Your First Complete Example

This guide walks through a complete, runnable example of using JaxARC to create and run a random agent. By the end, you’ll have a working baseline agent that you can build upon.

Complete Random Agent

Here’s a complete example that creates an environment, runs a random agent for multiple episodes, and tracks performance:

import jax
import jax.numpy as jnp
import jaxarc

def run_random_agent(num_episodes=10, max_steps=100, seed=0):
    """
    Run a random agent on MiniARC for multiple episodes.

    Args:
        num_episodes: Number of episodes to run
        max_steps: Maximum steps per episode
        seed: Random seed for reproducibility

    Returns:
        Dictionary with episode statistics
    """
    # Create environment (downloads dataset if needed)
    env, env_params = jaxarc.make("Mini", auto_download=True)

    # Initialize PRNG key
    key = jax.random.PRNGKey(seed)

    # Get action space
    action_space = env.action_space(env_params)

    # Track statistics
    episode_rewards = []
    episode_lengths = []

    print(f"Running {num_episodes} episodes...")
    print(f"Observation shape: {env.observation_shape()}")
    print(f"Action space: {action_space}")
    print("-" * 60)

    for episode in range(num_episodes):
        # Reset environment
        key, reset_key = jax.random.split(key)
        state, timestep = env.reset(reset_key, env_params=env_params)

        episode_reward = 0.0
        step_count = 0

        # Run episode
        while not timestep.last() and step_count < max_steps:
            # Sample random action
            key, action_key = jax.random.split(key)
            action = action_space.sample(action_key)

            # Take step
            state, timestep = env.step(state, action, env_params=env_params)

            # Track metrics
            episode_reward += float(timestep.reward)
            step_count += 1

        # Store episode statistics
        episode_rewards.append(episode_reward)
        episode_lengths.append(step_count)

        # Print episode summary
        status = "✓ Solved" if timestep.last() else "✗ Max steps"
        print(f"Episode {episode+1:2d}: {status} | "
              f"Steps: {step_count:3d} | "
              f"Reward: {episode_reward:6.2f}")

    # Calculate overall statistics
    stats = {
        "mean_reward": jnp.mean(jnp.array(episode_rewards)),
        "std_reward": jnp.std(jnp.array(episode_rewards)),
        "mean_length": jnp.mean(jnp.array(episode_lengths)),
        "success_rate": sum(1 for r in episode_rewards if r > 0) / num_episodes,
    }

    print("-" * 60)
    print(f"Average reward: {stats['mean_reward']:.2f} ± {stats['std_reward']:.2f}")
    print(f"Average episode length: {stats['mean_length']:.1f}")
    print(f"Success rate: {stats['success_rate']:.1%}")

    return stats

if __name__ == "__main__":
    # Run the agent
    stats = run_random_agent(num_episodes=10, max_steps=100, seed=42)

Understanding the Code

Let’s break down the key components:

1. Environment Setup

env, env_params = jaxarc.make("Mini", auto_download=True)
key = jax.random.PRNGKey(seed)
action_space = env.action_space(env_params)
  • jaxarc.make() creates the environment and parameters

  • auto_download=True downloads the dataset if needed

  • We initialize a PRNG key for reproducibility

  • Get the action space once (it’s fixed for the environment)

2. Episode Loop

for episode in range(num_episodes):
    key, reset_key = jax.random.split(key)
    state, timestep = env.reset(reset_key, env_params=env_params)

    while not timestep.last() and step_count < max_steps:
        # ... episode logic
  • Outer loop iterates over episodes

  • Each episode starts with reset() which returns (state, timestep)

  • Inner loop runs until timestep.last() or max steps

  • We always split the PRNG key before using it

3. Action Sampling

key, action_key = jax.random.split(key)
action = action_space.sample(action_key)
state, timestep = env.step(state, action, env_params=env_params)
  • Split key to get a fresh random key

  • Sample action from the action space

  • Step returns new (state, timestep) tuple (immutable!)

  • Always pass env_params to step()

4. Metrics Tracking

episode_reward += float(timestep.reward)
step_count += 1
  • Track cumulative reward per episode (from timestep.reward)

  • Count steps to measure episode length

  • Convert JAX arrays to Python floats for accumulation

Running the Example

Save the code above to a file (e.g., random_agent.py) and run it:

python random_agent.py

Expected output:

Running 10 episodes...
Observation shape: (5, 5, 1)
Action space: DictSpace({operation=DiscreteSpace(num_values=35, dtype=int32, name='operation'), selection=MultiDiscreteSpace(num_values=[Array([2, 2, 2, 2, 2], dtype=int32), Array([2, 2, 2, 2, 2], dtype=int32), Array([2, 2, 2, 2, 2], dtype=int32), Array([2, 2, 2, 2, 2], dtype=int32), Array([2, 2, 2, 2, 2], dtype=int32)], dtype=int32, name='selection_mask')}, name='arc_action')
------------------------------------------------------------
Episode  1: ✓ Solved | Steps:  62 | Reward:  -1.63
Episode  2: ✓ Solved | Steps:  57 | Reward:  -1.64
Episode  3: ✗ Max steps | Steps: 100 | Reward:  -0.34
Episode  4: ✓ Solved | Steps:   3 | Reward:  -1.01
Episode  5: ✓ Solved | Steps:  11 | Reward:  -1.37
Episode  6: ✓ Solved | Steps:   5 | Reward:  -0.70
Episode  7: ✓ Solved | Steps:   4 | Reward:  -1.02
Episode  8: ✓ Solved | Steps:  70 | Reward:  -1.71
Episode  9: ✓ Solved | Steps:  41 | Reward:  -1.68
Episode 10: ✗ Max steps | Steps: 100 | Reward:  -0.82
------------------------------------------------------------
Average reward: -1.19 ± 0.46
Average episode length: 45.3
Success rate: 0.0%

Note: A random agent typically doesn’t solve ARC tasks (success rate near 0%), but this provides a baseline for comparison.