Your First Complete Example¶
This guide walks through a complete, runnable example of using JaxARC to create and run a random agent. By the end, you’ll have a working baseline agent that you can build upon.
Complete Random Agent¶
Here’s a complete example that creates an environment, runs a random agent for multiple episodes, and tracks performance:
import jax
import jax.numpy as jnp
import jaxarc
def run_random_agent(num_episodes=10, max_steps=100, seed=0):
"""
Run a random agent on MiniARC for multiple episodes.
Args:
num_episodes: Number of episodes to run
max_steps: Maximum steps per episode
seed: Random seed for reproducibility
Returns:
Dictionary with episode statistics
"""
# Create environment (downloads dataset if needed)
env, env_params = jaxarc.make("Mini", auto_download=True)
# Initialize PRNG key
key = jax.random.PRNGKey(seed)
# Get action space
action_space = env.action_space(env_params)
# Track statistics
episode_rewards = []
episode_lengths = []
print(f"Running {num_episodes} episodes...")
print(f"Observation shape: {env.observation_shape()}")
print(f"Action space: {action_space}")
print("-" * 60)
for episode in range(num_episodes):
# Reset environment
key, reset_key = jax.random.split(key)
state, timestep = env.reset(reset_key, env_params=env_params)
episode_reward = 0.0
step_count = 0
# Run episode
while not timestep.last() and step_count < max_steps:
# Sample random action
key, action_key = jax.random.split(key)
action = action_space.sample(action_key)
# Take step
state, timestep = env.step(state, action, env_params=env_params)
# Track metrics
episode_reward += float(timestep.reward)
step_count += 1
# Store episode statistics
episode_rewards.append(episode_reward)
episode_lengths.append(step_count)
# Print episode summary
status = "✓ Solved" if timestep.last() else "✗ Max steps"
print(f"Episode {episode+1:2d}: {status} | "
f"Steps: {step_count:3d} | "
f"Reward: {episode_reward:6.2f}")
# Calculate overall statistics
stats = {
"mean_reward": jnp.mean(jnp.array(episode_rewards)),
"std_reward": jnp.std(jnp.array(episode_rewards)),
"mean_length": jnp.mean(jnp.array(episode_lengths)),
"success_rate": sum(1 for r in episode_rewards if r > 0) / num_episodes,
}
print("-" * 60)
print(f"Average reward: {stats['mean_reward']:.2f} ± {stats['std_reward']:.2f}")
print(f"Average episode length: {stats['mean_length']:.1f}")
print(f"Success rate: {stats['success_rate']:.1%}")
return stats
if __name__ == "__main__":
# Run the agent
stats = run_random_agent(num_episodes=10, max_steps=100, seed=42)
Understanding the Code¶
Let’s break down the key components:
1. Environment Setup¶
env, env_params = jaxarc.make("Mini", auto_download=True)
key = jax.random.PRNGKey(seed)
action_space = env.action_space(env_params)
jaxarc.make()creates the environment and parametersauto_download=Truedownloads the dataset if neededWe initialize a PRNG key for reproducibility
Get the action space once (it’s fixed for the environment)
2. Episode Loop¶
for episode in range(num_episodes):
key, reset_key = jax.random.split(key)
state, timestep = env.reset(reset_key, env_params=env_params)
while not timestep.last() and step_count < max_steps:
# ... episode logic
Outer loop iterates over episodes
Each episode starts with
reset()which returns(state, timestep)Inner loop runs until
timestep.last()or max stepsWe always split the PRNG key before using it
3. Action Sampling¶
key, action_key = jax.random.split(key)
action = action_space.sample(action_key)
state, timestep = env.step(state, action, env_params=env_params)
Split key to get a fresh random key
Sample action from the action space
Step returns new
(state, timestep)tuple (immutable!)Always pass
env_paramstostep()
4. Metrics Tracking¶
episode_reward += float(timestep.reward)
step_count += 1
Track cumulative reward per episode (from
timestep.reward)Count steps to measure episode length
Convert JAX arrays to Python floats for accumulation
Running the Example¶
Save the code above to a file (e.g., random_agent.py) and run it:
python random_agent.py
Expected output:
Running 10 episodes...
Observation shape: (5, 5, 1)
Action space: DictSpace({operation=DiscreteSpace(num_values=35, dtype=int32, name='operation'), selection=MultiDiscreteSpace(num_values=[Array([2, 2, 2, 2, 2], dtype=int32), Array([2, 2, 2, 2, 2], dtype=int32), Array([2, 2, 2, 2, 2], dtype=int32), Array([2, 2, 2, 2, 2], dtype=int32), Array([2, 2, 2, 2, 2], dtype=int32)], dtype=int32, name='selection_mask')}, name='arc_action')
------------------------------------------------------------
Episode 1: ✓ Solved | Steps: 62 | Reward: -1.63
Episode 2: ✓ Solved | Steps: 57 | Reward: -1.64
Episode 3: ✗ Max steps | Steps: 100 | Reward: -0.34
Episode 4: ✓ Solved | Steps: 3 | Reward: -1.01
Episode 5: ✓ Solved | Steps: 11 | Reward: -1.37
Episode 6: ✓ Solved | Steps: 5 | Reward: -0.70
Episode 7: ✓ Solved | Steps: 4 | Reward: -1.02
Episode 8: ✓ Solved | Steps: 70 | Reward: -1.71
Episode 9: ✓ Solved | Steps: 41 | Reward: -1.68
Episode 10: ✗ Max steps | Steps: 100 | Reward: -0.82
------------------------------------------------------------
Average reward: -1.19 ± 0.46
Average episode length: 45.3
Success rate: 0.0%
Note: A random agent typically doesn’t solve ARC tasks (success rate near 0%), but this provides a baseline for comparison.