Creating Agents for ARC Tasks

In JaxARC, an agent is simply a function that takes an observation and returns an action. The environment uses the TimeStep pattern where:

  • state contains the internal environment state

  • timestep contains observations, rewards, and termination flags

  • Actions are sampled from the environment’s action space

Setup: Create an Environment

First, let’s create a JaxARC environment. We’ll use a simple configuration optimized for agent development.

from __future__ import annotations

import jax
import jax.random as jr

from jaxarc.configs import JaxArcConfig
from jaxarc.registration import available_task_ids, make
from jaxarc.utils.core import get_config

# Configure environment with visualization and logging disabled for speed
config_overrides = [
    "dataset=mini_arc",
    "action=raw",
    "wandb.enabled=false",
    "logging.log_operations=false",
    "logging.log_rewards=false",
    "visualization.enabled=false",
]

# Load configuration
hydra_config = get_config(overrides=config_overrides)
config = JaxArcConfig.from_hydra(hydra_config)

# Get a task from MiniARC
available_ids = available_task_ids("Mini", config=config, auto_download=False)
task_id = available_ids[0]

# Create environment
env, env_params = make(f"Mini-{task_id}", config=config)

print(f"Environment created for task: {task_id}")
print(f"Action space: {env.action_space(env_params)}")
Environment created for task: Most_Common_color_l6ab0lf3xztbyxsu3p
Action space: DictSpace({operation=DiscreteSpace(num_values=35, dtype=int32, name='operation'), selection=MultiDiscreteSpace(num_values=[Array([2, 2, 2, 2, 2], dtype=int32), Array([2, 2, 2, 2, 2], dtype=int32), Array([2, 2, 2, 2, 2], dtype=int32), Array([2, 2, 2, 2, 2], dtype=int32), Array([2, 2, 2, 2, 2], dtype=int32)], dtype=int32, name='selection_mask')}, name='arc_action')
2025-11-03 22:12:27.955 | DEBUG    | jaxarc.utils.dataset_manager:validate_dataset:212 - Dataset validation passed: /Users/aadam/workspace/JaxARC/data/MiniARC
2025-11-03 22:12:27.955 | DEBUG    | jaxarc.utils.dataset_manager:ensure_dataset_available:81 - Dataset 'MiniARC' found at /Users/aadam/workspace/JaxARC/data/MiniARC
2025-11-03 22:12:27.958 | INFO     | jaxarc.parsers.mini_arc:_validate_grid_constraints:104 - MiniARC parser configured with optimal 5x5 grid constraints
2025-11-03 22:12:27.959 | INFO     | jaxarc.parsers.mini_arc:_scan_available_tasks:131 - Found 149 tasks in MiniARC dataset (lazy loading - tasks loaded on-demand, optimized for 5x5 grids)
2025-11-03 22:12:27.962 | DEBUG    | jaxarc.utils.dataset_manager:validate_dataset:212 - Dataset validation passed: /Users/aadam/workspace/JaxARC/data/MiniARC
2025-11-03 22:12:27.962 | DEBUG    | jaxarc.utils.dataset_manager:ensure_dataset_available:81 - Dataset 'MiniARC' found at /Users/aadam/workspace/JaxARC/data/MiniARC
2025-11-03 22:12:27.963 | INFO     | jaxarc.parsers.mini_arc:_validate_grid_constraints:104 - MiniARC parser configured with optimal 5x5 grid constraints
2025-11-03 22:12:27.965 | INFO     | jaxarc.parsers.mini_arc:_scan_available_tasks:131 - Found 149 tasks in MiniARC dataset (lazy loading - tasks loaded on-demand, optimized for 5x5 grids)
2025-11-03 22:12:27.966 | DEBUG    | jaxarc.parsers.mini_arc:_load_task_from_disk:171 - Loaded MiniARC task 'Most_Common_color_l6ab0lf3xztbyxsu3p' from disk
2025-11-03 22:12:28.332 | DEBUG    | jaxarc.parsers.base_parser:_log_parsing_stats:479 - Task Most_Common_color_l6ab0lf3xztbyxsu3p: 3 train pairs, 1 test pairs, max grid size: 5x5
2025-11-03 22:12:28.332 | DEBUG    | jaxarc.utils.task_manager:get_global_task_manager:236 - Created global task ID manager
2025-11-03 22:12:28.332 | DEBUG    | jaxarc.utils.task_manager:register_task:72 - Registered task 'Most_Common_color_l6ab0lf3xztbyxsu3p' with index 0

Understanding the Environment Loop

Before creating an agent, let’s understand how to interact with the environment using the TimeStep API.

# Initialize environment
key = jr.PRNGKey(42)
state, timestep = env.reset(key, env_params)

print(f"Observation shape: {timestep.observation.shape}")
print(f"Initial reward: {timestep.reward}")
print(f"Episode terminated: {timestep.last()}")

# Get action space for sampling
action_space = env.action_space(env_params)

# Sample and take a single action
key, action_key = jr.split(key)
action = action_space.sample(action_key)

print(f"\nSampled action: {action}")

# Step the environment
state, timestep = env.step(state, action, env_params)

print("\nAfter step:")
print(f"Reward: {timestep.reward}")
print(f"Episode done: {timestep.last()}")
Observation shape: (5, 5, 1)
Initial reward: 0.0
Episode terminated: False

Sampled action: {'operation': Array(6, dtype=int32), 'selection': Array([[0, 1, 0, 1, 0],
       [0, 1, 0, 1, 0],
       [0, 0, 0, 1, 1],
       [0, 0, 0, 0, 0],
       [0, 0, 0, 0, 1]], dtype=int32)}

After step:
Reward: -0.004999999888241291
Episode done: False

Creating a Random Agent

The simplest agent samples random actions from the action space. This serves as a baseline for comparison with more sophisticated agents.

Single Episode

# Run a single episode with a random agent
def run_random_episode(env, env_params, key, max_steps=100):
    """Run one episode with random actions."""
    # Reset environment
    reset_key, loop_key = jr.split(key)
    state, timestep = env.reset(reset_key, env_params)

    action_space = env.action_space(env_params)
    episode_reward = 0.0
    step_count = 0

    # Run episode
    while not timestep.last() and step_count < max_steps:
        # Sample random action
        loop_key, action_key = jr.split(loop_key)
        action = action_space.sample(action_key)

        # Step environment
        state, timestep = env.step(state, action, env_params)

        episode_reward += float(timestep.reward)
        step_count += 1

    return episode_reward, step_count


# Test the agent
key = jr.PRNGKey(123)
reward, steps = run_random_episode(env, env_params, key, max_steps=50)

print("Episode completed!")
print(f"Total reward: {reward:.2f}")
print(f"Steps taken: {steps}")
print(f"Average reward per step: {reward / steps:.3f}")
Episode completed!
Total reward: -1.30
Steps taken: 20
Average reward per step: -0.065

JAX-Accelerated Agent with Scan

For high-performance, we can use jax.lax.scan to run multiple steps efficiently. This pattern is used in PureJaxRL and similar high-throughput RL frameworks.

def make_jax_agent(env, env_params, num_steps):
    """Create a JIT-compiled agent using scan for efficiency."""
    action_space = env.action_space(env_params)

    def run_agent(key):
        # Reset environment
        reset_key, loop_key = jr.split(key)
        state, timestep = env.reset(reset_key, env_params)

        def step_fn(carry, _):
            """One step of the agent."""
            state, timestep, key = carry

            # Split key for action sampling and next iteration
            key, action_key, next_key = jr.split(key, 3)

            # Handle episode termination with conditional reset
            def do_reset(_):
                return env.reset(key, env_params)

            def continue_episode(_):
                return state, timestep

            state, timestep = jax.lax.cond(
                timestep.last(), do_reset, continue_episode, None
            )

            # Sample action and step
            action = action_space.sample(action_key)
            new_state, new_timestep = env.step(state, action, env_params)

            return (new_state, new_timestep, next_key), new_timestep.reward

        # Run scan over num_steps
        (final_state, final_timestep, _), rewards = jax.lax.scan(
            step_fn, (state, timestep, loop_key), None, length=num_steps
        )

        return rewards, final_timestep

    # JIT compile the entire function
    return jax.jit(run_agent)


# Create and run JIT-compiled agent
jax_agent = make_jax_agent(env, env_params, num_steps=100)

# First run includes compilation time
print("Compiling agent (first run)...")
key = jr.PRNGKey(456)
rewards, final_timestep = jax_agent(key)

print("Agent compiled and executed!")
print(f"Total reward: {float(rewards.sum()):.2f}")
print(f"Mean reward per step: {float(rewards.mean()):.3f}")
print(f"Max reward: {float(rewards.max()):.2f}")
print(f"Final episode terminated: {final_timestep.last()}")
Compiling agent (first run)...
Agent compiled and executed!
Total reward: -2.78
Mean reward per step: -0.028
Max reward: 0.63
Final episode terminated: False

Vectorized Agent: Multiple Parallel Environments

JAX’s vmap allows us to run multiple environments in parallel with a single function call. This dramatically increases throughput.

def make_vectorized_agent(env, env_params, num_envs, num_steps):
    """Create a vectorized agent that runs multiple environments in parallel."""

    # Create the single-env agent
    single_agent = make_jax_agent(env, env_params, num_steps)

    # Vectorize it across multiple environments
    vectorized_agent = jax.vmap(single_agent)

    return vectorized_agent


# Create vectorized agent
num_envs = 16
num_steps = 100

print(f"Creating vectorized agent with {num_envs} parallel environments...")
vec_agent = make_vectorized_agent(env, env_params, num_envs, num_steps)

# Generate keys for each environment
key = jr.PRNGKey(789)
env_keys = jr.split(key, num_envs)

# Run all environments in parallel
print(f"Running {num_envs} environments × {num_steps} steps...")
all_rewards, all_final_timesteps = vec_agent(env_keys)

# Analyze results
print(f"\nResults across {num_envs} environments:")
print(f"Mean total reward: {float(all_rewards.sum(axis=1).mean()):.2f}")
print(f"Best environment reward: {float(all_rewards.sum(axis=1).max()):.2f}")
print(f"Worst environment reward: {float(all_rewards.sum(axis=1).min()):.2f}")
print(f"Mean reward per step: {float(all_rewards.mean()):.3f}")

# Total steps executed
total_steps = num_envs * num_steps
print(f"\nTotal steps executed: {total_steps:,}")
Creating vectorized agent with 16 parallel environments...
Running 16 environments × 100 steps...

Results across 16 environments:
Mean total reward: -4.22
Best environment reward: -1.86
Worst environment reward: -9.14
Mean reward per step: -0.042

Total steps executed: 1,600

Performance Benchmark

Let’s measure the throughput of our vectorized agent to understand the performance benefits of JAX.

import time

# Benchmark configuration
num_envs = 64
num_steps = 256
num_runs = 3

print("Benchmarking vectorized agent...")
print(f"Configuration: {num_envs} envs × {num_steps} steps")
print("Warmup run (includes compilation)...\n")

# Create fresh agent
vec_agent = make_vectorized_agent(env, env_params, num_envs, num_steps)
key = jr.PRNGKey(999)
env_keys = jr.split(key, num_envs)

# Warmup run (includes compilation)
start = time.time()
rewards, _ = vec_agent(env_keys)
_ = rewards.block_until_ready()  # Wait for computation
warmup_time = time.time() - start

print(f"Warmup complete: {warmup_time:.2f}s (includes JIT compilation)")

# Timed runs
print(f"\nRunning {num_runs} timed iterations...")
times = []

for i in range(num_runs):
    key, subkey = jr.split(key)
    env_keys = jr.split(subkey, num_envs)

    start = time.time()
    rewards, _ = vec_agent(env_keys)
    _ = rewards.block_until_ready()
    elapsed = time.time() - start
    times.append(elapsed)

    print(f"  Run {i + 1}: {elapsed:.3f}s")

# Calculate statistics
mean_time = sum(times) / len(times)
total_steps = num_envs * num_steps
sps = total_steps / mean_time

print("\nPerformance Results:")
print(f"Mean execution time: {mean_time:.3f}s")
print(f"Steps per second (SPS): {sps:,.0f}")
print(f"Total steps per run: {total_steps:,}")
Benchmarking vectorized agent...
Configuration: 64 envs × 256 steps
Warmup run (includes compilation)...

Warmup complete: 2.36s (includes JIT compilation)

Running 3 timed iterations...
  Run 1: 0.085s
  Run 2: 0.085s
  Run 3: 0.087s

Performance Results:
Mean execution time: 0.086s
Steps per second (SPS): 190,766
Total steps per run: 16,384

Building Your Own Agent

To create a learning agent (not just random):

  1. Define a neural network using Flax, Haiku, or Equinox

  2. Collect trajectories using the scan pattern shown above

  3. Compute losses from rewards and observations

  4. Update parameters using Optax optimizers

  5. Repeat the training loop