{ "cells": [ { "cell_type": "markdown", "id": "23741476", "metadata": {}, "source": [ "# Creating Agents for ARC Tasks\n", "\n", "In JaxARC, an agent is simply a function that takes an observation and returns an action. The environment uses the `TimeStep` pattern where:\n", "- `state` contains the internal environment state\n", "- `timestep` contains observations, rewards, and termination flags\n", "- Actions are sampled from the environment's action space" ] }, { "cell_type": "markdown", "id": "75fffcd7", "metadata": {}, "source": [ "## Setup: Create an Environment\n", "\n", "First, let's create a JaxARC environment. We'll use a simple configuration optimized for agent development." ] }, { "cell_type": "code", "execution_count": 1, "id": "db082d78", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\u001b[32m2025-11-03 22:12:27.955\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mjaxarc.utils.dataset_manager\u001b[0m:\u001b[36mvalidate_dataset\u001b[0m:\u001b[36m212\u001b[0m - \u001b[34m\u001b[1mDataset validation passed: /Users/aadam/workspace/JaxARC/data/MiniARC\u001b[0m\n", "\u001b[32m2025-11-03 22:12:27.955\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mjaxarc.utils.dataset_manager\u001b[0m:\u001b[36mensure_dataset_available\u001b[0m:\u001b[36m81\u001b[0m - \u001b[34m\u001b[1mDataset 'MiniARC' found at /Users/aadam/workspace/JaxARC/data/MiniARC\u001b[0m\n", "\u001b[32m2025-11-03 22:12:27.958\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mjaxarc.parsers.mini_arc\u001b[0m:\u001b[36m_validate_grid_constraints\u001b[0m:\u001b[36m104\u001b[0m - \u001b[1mMiniARC parser configured with optimal 5x5 grid constraints\u001b[0m\n", "\u001b[32m2025-11-03 22:12:27.959\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mjaxarc.parsers.mini_arc\u001b[0m:\u001b[36m_scan_available_tasks\u001b[0m:\u001b[36m131\u001b[0m - \u001b[1mFound 149 tasks in MiniARC dataset (lazy loading - tasks loaded on-demand, optimized for 5x5 grids)\u001b[0m\n", "\u001b[32m2025-11-03 22:12:27.962\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mjaxarc.utils.dataset_manager\u001b[0m:\u001b[36mvalidate_dataset\u001b[0m:\u001b[36m212\u001b[0m - \u001b[34m\u001b[1mDataset validation passed: /Users/aadam/workspace/JaxARC/data/MiniARC\u001b[0m\n", "\u001b[32m2025-11-03 22:12:27.962\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mjaxarc.utils.dataset_manager\u001b[0m:\u001b[36mensure_dataset_available\u001b[0m:\u001b[36m81\u001b[0m - \u001b[34m\u001b[1mDataset 'MiniARC' found at /Users/aadam/workspace/JaxARC/data/MiniARC\u001b[0m\n", "\u001b[32m2025-11-03 22:12:27.963\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mjaxarc.parsers.mini_arc\u001b[0m:\u001b[36m_validate_grid_constraints\u001b[0m:\u001b[36m104\u001b[0m - \u001b[1mMiniARC parser configured with optimal 5x5 grid constraints\u001b[0m\n", "\u001b[32m2025-11-03 22:12:27.965\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mjaxarc.parsers.mini_arc\u001b[0m:\u001b[36m_scan_available_tasks\u001b[0m:\u001b[36m131\u001b[0m - \u001b[1mFound 149 tasks in MiniARC dataset (lazy loading - tasks loaded on-demand, optimized for 5x5 grids)\u001b[0m\n", "\u001b[32m2025-11-03 22:12:27.966\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mjaxarc.parsers.mini_arc\u001b[0m:\u001b[36m_load_task_from_disk\u001b[0m:\u001b[36m171\u001b[0m - \u001b[34m\u001b[1mLoaded MiniARC task 'Most_Common_color_l6ab0lf3xztbyxsu3p' from disk\u001b[0m\n", "\u001b[32m2025-11-03 22:12:28.332\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mjaxarc.parsers.base_parser\u001b[0m:\u001b[36m_log_parsing_stats\u001b[0m:\u001b[36m479\u001b[0m - \u001b[34m\u001b[1mTask Most_Common_color_l6ab0lf3xztbyxsu3p: 3 train pairs, 1 test pairs, max grid size: 5x5\u001b[0m\n", "\u001b[32m2025-11-03 22:12:28.332\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mjaxarc.utils.task_manager\u001b[0m:\u001b[36mget_global_task_manager\u001b[0m:\u001b[36m236\u001b[0m - \u001b[34m\u001b[1mCreated global task ID manager\u001b[0m\n", "\u001b[32m2025-11-03 22:12:28.332\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mjaxarc.utils.task_manager\u001b[0m:\u001b[36mregister_task\u001b[0m:\u001b[36m72\u001b[0m - \u001b[34m\u001b[1mRegistered task 'Most_Common_color_l6ab0lf3xztbyxsu3p' with index 0\u001b[0m\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Environment created for task: Most_Common_color_l6ab0lf3xztbyxsu3p\n", "Action space: DictSpace({operation=DiscreteSpace(num_values=35, dtype=int32, name='operation'), selection=MultiDiscreteSpace(num_values=[Array([2, 2, 2, 2, 2], dtype=int32), Array([2, 2, 2, 2, 2], dtype=int32), Array([2, 2, 2, 2, 2], dtype=int32), Array([2, 2, 2, 2, 2], dtype=int32), Array([2, 2, 2, 2, 2], dtype=int32)], dtype=int32, name='selection_mask')}, name='arc_action')\n" ] } ], "source": [ "from __future__ import annotations\n", "\n", "import jax\n", "import jax.random as jr\n", "\n", "from jaxarc.configs import JaxArcConfig\n", "from jaxarc.registration import available_task_ids, make\n", "from jaxarc.utils.core import get_config\n", "\n", "# Configure environment with visualization and logging disabled for speed\n", "config_overrides = [\n", " \"dataset=mini_arc\",\n", " \"action=raw\",\n", " \"wandb.enabled=false\",\n", " \"logging.log_operations=false\",\n", " \"logging.log_rewards=false\",\n", " \"visualization.enabled=false\",\n", "]\n", "\n", "# Load configuration\n", "hydra_config = get_config(overrides=config_overrides)\n", "config = JaxArcConfig.from_hydra(hydra_config)\n", "\n", "# Get a task from MiniARC\n", "available_ids = available_task_ids(\"Mini\", config=config, auto_download=False)\n", "task_id = available_ids[0]\n", "\n", "# Create environment\n", "env, env_params = make(f\"Mini-{task_id}\", config=config)\n", "\n", "print(f\"Environment created for task: {task_id}\")\n", "print(f\"Action space: {env.action_space(env_params)}\")" ] }, { "cell_type": "markdown", "id": "1f5acdb3", "metadata": {}, "source": [ "## Understanding the Environment Loop\n", "\n", "Before creating an agent, let's understand how to interact with the environment using the TimeStep API." ] }, { "cell_type": "code", "execution_count": 2, "id": "37a1b84a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Observation shape: (5, 5, 1)\n", "Initial reward: 0.0\n", "Episode terminated: False\n", "\n", "Sampled action: {'operation': Array(6, dtype=int32), 'selection': Array([[0, 1, 0, 1, 0],\n", " [0, 1, 0, 1, 0],\n", " [0, 0, 0, 1, 1],\n", " [0, 0, 0, 0, 0],\n", " [0, 0, 0, 0, 1]], dtype=int32)}\n", "\n", "After step:\n", "Reward: -0.004999999888241291\n", "Episode done: False\n" ] } ], "source": [ "# Initialize environment\n", "key = jr.PRNGKey(42)\n", "state, timestep = env.reset(key, env_params)\n", "\n", "print(f\"Observation shape: {timestep.observation.shape}\")\n", "print(f\"Initial reward: {timestep.reward}\")\n", "print(f\"Episode terminated: {timestep.last()}\")\n", "\n", "# Get action space for sampling\n", "action_space = env.action_space(env_params)\n", "\n", "# Sample and take a single action\n", "key, action_key = jr.split(key)\n", "action = action_space.sample(action_key)\n", "\n", "print(f\"\\nSampled action: {action}\")\n", "\n", "# Step the environment\n", "state, timestep = env.step(state, action, env_params)\n", "\n", "print(\"\\nAfter step:\")\n", "print(f\"Reward: {timestep.reward}\")\n", "print(f\"Episode done: {timestep.last()}\")" ] }, { "cell_type": "markdown", "id": "a7c198f6", "metadata": {}, "source": [ "## Creating a Random Agent\n", "\n", "The simplest agent samples random actions from the action space. This serves as a baseline for comparison with more sophisticated agents.\n", "\n", "### Single Episode" ] }, { "cell_type": "code", "execution_count": 3, "id": "93b8e3cb", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Episode completed!\n", "Total reward: -1.30\n", "Steps taken: 20\n", "Average reward per step: -0.065\n" ] } ], "source": [ "# Run a single episode with a random agent\n", "def run_random_episode(env, env_params, key, max_steps=100):\n", " \"\"\"Run one episode with random actions.\"\"\"\n", " # Reset environment\n", " reset_key, loop_key = jr.split(key)\n", " state, timestep = env.reset(reset_key, env_params)\n", "\n", " action_space = env.action_space(env_params)\n", " episode_reward = 0.0\n", " step_count = 0\n", "\n", " # Run episode\n", " while not timestep.last() and step_count < max_steps:\n", " # Sample random action\n", " loop_key, action_key = jr.split(loop_key)\n", " action = action_space.sample(action_key)\n", "\n", " # Step environment\n", " state, timestep = env.step(state, action, env_params)\n", "\n", " episode_reward += float(timestep.reward)\n", " step_count += 1\n", "\n", " return episode_reward, step_count\n", "\n", "\n", "# Test the agent\n", "key = jr.PRNGKey(123)\n", "reward, steps = run_random_episode(env, env_params, key, max_steps=50)\n", "\n", "print(\"Episode completed!\")\n", "print(f\"Total reward: {reward:.2f}\")\n", "print(f\"Steps taken: {steps}\")\n", "print(f\"Average reward per step: {reward / steps:.3f}\")" ] }, { "cell_type": "markdown", "id": "73e6938d", "metadata": {}, "source": [ "## JAX-Accelerated Agent with Scan\n", "\n", "For high-performance, we can use `jax.lax.scan` to run multiple steps efficiently. This pattern is used in PureJaxRL and similar high-throughput RL frameworks." ] }, { "cell_type": "code", "execution_count": 4, "id": "1b73acdc", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Compiling agent (first run)...\n", "Agent compiled and executed!\n", "Total reward: -2.78\n", "Mean reward per step: -0.028\n", "Max reward: 0.63\n", "Final episode terminated: False\n" ] } ], "source": [ "def make_jax_agent(env, env_params, num_steps):\n", " \"\"\"Create a JIT-compiled agent using scan for efficiency.\"\"\"\n", " action_space = env.action_space(env_params)\n", "\n", " def run_agent(key):\n", " # Reset environment\n", " reset_key, loop_key = jr.split(key)\n", " state, timestep = env.reset(reset_key, env_params)\n", "\n", " def step_fn(carry, _):\n", " \"\"\"One step of the agent.\"\"\"\n", " state, timestep, key = carry\n", "\n", " # Split key for action sampling and next iteration\n", " key, action_key, next_key = jr.split(key, 3)\n", "\n", " # Handle episode termination with conditional reset\n", " def do_reset(_):\n", " return env.reset(key, env_params)\n", "\n", " def continue_episode(_):\n", " return state, timestep\n", "\n", " state, timestep = jax.lax.cond(\n", " timestep.last(), do_reset, continue_episode, None\n", " )\n", "\n", " # Sample action and step\n", " action = action_space.sample(action_key)\n", " new_state, new_timestep = env.step(state, action, env_params)\n", "\n", " return (new_state, new_timestep, next_key), new_timestep.reward\n", "\n", " # Run scan over num_steps\n", " (final_state, final_timestep, _), rewards = jax.lax.scan(\n", " step_fn, (state, timestep, loop_key), None, length=num_steps\n", " )\n", "\n", " return rewards, final_timestep\n", "\n", " # JIT compile the entire function\n", " return jax.jit(run_agent)\n", "\n", "\n", "# Create and run JIT-compiled agent\n", "jax_agent = make_jax_agent(env, env_params, num_steps=100)\n", "\n", "# First run includes compilation time\n", "print(\"Compiling agent (first run)...\")\n", "key = jr.PRNGKey(456)\n", "rewards, final_timestep = jax_agent(key)\n", "\n", "print(\"Agent compiled and executed!\")\n", "print(f\"Total reward: {float(rewards.sum()):.2f}\")\n", "print(f\"Mean reward per step: {float(rewards.mean()):.3f}\")\n", "print(f\"Max reward: {float(rewards.max()):.2f}\")\n", "print(f\"Final episode terminated: {final_timestep.last()}\")" ] }, { "cell_type": "markdown", "id": "29a51ac2", "metadata": {}, "source": [ "## Vectorized Agent: Multiple Parallel Environments\n", "\n", "JAX's `vmap` allows us to run multiple environments in parallel with a single function call. This dramatically increases throughput." ] }, { "cell_type": "code", "execution_count": 5, "id": "9b9a9b2d", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Creating vectorized agent with 16 parallel environments...\n", "Running 16 environments × 100 steps...\n", "\n", "Results across 16 environments:\n", "Mean total reward: -4.22\n", "Best environment reward: -1.86\n", "Worst environment reward: -9.14\n", "Mean reward per step: -0.042\n", "\n", "Total steps executed: 1,600\n" ] } ], "source": [ "def make_vectorized_agent(env, env_params, num_envs, num_steps):\n", " \"\"\"Create a vectorized agent that runs multiple environments in parallel.\"\"\"\n", "\n", " # Create the single-env agent\n", " single_agent = make_jax_agent(env, env_params, num_steps)\n", "\n", " # Vectorize it across multiple environments\n", " vectorized_agent = jax.vmap(single_agent)\n", "\n", " return vectorized_agent\n", "\n", "\n", "# Create vectorized agent\n", "num_envs = 16\n", "num_steps = 100\n", "\n", "print(f\"Creating vectorized agent with {num_envs} parallel environments...\")\n", "vec_agent = make_vectorized_agent(env, env_params, num_envs, num_steps)\n", "\n", "# Generate keys for each environment\n", "key = jr.PRNGKey(789)\n", "env_keys = jr.split(key, num_envs)\n", "\n", "# Run all environments in parallel\n", "print(f\"Running {num_envs} environments × {num_steps} steps...\")\n", "all_rewards, all_final_timesteps = vec_agent(env_keys)\n", "\n", "# Analyze results\n", "print(f\"\\nResults across {num_envs} environments:\")\n", "print(f\"Mean total reward: {float(all_rewards.sum(axis=1).mean()):.2f}\")\n", "print(f\"Best environment reward: {float(all_rewards.sum(axis=1).max()):.2f}\")\n", "print(f\"Worst environment reward: {float(all_rewards.sum(axis=1).min()):.2f}\")\n", "print(f\"Mean reward per step: {float(all_rewards.mean()):.3f}\")\n", "\n", "# Total steps executed\n", "total_steps = num_envs * num_steps\n", "print(f\"\\nTotal steps executed: {total_steps:,}\")" ] }, { "cell_type": "markdown", "id": "2632fc1e", "metadata": {}, "source": [ "## Performance Benchmark\n", "\n", "Let's measure the throughput of our vectorized agent to understand the performance benefits of JAX." ] }, { "cell_type": "code", "execution_count": 6, "id": "fe59bf13", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Benchmarking vectorized agent...\n", "Configuration: 64 envs × 256 steps\n", "Warmup run (includes compilation)...\n", "\n", "Warmup complete: 2.36s (includes JIT compilation)\n", "\n", "Running 3 timed iterations...\n", " Run 1: 0.085s\n", " Run 2: 0.085s\n", " Run 3: 0.087s\n", "\n", "Performance Results:\n", "Mean execution time: 0.086s\n", "Steps per second (SPS): 190,766\n", "Total steps per run: 16,384\n" ] } ], "source": [ "import time\n", "\n", "# Benchmark configuration\n", "num_envs = 64\n", "num_steps = 256\n", "num_runs = 3\n", "\n", "print(\"Benchmarking vectorized agent...\")\n", "print(f\"Configuration: {num_envs} envs × {num_steps} steps\")\n", "print(\"Warmup run (includes compilation)...\\n\")\n", "\n", "# Create fresh agent\n", "vec_agent = make_vectorized_agent(env, env_params, num_envs, num_steps)\n", "key = jr.PRNGKey(999)\n", "env_keys = jr.split(key, num_envs)\n", "\n", "# Warmup run (includes compilation)\n", "start = time.time()\n", "rewards, _ = vec_agent(env_keys)\n", "_ = rewards.block_until_ready() # Wait for computation\n", "warmup_time = time.time() - start\n", "\n", "print(f\"Warmup complete: {warmup_time:.2f}s (includes JIT compilation)\")\n", "\n", "# Timed runs\n", "print(f\"\\nRunning {num_runs} timed iterations...\")\n", "times = []\n", "\n", "for i in range(num_runs):\n", " key, subkey = jr.split(key)\n", " env_keys = jr.split(subkey, num_envs)\n", "\n", " start = time.time()\n", " rewards, _ = vec_agent(env_keys)\n", " _ = rewards.block_until_ready()\n", " elapsed = time.time() - start\n", " times.append(elapsed)\n", "\n", " print(f\" Run {i + 1}: {elapsed:.3f}s\")\n", "\n", "# Calculate statistics\n", "mean_time = sum(times) / len(times)\n", "total_steps = num_envs * num_steps\n", "sps = total_steps / mean_time\n", "\n", "print(\"\\nPerformance Results:\")\n", "print(f\"Mean execution time: {mean_time:.3f}s\")\n", "print(f\"Steps per second (SPS): {sps:,.0f}\")\n", "print(f\"Total steps per run: {total_steps:,}\")" ] }, { "cell_type": "markdown", "id": "0829e4f8", "metadata": {}, "source": [ "## Building Your Own Agent\n", "\n", "To create a learning agent (not just random):\n", "\n", "1. **Define a neural network** using Flax, Haiku, or Equinox\n", "2. **Collect trajectories** using the scan pattern shown above\n", "3. **Compute losses** from rewards and observations\n", "4. **Update parameters** using Optax optimizers\n", "5. **Repeat** the training loop" ] } ], "metadata": { "kernelspec": { "display_name": "dev", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.12" } }, "nbformat": 4, "nbformat_minor": 5 }