Using Wrappers

Wrappers transform environments without modifying their core logic. JaxARC provides wrappers for:

  • Action transformation - Convert between different action formats

  • Observation augmentation - Add channels to observations

  • Action space flattening - Simplify complex action spaces

Wrappers follow the delegation pattern:

  1. Core environment handles only Action objects (mask-based selections)

  2. Wrappers convert user-friendly formats to/from masks

  3. Composable - stack multiple wrappers easily

Setup: Base Environment

Let’s start with a base environment that uses mask-based actions.

from __future__ import annotations

import jax.random as jr

from jaxarc.configs import JaxArcConfig
from jaxarc.registration import make
from jaxarc.utils.core import get_config

# Setup environment with minimal logging
config_overrides = [
    "dataset=mini_arc",
    "action=raw",
    "wandb.enabled=false",
    "logging.log_operations=false",
    "logging.log_rewards=false",
    "visualization.enabled=false",
]

hydra_config = get_config(overrides=config_overrides)
config = JaxArcConfig.from_hydra(hydra_config)

# Create base environment
env, env_params = make("Mini-Most_Common_color_l6ab0lf3xztbyxsu3p", config=config)

# Check the action space
action_space = env.action_space(env_params)
print(f"Base action space: {action_space}")
print(f"Action keys: {list(action_space.spaces.keys())}")
Base action space: DictSpace({operation=DiscreteSpace(num_values=35, dtype=int32, name='operation'), selection=MultiDiscreteSpace(num_values=[Array([2, 2, 2, 2, 2], dtype=int32), Array([2, 2, 2, 2, 2], dtype=int32), Array([2, 2, 2, 2, 2], dtype=int32), Array([2, 2, 2, 2, 2], dtype=int32), Array([2, 2, 2, 2, 2], dtype=int32)], dtype=int32, name='selection_mask')}, name='arc_action')
Action keys: ['operation', 'selection']
2025-11-18 22:47:09.240 | DEBUG    | jaxarc.utils.dataset_manager:validate_dataset:212 - Dataset validation passed: /Users/aadam/workspace/JaxARC/data/MiniARC
2025-11-18 22:47:09.240 | DEBUG    | jaxarc.utils.dataset_manager:ensure_dataset_available:81 - Dataset 'MiniARC' found at /Users/aadam/workspace/JaxARC/data/MiniARC
2025-11-18 22:47:09.243 | INFO     | jaxarc.parsers.mini_arc:_validate_grid_constraints:104 - MiniARC parser configured with optimal 5x5 grid constraints
2025-11-18 22:47:09.245 | INFO     | jaxarc.parsers.mini_arc:_scan_available_tasks:131 - Found 149 tasks in MiniARC dataset (lazy loading - tasks loaded on-demand, optimized for 5x5 grids)
2025-11-18 22:47:09.246 | DEBUG    | jaxarc.parsers.mini_arc:_load_task_from_disk:171 - Loaded MiniARC task 'Most_Common_color_l6ab0lf3xztbyxsu3p' from disk
2025-11-18 22:47:09.240 | DEBUG    | jaxarc.utils.dataset_manager:ensure_dataset_available:81 - Dataset 'MiniARC' found at /Users/aadam/workspace/JaxARC/data/MiniARC
2025-11-18 22:47:09.243 | INFO     | jaxarc.parsers.mini_arc:_validate_grid_constraints:104 - MiniARC parser configured with optimal 5x5 grid constraints
2025-11-18 22:47:09.245 | INFO     | jaxarc.parsers.mini_arc:_scan_available_tasks:131 - Found 149 tasks in MiniARC dataset (lazy loading - tasks loaded on-demand, optimized for 5x5 grids)
2025-11-18 22:47:09.246 | DEBUG    | jaxarc.parsers.mini_arc:_load_task_from_disk:171 - Loaded MiniARC task 'Most_Common_color_l6ab0lf3xztbyxsu3p' from disk
2025-11-18 22:47:09.658 | DEBUG    | jaxarc.parsers.base_parser:_log_parsing_stats:479 - Task Most_Common_color_l6ab0lf3xztbyxsu3p: 3 train pairs, 1 test pairs, max grid size: 5x5
2025-11-18 22:47:09.658 | DEBUG    | jaxarc.utils.task_manager:get_global_task_manager:236 - Created global task ID manager
2025-11-18 22:47:09.659 | DEBUG    | jaxarc.utils.task_manager:register_task:72 - Registered task 'Most_Common_color_l6ab0lf3xztbyxsu3p' with index 0
2025-11-18 22:47:09.658 | DEBUG    | jaxarc.parsers.base_parser:_log_parsing_stats:479 - Task Most_Common_color_l6ab0lf3xztbyxsu3p: 3 train pairs, 1 test pairs, max grid size: 5x5
2025-11-18 22:47:09.658 | DEBUG    | jaxarc.utils.task_manager:get_global_task_manager:236 - Created global task ID manager
2025-11-18 22:47:09.659 | DEBUG    | jaxarc.utils.task_manager:register_task:72 - Registered task 'Most_Common_color_l6ab0lf3xztbyxsu3p' with index 0

Action Wrappers

Action wrappers convert user-friendly action formats into the mask-based Action objects that the core environment expects.

1. PointActionWrapper

Converts point-based actions {"operation": op, "row": r, "col": c} to mask selections.

from jaxarc.wrappers import PointActionWrapper

# Wrap environment
point_env = PointActionWrapper(env)

# Check new action space
point_action_space = point_env.action_space(env_params)
print(f"Point action space: {point_action_space}")
print(f"Action keys: {list(point_action_space.spaces.keys())}")

# Reset and take a point action
key = jr.PRNGKey(42)
state, timestep = point_env.reset(key, env_params)

print(f"\nInitial observation shape: {timestep.observation.shape}")

# Take a point action
action = {"operation": 2, "row": 2, "col": 3}
state, timestep = point_env.step(state, action, env_params)

print(f"Point action executed: {action}")
print(f"Reward: {float(timestep.reward):.3f}")
Point action space: DictSpace({operation=DiscreteSpace(num_values=35, dtype=int32, name='operation'), row=DiscreteSpace(num_values=5, dtype=int32, name=''), col=DiscreteSpace(num_values=5, dtype=int32, name='')}, name='point_action')
Action keys: ['operation', 'row', 'col']

Initial observation shape: (5, 5, 1)

Initial observation shape: (5, 5, 1)
Point action executed: {'operation': 2, 'row': 2, 'col': 3}
Reward: -0.005
Point action executed: {'operation': 2, 'row': 2, 'col': 3}
Reward: -0.005

BboxActionWrapper

For operations that require a rectangular region (selection, copy, cut), use BboxActionWrapper:

from jaxarc.wrappers import BboxActionWrapper

# Wrap environment
bbox_env = BboxActionWrapper(env)

# Check action space
bbox_action_space = bbox_env.action_space(env_params)
print(f"Bbox action space: {bbox_action_space}")
print(f"Action keys: {list(bbox_action_space.spaces.keys())}")

# Reset and take a bbox action
key = jr.PRNGKey(43)
state, timestep = bbox_env.reset(key, env_params)

# Select a 2x3 region
action = {"operation": 0, "r1": 1, "c1": 1, "r2": 2, "c2": 3}
state, timestep = bbox_env.step(state, action, env_params)

print(f"\nBbox action executed: {action}")
print(f"Reward: {float(timestep.reward):.3f}")
Bbox action space: DictSpace({operation=DiscreteSpace(num_values=35, dtype=int32, name='operation'), r1=DiscreteSpace(num_values=5, dtype=int32, name=''), c1=DiscreteSpace(num_values=5, dtype=int32, name=''), r2=DiscreteSpace(num_values=5, dtype=int32, name=''), c2=DiscreteSpace(num_values=5, dtype=int32, name='')}, name='bbox_action')
Action keys: ['operation', 'r1', 'c1', 'r2', 'c2']

Bbox action executed: {'operation': 0, 'r1': 1, 'c1': 1, 'r2': 2, 'c2': 3}
Reward: -0.005

FlattenActionWrapper

For RL algorithms that work with single discrete action spaces, FlattenActionWrapper flattens the composite action space:

from jaxarc.wrappers import FlattenActionWrapper

# Wrap environment
# Using PointActionWrapper here to reduce the action space size for demonstration
flat_env = FlattenActionWrapper(point_env)

# Check action space
flat_action_space = flat_env.action_space(env_params)
print(f"Flattened action space: {flat_action_space}")

# Reset and take a flattened action
key = jr.PRNGKey(44)
state, timestep = flat_env.reset(key, env_params)

# Sample a random action
action = flat_action_space.sample(key)
state, timestep = flat_env.step(state, action, env_params)

print(f"\nFlattened action: {action}")
print(f"Reward: {float(timestep.reward):.3f}")
Flattened action space: DiscreteSpace(num_values=875, dtype=int32, name='')

Flattened action: 752
Reward: -0.005

Observation Wrappers

Observation wrappers add channels to the observation tensor, providing the agent with additional context.

Basic Observation Wrappers

These wrappers add single-channel context:

from jaxarc.wrappers import (
    AnswerObservationWrapper,
    ClipboardObservationWrapper,
    InputGridObservationWrapper,
)

# Start fresh
key = jr.PRNGKey(45)
state, timestep = env.reset(key, env_params)
print(f"Base observation shape: {timestep.observation.shape}")

# Add input grid channel
env_with_input = InputGridObservationWrapper(env)
state, timestep = env_with_input.reset(key, env_params)
print(f"+ InputGridObservationWrapper: {timestep.observation.shape}")

# Add answer grid channel
env_with_answer = AnswerObservationWrapper(env_with_input)
state, timestep = env_with_answer.reset(key, env_params)
print(f"+ AnswerObservationWrapper: {timestep.observation.shape}")

# Add clipboard channel
env_with_clipboard = ClipboardObservationWrapper(env_with_answer)
state, timestep = env_with_clipboard.reset(key, env_params)
print(f"+ ClipboardObservationWrapper: {timestep.observation.shape}")

print(f"\nTotal channels so far: {timestep.observation.shape[-1]}")
Base observation shape: (5, 5, 1)
+ InputGridObservationWrapper: (5, 5, 2)
+ AnswerObservationWrapper: (5, 5, 3)
+ ClipboardObservationWrapper: (5, 5, 4)

Total channels so far: 4
+ AnswerObservationWrapper: (5, 5, 3)
+ ClipboardObservationWrapper: (5, 5, 4)

Total channels so far: 4

ContextualObservationWrapper

The ContextualObservationWrapper adds demonstration pairs from the task to the observation. This gives the agent access to other input/output examples that illustrate the task’s transformation pattern.

Key features:

  • Adds 2 * num_context_pairs channels (input + output for each pair)

  • During training: excludes the current pair being solved

  • During testing: includes all demonstration pairs (since we’re solving a test pair)

  • Pads with zeros if fewer demonstration pairs are available than requested

from jaxarc.wrappers import ContextualObservationWrapper

# Add 3 demonstration pairs as context
env_with_context = ContextualObservationWrapper(env_with_clipboard, num_context_pairs=3)

key = jr.PRNGKey(45)
state, timestep = env_with_context.reset(key, env_params)

print("With ContextualObservationWrapper (3 pairs):")
print(f"  Observation shape: {timestep.observation.shape}")
print(f"  Added channels: {3 * 2} (3 pairs × 2 channels per pair)")

print(f"\nTotal channels: {timestep.observation.shape[-1]}")
With ContextualObservationWrapper (3 pairs):
  Observation shape: (5, 5, 10)
  Added channels: 6 (3 pairs × 2 channels per pair)

Total channels: 10

Combining Action and Observation Wrappers

You can chain both types of wrappers together:

# Create a fully wrapped environment
wrapped_env = PointActionWrapper(env)
wrapped_env = InputGridObservationWrapper(wrapped_env)
wrapped_env = AnswerObservationWrapper(wrapped_env)

# Reset and inspect
key = jr.PRNGKey(46)
state, timestep = wrapped_env.reset(key, env_params)

print("Wrapped environment:")
print(f"  Action space: {wrapped_env.action_space(env_params)}")
print(f"  Observation shape: {timestep.observation.shape}")

# Take a point action
action = {"operation": 1, "row": 1, "col": 1}
state, timestep = wrapped_env.step(state, action, env_params)

print("\nAction executed successfully with enhanced observations")
Wrapped environment:
  Action space: DictSpace({operation=DiscreteSpace(num_values=35, dtype=int32, name='operation'), row=DiscreteSpace(num_values=5, dtype=int32, name=''), col=DiscreteSpace(num_values=5, dtype=int32, name='')}, name='point_action')
  Observation shape: (5, 5, 3)

Action executed successfully with enhanced observations

Summary

Wrapper Type

Purpose

Example Use Case

Action Wrappers

PointActionWrapper

Dict actions with single points

Agents that select one cell at a time

BboxActionWrapper

Dict actions with bounding boxes

Agents that work with regions

FlattenActionWrapper

Single discrete action space

Standard RL algorithms (DQN, PPO)

Observation Wrappers

InputGridObservationWrapper

Add input grid channel

Always visible reference

AnswerObservationWrapper

Add answer grid channel

Training with supervision

ClipboardObservationWrapper

Add clipboard channel

Copy-paste operations

ContextualObservationWrapper

Add demonstration pairs

Few-shot learning, pattern recognition

Visualization Wrappers

StepVisualizationWrapper

Enable detailed SVG rendering

Debugging agent actions and transitions

Wrappers enhance environment usability without altering core logic. They enable flexible action formats, richer observations, and better visualization, facilitating effective agent training and evaluation.