Creating Agents for ARC Tasks¶
In JaxARC, an agent is simply a function that takes an observation and returns an action. The environment uses the TimeStep pattern where:
statecontains the internal environment statetimestepcontains observations, rewards, and termination flagsActions are sampled from the environment’s action space
Setup: Create an Environment¶
First, let’s create a JaxARC environment. We’ll use a simple configuration optimized for agent development.
from __future__ import annotations
import jax
import jax.random as jr
from jaxarc.configs import JaxArcConfig
from jaxarc.registration import available_task_ids, make
from jaxarc.utils.core import get_config
# Configure environment with visualization and logging disabled for speed
config_overrides = [
"dataset=mini_arc",
"action=raw",
"wandb.enabled=false",
"logging.log_operations=false",
"logging.log_rewards=false",
"visualization.enabled=false",
]
# Load configuration
hydra_config = get_config(overrides=config_overrides)
config = JaxArcConfig.from_hydra(hydra_config)
# Get a task from MiniARC
available_ids = available_task_ids("Mini", config=config, auto_download=False)
task_id = available_ids[0]
# Create environment
env, env_params = make(f"Mini-{task_id}", config=config)
print(f"Environment created for task: {task_id}")
print(f"Action space: {env.action_space(env_params)}")
Environment created for task: Most_Common_color_l6ab0lf3xztbyxsu3p
Action space: DictSpace({operation=DiscreteSpace(num_values=35, dtype=int32, name='operation'), selection=MultiDiscreteSpace(num_values=[Array([2, 2, 2, 2, 2], dtype=int32), Array([2, 2, 2, 2, 2], dtype=int32), Array([2, 2, 2, 2, 2], dtype=int32), Array([2, 2, 2, 2, 2], dtype=int32), Array([2, 2, 2, 2, 2], dtype=int32)], dtype=int32, name='selection_mask')}, name='arc_action')
2025-11-03 22:12:27.955 | DEBUG | jaxarc.utils.dataset_manager:validate_dataset:212 - Dataset validation passed: /Users/aadam/workspace/JaxARC/data/MiniARC
2025-11-03 22:12:27.955 | DEBUG | jaxarc.utils.dataset_manager:ensure_dataset_available:81 - Dataset 'MiniARC' found at /Users/aadam/workspace/JaxARC/data/MiniARC
2025-11-03 22:12:27.958 | INFO | jaxarc.parsers.mini_arc:_validate_grid_constraints:104 - MiniARC parser configured with optimal 5x5 grid constraints
2025-11-03 22:12:27.959 | INFO | jaxarc.parsers.mini_arc:_scan_available_tasks:131 - Found 149 tasks in MiniARC dataset (lazy loading - tasks loaded on-demand, optimized for 5x5 grids)
2025-11-03 22:12:27.962 | DEBUG | jaxarc.utils.dataset_manager:validate_dataset:212 - Dataset validation passed: /Users/aadam/workspace/JaxARC/data/MiniARC
2025-11-03 22:12:27.962 | DEBUG | jaxarc.utils.dataset_manager:ensure_dataset_available:81 - Dataset 'MiniARC' found at /Users/aadam/workspace/JaxARC/data/MiniARC
2025-11-03 22:12:27.963 | INFO | jaxarc.parsers.mini_arc:_validate_grid_constraints:104 - MiniARC parser configured with optimal 5x5 grid constraints
2025-11-03 22:12:27.965 | INFO | jaxarc.parsers.mini_arc:_scan_available_tasks:131 - Found 149 tasks in MiniARC dataset (lazy loading - tasks loaded on-demand, optimized for 5x5 grids)
2025-11-03 22:12:27.966 | DEBUG | jaxarc.parsers.mini_arc:_load_task_from_disk:171 - Loaded MiniARC task 'Most_Common_color_l6ab0lf3xztbyxsu3p' from disk
2025-11-03 22:12:28.332 | DEBUG | jaxarc.parsers.base_parser:_log_parsing_stats:479 - Task Most_Common_color_l6ab0lf3xztbyxsu3p: 3 train pairs, 1 test pairs, max grid size: 5x5
2025-11-03 22:12:28.332 | DEBUG | jaxarc.utils.task_manager:get_global_task_manager:236 - Created global task ID manager
2025-11-03 22:12:28.332 | DEBUG | jaxarc.utils.task_manager:register_task:72 - Registered task 'Most_Common_color_l6ab0lf3xztbyxsu3p' with index 0
Understanding the Environment Loop¶
Before creating an agent, let’s understand how to interact with the environment using the TimeStep API.
# Initialize environment
key = jr.PRNGKey(42)
state, timestep = env.reset(key, env_params)
print(f"Observation shape: {timestep.observation.shape}")
print(f"Initial reward: {timestep.reward}")
print(f"Episode terminated: {timestep.last()}")
# Get action space for sampling
action_space = env.action_space(env_params)
# Sample and take a single action
key, action_key = jr.split(key)
action = action_space.sample(action_key)
print(f"\nSampled action: {action}")
# Step the environment
state, timestep = env.step(state, action, env_params)
print("\nAfter step:")
print(f"Reward: {timestep.reward}")
print(f"Episode done: {timestep.last()}")
Observation shape: (5, 5, 1)
Initial reward: 0.0
Episode terminated: False
Sampled action: {'operation': Array(6, dtype=int32), 'selection': Array([[0, 1, 0, 1, 0],
[0, 1, 0, 1, 0],
[0, 0, 0, 1, 1],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 1]], dtype=int32)}
After step:
Reward: -0.004999999888241291
Episode done: False
Creating a Random Agent¶
The simplest agent samples random actions from the action space. This serves as a baseline for comparison with more sophisticated agents.
Single Episode¶
# Run a single episode with a random agent
def run_random_episode(env, env_params, key, max_steps=100):
"""Run one episode with random actions."""
# Reset environment
reset_key, loop_key = jr.split(key)
state, timestep = env.reset(reset_key, env_params)
action_space = env.action_space(env_params)
episode_reward = 0.0
step_count = 0
# Run episode
while not timestep.last() and step_count < max_steps:
# Sample random action
loop_key, action_key = jr.split(loop_key)
action = action_space.sample(action_key)
# Step environment
state, timestep = env.step(state, action, env_params)
episode_reward += float(timestep.reward)
step_count += 1
return episode_reward, step_count
# Test the agent
key = jr.PRNGKey(123)
reward, steps = run_random_episode(env, env_params, key, max_steps=50)
print("Episode completed!")
print(f"Total reward: {reward:.2f}")
print(f"Steps taken: {steps}")
print(f"Average reward per step: {reward / steps:.3f}")
Episode completed!
Total reward: -1.30
Steps taken: 20
Average reward per step: -0.065
JAX-Accelerated Agent with Scan¶
For high-performance, we can use jax.lax.scan to run multiple steps efficiently. This pattern is used in PureJaxRL and similar high-throughput RL frameworks.
def make_jax_agent(env, env_params, num_steps):
"""Create a JIT-compiled agent using scan for efficiency."""
action_space = env.action_space(env_params)
def run_agent(key):
# Reset environment
reset_key, loop_key = jr.split(key)
state, timestep = env.reset(reset_key, env_params)
def step_fn(carry, _):
"""One step of the agent."""
state, timestep, key = carry
# Split key for action sampling and next iteration
key, action_key, next_key = jr.split(key, 3)
# Handle episode termination with conditional reset
def do_reset(_):
return env.reset(key, env_params)
def continue_episode(_):
return state, timestep
state, timestep = jax.lax.cond(
timestep.last(), do_reset, continue_episode, None
)
# Sample action and step
action = action_space.sample(action_key)
new_state, new_timestep = env.step(state, action, env_params)
return (new_state, new_timestep, next_key), new_timestep.reward
# Run scan over num_steps
(final_state, final_timestep, _), rewards = jax.lax.scan(
step_fn, (state, timestep, loop_key), None, length=num_steps
)
return rewards, final_timestep
# JIT compile the entire function
return jax.jit(run_agent)
# Create and run JIT-compiled agent
jax_agent = make_jax_agent(env, env_params, num_steps=100)
# First run includes compilation time
print("Compiling agent (first run)...")
key = jr.PRNGKey(456)
rewards, final_timestep = jax_agent(key)
print("Agent compiled and executed!")
print(f"Total reward: {float(rewards.sum()):.2f}")
print(f"Mean reward per step: {float(rewards.mean()):.3f}")
print(f"Max reward: {float(rewards.max()):.2f}")
print(f"Final episode terminated: {final_timestep.last()}")
Compiling agent (first run)...
Agent compiled and executed!
Total reward: -2.78
Mean reward per step: -0.028
Max reward: 0.63
Final episode terminated: False
Vectorized Agent: Multiple Parallel Environments¶
JAX’s vmap allows us to run multiple environments in parallel with a single function call. This dramatically increases throughput.
def make_vectorized_agent(env, env_params, num_envs, num_steps):
"""Create a vectorized agent that runs multiple environments in parallel."""
# Create the single-env agent
single_agent = make_jax_agent(env, env_params, num_steps)
# Vectorize it across multiple environments
vectorized_agent = jax.vmap(single_agent)
return vectorized_agent
# Create vectorized agent
num_envs = 16
num_steps = 100
print(f"Creating vectorized agent with {num_envs} parallel environments...")
vec_agent = make_vectorized_agent(env, env_params, num_envs, num_steps)
# Generate keys for each environment
key = jr.PRNGKey(789)
env_keys = jr.split(key, num_envs)
# Run all environments in parallel
print(f"Running {num_envs} environments × {num_steps} steps...")
all_rewards, all_final_timesteps = vec_agent(env_keys)
# Analyze results
print(f"\nResults across {num_envs} environments:")
print(f"Mean total reward: {float(all_rewards.sum(axis=1).mean()):.2f}")
print(f"Best environment reward: {float(all_rewards.sum(axis=1).max()):.2f}")
print(f"Worst environment reward: {float(all_rewards.sum(axis=1).min()):.2f}")
print(f"Mean reward per step: {float(all_rewards.mean()):.3f}")
# Total steps executed
total_steps = num_envs * num_steps
print(f"\nTotal steps executed: {total_steps:,}")
Creating vectorized agent with 16 parallel environments...
Running 16 environments × 100 steps...
Results across 16 environments:
Mean total reward: -4.22
Best environment reward: -1.86
Worst environment reward: -9.14
Mean reward per step: -0.042
Total steps executed: 1,600
Performance Benchmark¶
Let’s measure the throughput of our vectorized agent to understand the performance benefits of JAX.
import time
# Benchmark configuration
num_envs = 64
num_steps = 256
num_runs = 3
print("Benchmarking vectorized agent...")
print(f"Configuration: {num_envs} envs × {num_steps} steps")
print("Warmup run (includes compilation)...\n")
# Create fresh agent
vec_agent = make_vectorized_agent(env, env_params, num_envs, num_steps)
key = jr.PRNGKey(999)
env_keys = jr.split(key, num_envs)
# Warmup run (includes compilation)
start = time.time()
rewards, _ = vec_agent(env_keys)
_ = rewards.block_until_ready() # Wait for computation
warmup_time = time.time() - start
print(f"Warmup complete: {warmup_time:.2f}s (includes JIT compilation)")
# Timed runs
print(f"\nRunning {num_runs} timed iterations...")
times = []
for i in range(num_runs):
key, subkey = jr.split(key)
env_keys = jr.split(subkey, num_envs)
start = time.time()
rewards, _ = vec_agent(env_keys)
_ = rewards.block_until_ready()
elapsed = time.time() - start
times.append(elapsed)
print(f" Run {i + 1}: {elapsed:.3f}s")
# Calculate statistics
mean_time = sum(times) / len(times)
total_steps = num_envs * num_steps
sps = total_steps / mean_time
print("\nPerformance Results:")
print(f"Mean execution time: {mean_time:.3f}s")
print(f"Steps per second (SPS): {sps:,.0f}")
print(f"Total steps per run: {total_steps:,}")
Benchmarking vectorized agent...
Configuration: 64 envs × 256 steps
Warmup run (includes compilation)...
Warmup complete: 2.36s (includes JIT compilation)
Running 3 timed iterations...
Run 1: 0.085s
Run 2: 0.085s
Run 3: 0.087s
Performance Results:
Mean execution time: 0.086s
Steps per second (SPS): 190,766
Total steps per run: 16,384
Building Your Own Agent¶
To create a learning agent (not just random):
Define a neural network using Flax, Haiku, or Equinox
Collect trajectories using the scan pattern shown above
Compute losses from rewards and observations
Update parameters using Optax optimizers
Repeat the training loop