{ "cells": [ { "cell_type": "markdown", "id": "3f5e102a", "metadata": {}, "source": [ "# Using Wrappers\n", "\n", "Wrappers transform environments without modifying their core logic. JaxARC provides wrappers for:\n", "\n", "- **Action transformation** - Convert between different action formats\n", "- **Observation augmentation** - Add channels to observations\n", "- **Action space flattening** - Simplify complex action spaces\n", "\n", "Wrappers follow the delegation pattern:\n", "1. **Core environment** handles only `Action` objects (mask-based selections)\n", "2. **Wrappers** convert user-friendly formats to/from masks\n", "3. **Composable** - stack multiple wrappers easily" ] }, { "cell_type": "markdown", "id": "4f06eb91", "metadata": {}, "source": [ "## Setup: Base Environment\n", "\n", "Let's start with a base environment that uses mask-based actions." ] }, { "cell_type": "code", "execution_count": 1, "id": "3804be97", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\u001b[32m2025-11-18 22:47:09.240\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mjaxarc.utils.dataset_manager\u001b[0m:\u001b[36mvalidate_dataset\u001b[0m:\u001b[36m212\u001b[0m - \u001b[34m\u001b[1mDataset validation passed: /Users/aadam/workspace/JaxARC/data/MiniARC\u001b[0m\n", "\u001b[32m2025-11-18 22:47:09.240\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mjaxarc.utils.dataset_manager\u001b[0m:\u001b[36mensure_dataset_available\u001b[0m:\u001b[36m81\u001b[0m - \u001b[34m\u001b[1mDataset 'MiniARC' found at /Users/aadam/workspace/JaxARC/data/MiniARC\u001b[0m\n", "\u001b[32m2025-11-18 22:47:09.243\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mjaxarc.parsers.mini_arc\u001b[0m:\u001b[36m_validate_grid_constraints\u001b[0m:\u001b[36m104\u001b[0m - \u001b[1mMiniARC parser configured with optimal 5x5 grid constraints\u001b[0m\n", "\u001b[32m2025-11-18 22:47:09.245\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mjaxarc.parsers.mini_arc\u001b[0m:\u001b[36m_scan_available_tasks\u001b[0m:\u001b[36m131\u001b[0m - \u001b[1mFound 149 tasks in MiniARC dataset (lazy loading - tasks loaded on-demand, optimized for 5x5 grids)\u001b[0m\n", "\u001b[32m2025-11-18 22:47:09.246\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mjaxarc.parsers.mini_arc\u001b[0m:\u001b[36m_load_task_from_disk\u001b[0m:\u001b[36m171\u001b[0m - \u001b[34m\u001b[1mLoaded MiniARC task 'Most_Common_color_l6ab0lf3xztbyxsu3p' from disk\u001b[0m\n", "\u001b[32m2025-11-18 22:47:09.240\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mjaxarc.utils.dataset_manager\u001b[0m:\u001b[36mensure_dataset_available\u001b[0m:\u001b[36m81\u001b[0m - \u001b[34m\u001b[1mDataset 'MiniARC' found at /Users/aadam/workspace/JaxARC/data/MiniARC\u001b[0m\n", "\u001b[32m2025-11-18 22:47:09.243\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mjaxarc.parsers.mini_arc\u001b[0m:\u001b[36m_validate_grid_constraints\u001b[0m:\u001b[36m104\u001b[0m - \u001b[1mMiniARC parser configured with optimal 5x5 grid constraints\u001b[0m\n", "\u001b[32m2025-11-18 22:47:09.245\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mjaxarc.parsers.mini_arc\u001b[0m:\u001b[36m_scan_available_tasks\u001b[0m:\u001b[36m131\u001b[0m - \u001b[1mFound 149 tasks in MiniARC dataset (lazy loading - tasks loaded on-demand, optimized for 5x5 grids)\u001b[0m\n", "\u001b[32m2025-11-18 22:47:09.246\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mjaxarc.parsers.mini_arc\u001b[0m:\u001b[36m_load_task_from_disk\u001b[0m:\u001b[36m171\u001b[0m - \u001b[34m\u001b[1mLoaded MiniARC task 'Most_Common_color_l6ab0lf3xztbyxsu3p' from disk\u001b[0m\n", "\u001b[32m2025-11-18 22:47:09.658\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mjaxarc.parsers.base_parser\u001b[0m:\u001b[36m_log_parsing_stats\u001b[0m:\u001b[36m479\u001b[0m - \u001b[34m\u001b[1mTask Most_Common_color_l6ab0lf3xztbyxsu3p: 3 train pairs, 1 test pairs, max grid size: 5x5\u001b[0m\n", "\u001b[32m2025-11-18 22:47:09.658\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mjaxarc.utils.task_manager\u001b[0m:\u001b[36mget_global_task_manager\u001b[0m:\u001b[36m236\u001b[0m - \u001b[34m\u001b[1mCreated global task ID manager\u001b[0m\n", "\u001b[32m2025-11-18 22:47:09.659\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mjaxarc.utils.task_manager\u001b[0m:\u001b[36mregister_task\u001b[0m:\u001b[36m72\u001b[0m - \u001b[34m\u001b[1mRegistered task 'Most_Common_color_l6ab0lf3xztbyxsu3p' with index 0\u001b[0m\n", "\u001b[32m2025-11-18 22:47:09.658\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mjaxarc.parsers.base_parser\u001b[0m:\u001b[36m_log_parsing_stats\u001b[0m:\u001b[36m479\u001b[0m - \u001b[34m\u001b[1mTask Most_Common_color_l6ab0lf3xztbyxsu3p: 3 train pairs, 1 test pairs, max grid size: 5x5\u001b[0m\n", "\u001b[32m2025-11-18 22:47:09.658\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mjaxarc.utils.task_manager\u001b[0m:\u001b[36mget_global_task_manager\u001b[0m:\u001b[36m236\u001b[0m - \u001b[34m\u001b[1mCreated global task ID manager\u001b[0m\n", "\u001b[32m2025-11-18 22:47:09.659\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mjaxarc.utils.task_manager\u001b[0m:\u001b[36mregister_task\u001b[0m:\u001b[36m72\u001b[0m - \u001b[34m\u001b[1mRegistered task 'Most_Common_color_l6ab0lf3xztbyxsu3p' with index 0\u001b[0m\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Base action space: DictSpace({operation=DiscreteSpace(num_values=35, dtype=int32, name='operation'), selection=MultiDiscreteSpace(num_values=[Array([2, 2, 2, 2, 2], dtype=int32), Array([2, 2, 2, 2, 2], dtype=int32), Array([2, 2, 2, 2, 2], dtype=int32), Array([2, 2, 2, 2, 2], dtype=int32), Array([2, 2, 2, 2, 2], dtype=int32)], dtype=int32, name='selection_mask')}, name='arc_action')\n", "Action keys: ['operation', 'selection']\n" ] } ], "source": [ "from __future__ import annotations\n", "\n", "import jax.random as jr\n", "\n", "from jaxarc.configs import JaxArcConfig\n", "from jaxarc.registration import make\n", "from jaxarc.utils.core import get_config\n", "\n", "# Setup environment with minimal logging\n", "config_overrides = [\n", " \"dataset=mini_arc\",\n", " \"action=raw\",\n", " \"wandb.enabled=false\",\n", " \"logging.log_operations=false\",\n", " \"logging.log_rewards=false\",\n", " \"visualization.enabled=false\",\n", "]\n", "\n", "hydra_config = get_config(overrides=config_overrides)\n", "config = JaxArcConfig.from_hydra(hydra_config)\n", "\n", "# Create base environment\n", "env, env_params = make(\"Mini-Most_Common_color_l6ab0lf3xztbyxsu3p\", config=config)\n", "\n", "# Check the action space\n", "action_space = env.action_space(env_params)\n", "print(f\"Base action space: {action_space}\")\n", "print(f\"Action keys: {list(action_space.spaces.keys())}\")" ] }, { "cell_type": "markdown", "id": "54246678", "metadata": {}, "source": [ "## Action Wrappers\n", "\n", "Action wrappers convert user-friendly action formats into the mask-based `Action` objects that the core environment expects.\n", "\n", "### 1. PointActionWrapper\n", "\n", "Converts point-based actions `{\"operation\": op, \"row\": r, \"col\": c}` to mask selections." ] }, { "cell_type": "code", "execution_count": 2, "id": "073965d9", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Point action space: DictSpace({operation=DiscreteSpace(num_values=35, dtype=int32, name='operation'), row=DiscreteSpace(num_values=5, dtype=int32, name=''), col=DiscreteSpace(num_values=5, dtype=int32, name='')}, name='point_action')\n", "Action keys: ['operation', 'row', 'col']\n", "\n", "Initial observation shape: (5, 5, 1)\n", "\n", "Initial observation shape: (5, 5, 1)\n", "Point action executed: {'operation': 2, 'row': 2, 'col': 3}\n", "Reward: -0.005\n", "Point action executed: {'operation': 2, 'row': 2, 'col': 3}\n", "Reward: -0.005\n" ] } ], "source": [ "from jaxarc.wrappers import PointActionWrapper\n", "\n", "# Wrap environment\n", "point_env = PointActionWrapper(env)\n", "\n", "# Check new action space\n", "point_action_space = point_env.action_space(env_params)\n", "print(f\"Point action space: {point_action_space}\")\n", "print(f\"Action keys: {list(point_action_space.spaces.keys())}\")\n", "\n", "# Reset and take a point action\n", "key = jr.PRNGKey(42)\n", "state, timestep = point_env.reset(key, env_params)\n", "\n", "print(f\"\\nInitial observation shape: {timestep.observation.shape}\")\n", "\n", "# Take a point action\n", "action = {\"operation\": 2, \"row\": 2, \"col\": 3}\n", "state, timestep = point_env.step(state, action, env_params)\n", "\n", "print(f\"Point action executed: {action}\")\n", "print(f\"Reward: {float(timestep.reward):.3f}\")" ] }, { "cell_type": "markdown", "id": "9cbed684", "metadata": {}, "source": [ "### BboxActionWrapper\n", "\n", "For operations that require a rectangular region (selection, copy, cut), use `BboxActionWrapper`:" ] }, { "cell_type": "code", "execution_count": 3, "id": "40733570", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Bbox action space: DictSpace({operation=DiscreteSpace(num_values=35, dtype=int32, name='operation'), r1=DiscreteSpace(num_values=5, dtype=int32, name=''), c1=DiscreteSpace(num_values=5, dtype=int32, name=''), r2=DiscreteSpace(num_values=5, dtype=int32, name=''), c2=DiscreteSpace(num_values=5, dtype=int32, name='')}, name='bbox_action')\n", "Action keys: ['operation', 'r1', 'c1', 'r2', 'c2']\n", "\n", "Bbox action executed: {'operation': 0, 'r1': 1, 'c1': 1, 'r2': 2, 'c2': 3}\n", "Reward: -0.005\n" ] } ], "source": [ "from jaxarc.wrappers import BboxActionWrapper\n", "\n", "# Wrap environment\n", "bbox_env = BboxActionWrapper(env)\n", "\n", "# Check action space\n", "bbox_action_space = bbox_env.action_space(env_params)\n", "print(f\"Bbox action space: {bbox_action_space}\")\n", "print(f\"Action keys: {list(bbox_action_space.spaces.keys())}\")\n", "\n", "# Reset and take a bbox action\n", "key = jr.PRNGKey(43)\n", "state, timestep = bbox_env.reset(key, env_params)\n", "\n", "# Select a 2x3 region\n", "action = {\"operation\": 0, \"r1\": 1, \"c1\": 1, \"r2\": 2, \"c2\": 3}\n", "state, timestep = bbox_env.step(state, action, env_params)\n", "\n", "print(f\"\\nBbox action executed: {action}\")\n", "print(f\"Reward: {float(timestep.reward):.3f}\")" ] }, { "cell_type": "markdown", "id": "4db7c0f3", "metadata": {}, "source": [ "### FlattenActionWrapper\n", "\n", "For RL algorithms that work with single discrete action spaces, `FlattenActionWrapper` flattens the composite action space:" ] }, { "cell_type": "code", "execution_count": 4, "id": "6812630d", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Flattened action space: DiscreteSpace(num_values=875, dtype=int32, name='')\n", "\n", "Flattened action: 752\n", "Reward: -0.005\n" ] } ], "source": [ "from jaxarc.wrappers import FlattenActionWrapper\n", "\n", "# Wrap environment\n", "# Using PointActionWrapper here to reduce the action space size for demonstration\n", "flat_env = FlattenActionWrapper(point_env)\n", "\n", "# Check action space\n", "flat_action_space = flat_env.action_space(env_params)\n", "print(f\"Flattened action space: {flat_action_space}\")\n", "\n", "# Reset and take a flattened action\n", "key = jr.PRNGKey(44)\n", "state, timestep = flat_env.reset(key, env_params)\n", "\n", "# Sample a random action\n", "action = flat_action_space.sample(key)\n", "state, timestep = flat_env.step(state, action, env_params)\n", "\n", "print(f\"\\nFlattened action: {action}\")\n", "print(f\"Reward: {float(timestep.reward):.3f}\")" ] }, { "cell_type": "markdown", "id": "2f9263d5", "metadata": {}, "source": [ "## Observation Wrappers\n", "\n", "Observation wrappers add channels to the observation tensor, providing the agent with additional context.\n", "\n", "### Basic Observation Wrappers\n", "\n", "These wrappers add single-channel context:" ] }, { "cell_type": "code", "execution_count": 5, "id": "0fc3e6cc", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Base observation shape: (5, 5, 1)\n", "+ InputGridObservationWrapper: (5, 5, 2)\n", "+ AnswerObservationWrapper: (5, 5, 3)\n", "+ ClipboardObservationWrapper: (5, 5, 4)\n", "\n", "Total channels so far: 4\n", "+ AnswerObservationWrapper: (5, 5, 3)\n", "+ ClipboardObservationWrapper: (5, 5, 4)\n", "\n", "Total channels so far: 4\n" ] } ], "source": [ "from jaxarc.wrappers import (\n", " AnswerObservationWrapper,\n", " ClipboardObservationWrapper,\n", " InputGridObservationWrapper,\n", ")\n", "\n", "# Start fresh\n", "key = jr.PRNGKey(45)\n", "state, timestep = env.reset(key, env_params)\n", "print(f\"Base observation shape: {timestep.observation.shape}\")\n", "\n", "# Add input grid channel\n", "env_with_input = InputGridObservationWrapper(env)\n", "state, timestep = env_with_input.reset(key, env_params)\n", "print(f\"+ InputGridObservationWrapper: {timestep.observation.shape}\")\n", "\n", "# Add answer grid channel\n", "env_with_answer = AnswerObservationWrapper(env_with_input)\n", "state, timestep = env_with_answer.reset(key, env_params)\n", "print(f\"+ AnswerObservationWrapper: {timestep.observation.shape}\")\n", "\n", "# Add clipboard channel\n", "env_with_clipboard = ClipboardObservationWrapper(env_with_answer)\n", "state, timestep = env_with_clipboard.reset(key, env_params)\n", "print(f\"+ ClipboardObservationWrapper: {timestep.observation.shape}\")\n", "\n", "print(f\"\\nTotal channels so far: {timestep.observation.shape[-1]}\")" ] }, { "cell_type": "markdown", "id": "94c606c9", "metadata": {}, "source": [ "### ContextualObservationWrapper\n", "\n", "The `ContextualObservationWrapper` adds **demonstration pairs** from the task to the observation. This gives the agent access to other input/output examples that illustrate the task's transformation pattern.\n", "\n", "Key features:\n", "- Adds `2 * num_context_pairs` channels (input + output for each pair)\n", "- During **training**: excludes the current pair being solved\n", "- During **testing**: includes all demonstration pairs (since we're solving a test pair)\n", "- Pads with zeros if fewer demonstration pairs are available than requested" ] }, { "cell_type": "code", "execution_count": 6, "id": "04c0e435", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "With ContextualObservationWrapper (3 pairs):\n", " Observation shape: (5, 5, 10)\n", " Added channels: 6 (3 pairs × 2 channels per pair)\n", "\n", "Total channels: 10\n" ] } ], "source": [ "from jaxarc.wrappers import ContextualObservationWrapper\n", "\n", "# Add 3 demonstration pairs as context\n", "env_with_context = ContextualObservationWrapper(env_with_clipboard, num_context_pairs=3)\n", "\n", "key = jr.PRNGKey(45)\n", "state, timestep = env_with_context.reset(key, env_params)\n", "\n", "print(\"With ContextualObservationWrapper (3 pairs):\")\n", "print(f\" Observation shape: {timestep.observation.shape}\")\n", "print(f\" Added channels: {3 * 2} (3 pairs × 2 channels per pair)\")\n", "\n", "print(f\"\\nTotal channels: {timestep.observation.shape[-1]}\")" ] }, { "cell_type": "markdown", "id": "5d535660", "metadata": {}, "source": [ "## Combining Action and Observation Wrappers\n", "\n", "You can chain both types of wrappers together:" ] }, { "cell_type": "code", "execution_count": 7, "id": "0860b5a7", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Wrapped environment:\n", " Action space: DictSpace({operation=DiscreteSpace(num_values=35, dtype=int32, name='operation'), row=DiscreteSpace(num_values=5, dtype=int32, name=''), col=DiscreteSpace(num_values=5, dtype=int32, name='')}, name='point_action')\n", " Observation shape: (5, 5, 3)\n", "\n", "Action executed successfully with enhanced observations\n" ] } ], "source": [ "# Create a fully wrapped environment\n", "wrapped_env = PointActionWrapper(env)\n", "wrapped_env = InputGridObservationWrapper(wrapped_env)\n", "wrapped_env = AnswerObservationWrapper(wrapped_env)\n", "\n", "# Reset and inspect\n", "key = jr.PRNGKey(46)\n", "state, timestep = wrapped_env.reset(key, env_params)\n", "\n", "print(\"Wrapped environment:\")\n", "print(f\" Action space: {wrapped_env.action_space(env_params)}\")\n", "print(f\" Observation shape: {timestep.observation.shape}\")\n", "\n", "# Take a point action\n", "action = {\"operation\": 1, \"row\": 1, \"col\": 1}\n", "state, timestep = wrapped_env.step(state, action, env_params)\n", "\n", "print(\"\\nAction executed successfully with enhanced observations\")" ] }, { "cell_type": "markdown", "id": "0f751160", "metadata": {}, "source": [ "## Summary\n", "\n", "\n", "| Wrapper Type | Purpose | Example Use Case |\n", "|-------------|---------|------------------|\n", "| **Action Wrappers** | | |\n", "| `PointActionWrapper` | Dict actions with single points | Agents that select one cell at a time |\n", "| `BboxActionWrapper` | Dict actions with bounding boxes | Agents that work with regions |\n", "| `FlattenActionWrapper` | Single discrete action space | Standard RL algorithms (DQN, PPO) |\n", "| **Observation Wrappers** | | |\n", "| `InputGridObservationWrapper` | Add input grid channel | Always visible reference |\n", "| `AnswerObservationWrapper` | Add answer grid channel | Training with supervision |\n", "| `ClipboardObservationWrapper` | Add clipboard channel | Copy-paste operations |\n", "| `ContextualObservationWrapper` | Add demonstration pairs | Few-shot learning, pattern recognition |\n", "| **Visualization Wrappers** | | |\n", "| `StepVisualizationWrapper` | Enable detailed SVG rendering | Debugging agent actions and transitions |\n", "\n", "Wrappers enhance environment usability without altering core logic. They enable flexible action formats, richer observations, and better visualization, facilitating effective agent training and evaluation." ] } ], "metadata": { "kernelspec": { "display_name": "dev", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.12" } }, "nbformat": 4, "nbformat_minor": 5 }