Using Wrappers¶
Wrappers transform environments without modifying their core logic. JaxARC provides wrappers for:
Action transformation - Convert between different action formats
Observation augmentation - Add channels to observations
Action space flattening - Simplify complex action spaces
Wrappers follow the delegation pattern:
Core environment handles only
Actionobjects (mask-based selections)Wrappers convert user-friendly formats to/from masks
Composable - stack multiple wrappers easily
Setup: Base Environment¶
Let’s start with a base environment that uses mask-based actions.
from __future__ import annotations
import jax.random as jr
from jaxarc.configs import JaxArcConfig
from jaxarc.registration import make
from jaxarc.utils.core import get_config
# Setup environment with minimal logging
config_overrides = [
"dataset=mini_arc",
"action=raw",
"wandb.enabled=false",
"logging.log_operations=false",
"logging.log_rewards=false",
"visualization.enabled=false",
]
hydra_config = get_config(overrides=config_overrides)
config = JaxArcConfig.from_hydra(hydra_config)
# Create base environment
env, env_params = make("Mini-Most_Common_color_l6ab0lf3xztbyxsu3p", config=config)
# Check the action space
action_space = env.action_space(env_params)
print(f"Base action space: {action_space}")
print(f"Action keys: {list(action_space.spaces.keys())}")
Base action space: DictSpace({operation=DiscreteSpace(num_values=35, dtype=int32, name='operation'), selection=MultiDiscreteSpace(num_values=[Array([2, 2, 2, 2, 2], dtype=int32), Array([2, 2, 2, 2, 2], dtype=int32), Array([2, 2, 2, 2, 2], dtype=int32), Array([2, 2, 2, 2, 2], dtype=int32), Array([2, 2, 2, 2, 2], dtype=int32)], dtype=int32, name='selection_mask')}, name='arc_action')
Action keys: ['operation', 'selection']
2025-11-18 22:47:09.240 | DEBUG | jaxarc.utils.dataset_manager:validate_dataset:212 - Dataset validation passed: /Users/aadam/workspace/JaxARC/data/MiniARC
2025-11-18 22:47:09.240 | DEBUG | jaxarc.utils.dataset_manager:ensure_dataset_available:81 - Dataset 'MiniARC' found at /Users/aadam/workspace/JaxARC/data/MiniARC
2025-11-18 22:47:09.243 | INFO | jaxarc.parsers.mini_arc:_validate_grid_constraints:104 - MiniARC parser configured with optimal 5x5 grid constraints
2025-11-18 22:47:09.245 | INFO | jaxarc.parsers.mini_arc:_scan_available_tasks:131 - Found 149 tasks in MiniARC dataset (lazy loading - tasks loaded on-demand, optimized for 5x5 grids)
2025-11-18 22:47:09.246 | DEBUG | jaxarc.parsers.mini_arc:_load_task_from_disk:171 - Loaded MiniARC task 'Most_Common_color_l6ab0lf3xztbyxsu3p' from disk
2025-11-18 22:47:09.240 | DEBUG | jaxarc.utils.dataset_manager:ensure_dataset_available:81 - Dataset 'MiniARC' found at /Users/aadam/workspace/JaxARC/data/MiniARC
2025-11-18 22:47:09.243 | INFO | jaxarc.parsers.mini_arc:_validate_grid_constraints:104 - MiniARC parser configured with optimal 5x5 grid constraints
2025-11-18 22:47:09.245 | INFO | jaxarc.parsers.mini_arc:_scan_available_tasks:131 - Found 149 tasks in MiniARC dataset (lazy loading - tasks loaded on-demand, optimized for 5x5 grids)
2025-11-18 22:47:09.246 | DEBUG | jaxarc.parsers.mini_arc:_load_task_from_disk:171 - Loaded MiniARC task 'Most_Common_color_l6ab0lf3xztbyxsu3p' from disk
2025-11-18 22:47:09.658 | DEBUG | jaxarc.parsers.base_parser:_log_parsing_stats:479 - Task Most_Common_color_l6ab0lf3xztbyxsu3p: 3 train pairs, 1 test pairs, max grid size: 5x5
2025-11-18 22:47:09.658 | DEBUG | jaxarc.utils.task_manager:get_global_task_manager:236 - Created global task ID manager
2025-11-18 22:47:09.659 | DEBUG | jaxarc.utils.task_manager:register_task:72 - Registered task 'Most_Common_color_l6ab0lf3xztbyxsu3p' with index 0
2025-11-18 22:47:09.658 | DEBUG | jaxarc.parsers.base_parser:_log_parsing_stats:479 - Task Most_Common_color_l6ab0lf3xztbyxsu3p: 3 train pairs, 1 test pairs, max grid size: 5x5
2025-11-18 22:47:09.658 | DEBUG | jaxarc.utils.task_manager:get_global_task_manager:236 - Created global task ID manager
2025-11-18 22:47:09.659 | DEBUG | jaxarc.utils.task_manager:register_task:72 - Registered task 'Most_Common_color_l6ab0lf3xztbyxsu3p' with index 0
Action Wrappers¶
Action wrappers convert user-friendly action formats into the mask-based Action objects that the core environment expects.
1. PointActionWrapper¶
Converts point-based actions {"operation": op, "row": r, "col": c} to mask selections.
from jaxarc.wrappers import PointActionWrapper
# Wrap environment
point_env = PointActionWrapper(env)
# Check new action space
point_action_space = point_env.action_space(env_params)
print(f"Point action space: {point_action_space}")
print(f"Action keys: {list(point_action_space.spaces.keys())}")
# Reset and take a point action
key = jr.PRNGKey(42)
state, timestep = point_env.reset(key, env_params)
print(f"\nInitial observation shape: {timestep.observation.shape}")
# Take a point action
action = {"operation": 2, "row": 2, "col": 3}
state, timestep = point_env.step(state, action, env_params)
print(f"Point action executed: {action}")
print(f"Reward: {float(timestep.reward):.3f}")
Point action space: DictSpace({operation=DiscreteSpace(num_values=35, dtype=int32, name='operation'), row=DiscreteSpace(num_values=5, dtype=int32, name=''), col=DiscreteSpace(num_values=5, dtype=int32, name='')}, name='point_action')
Action keys: ['operation', 'row', 'col']
Initial observation shape: (5, 5, 1)
Initial observation shape: (5, 5, 1)
Point action executed: {'operation': 2, 'row': 2, 'col': 3}
Reward: -0.005
Point action executed: {'operation': 2, 'row': 2, 'col': 3}
Reward: -0.005
BboxActionWrapper¶
For operations that require a rectangular region (selection, copy, cut), use BboxActionWrapper:
from jaxarc.wrappers import BboxActionWrapper
# Wrap environment
bbox_env = BboxActionWrapper(env)
# Check action space
bbox_action_space = bbox_env.action_space(env_params)
print(f"Bbox action space: {bbox_action_space}")
print(f"Action keys: {list(bbox_action_space.spaces.keys())}")
# Reset and take a bbox action
key = jr.PRNGKey(43)
state, timestep = bbox_env.reset(key, env_params)
# Select a 2x3 region
action = {"operation": 0, "r1": 1, "c1": 1, "r2": 2, "c2": 3}
state, timestep = bbox_env.step(state, action, env_params)
print(f"\nBbox action executed: {action}")
print(f"Reward: {float(timestep.reward):.3f}")
Bbox action space: DictSpace({operation=DiscreteSpace(num_values=35, dtype=int32, name='operation'), r1=DiscreteSpace(num_values=5, dtype=int32, name=''), c1=DiscreteSpace(num_values=5, dtype=int32, name=''), r2=DiscreteSpace(num_values=5, dtype=int32, name=''), c2=DiscreteSpace(num_values=5, dtype=int32, name='')}, name='bbox_action')
Action keys: ['operation', 'r1', 'c1', 'r2', 'c2']
Bbox action executed: {'operation': 0, 'r1': 1, 'c1': 1, 'r2': 2, 'c2': 3}
Reward: -0.005
FlattenActionWrapper¶
For RL algorithms that work with single discrete action spaces, FlattenActionWrapper flattens the composite action space:
from jaxarc.wrappers import FlattenActionWrapper
# Wrap environment
# Using PointActionWrapper here to reduce the action space size for demonstration
flat_env = FlattenActionWrapper(point_env)
# Check action space
flat_action_space = flat_env.action_space(env_params)
print(f"Flattened action space: {flat_action_space}")
# Reset and take a flattened action
key = jr.PRNGKey(44)
state, timestep = flat_env.reset(key, env_params)
# Sample a random action
action = flat_action_space.sample(key)
state, timestep = flat_env.step(state, action, env_params)
print(f"\nFlattened action: {action}")
print(f"Reward: {float(timestep.reward):.3f}")
Flattened action space: DiscreteSpace(num_values=875, dtype=int32, name='')
Flattened action: 752
Reward: -0.005
Observation Wrappers¶
Observation wrappers add channels to the observation tensor, providing the agent with additional context.
Basic Observation Wrappers¶
These wrappers add single-channel context:
from jaxarc.wrappers import (
AnswerObservationWrapper,
ClipboardObservationWrapper,
InputGridObservationWrapper,
)
# Start fresh
key = jr.PRNGKey(45)
state, timestep = env.reset(key, env_params)
print(f"Base observation shape: {timestep.observation.shape}")
# Add input grid channel
env_with_input = InputGridObservationWrapper(env)
state, timestep = env_with_input.reset(key, env_params)
print(f"+ InputGridObservationWrapper: {timestep.observation.shape}")
# Add answer grid channel
env_with_answer = AnswerObservationWrapper(env_with_input)
state, timestep = env_with_answer.reset(key, env_params)
print(f"+ AnswerObservationWrapper: {timestep.observation.shape}")
# Add clipboard channel
env_with_clipboard = ClipboardObservationWrapper(env_with_answer)
state, timestep = env_with_clipboard.reset(key, env_params)
print(f"+ ClipboardObservationWrapper: {timestep.observation.shape}")
print(f"\nTotal channels so far: {timestep.observation.shape[-1]}")
Base observation shape: (5, 5, 1)
+ InputGridObservationWrapper: (5, 5, 2)
+ AnswerObservationWrapper: (5, 5, 3)
+ ClipboardObservationWrapper: (5, 5, 4)
Total channels so far: 4
+ AnswerObservationWrapper: (5, 5, 3)
+ ClipboardObservationWrapper: (5, 5, 4)
Total channels so far: 4
ContextualObservationWrapper¶
The ContextualObservationWrapper adds demonstration pairs from the task to the observation. This gives the agent access to other input/output examples that illustrate the task’s transformation pattern.
Key features:
Adds
2 * num_context_pairschannels (input + output for each pair)During training: excludes the current pair being solved
During testing: includes all demonstration pairs (since we’re solving a test pair)
Pads with zeros if fewer demonstration pairs are available than requested
from jaxarc.wrappers import ContextualObservationWrapper
# Add 3 demonstration pairs as context
env_with_context = ContextualObservationWrapper(env_with_clipboard, num_context_pairs=3)
key = jr.PRNGKey(45)
state, timestep = env_with_context.reset(key, env_params)
print("With ContextualObservationWrapper (3 pairs):")
print(f" Observation shape: {timestep.observation.shape}")
print(f" Added channels: {3 * 2} (3 pairs × 2 channels per pair)")
print(f"\nTotal channels: {timestep.observation.shape[-1]}")
With ContextualObservationWrapper (3 pairs):
Observation shape: (5, 5, 10)
Added channels: 6 (3 pairs × 2 channels per pair)
Total channels: 10
Combining Action and Observation Wrappers¶
You can chain both types of wrappers together:
# Create a fully wrapped environment
wrapped_env = PointActionWrapper(env)
wrapped_env = InputGridObservationWrapper(wrapped_env)
wrapped_env = AnswerObservationWrapper(wrapped_env)
# Reset and inspect
key = jr.PRNGKey(46)
state, timestep = wrapped_env.reset(key, env_params)
print("Wrapped environment:")
print(f" Action space: {wrapped_env.action_space(env_params)}")
print(f" Observation shape: {timestep.observation.shape}")
# Take a point action
action = {"operation": 1, "row": 1, "col": 1}
state, timestep = wrapped_env.step(state, action, env_params)
print("\nAction executed successfully with enhanced observations")
Wrapped environment:
Action space: DictSpace({operation=DiscreteSpace(num_values=35, dtype=int32, name='operation'), row=DiscreteSpace(num_values=5, dtype=int32, name=''), col=DiscreteSpace(num_values=5, dtype=int32, name='')}, name='point_action')
Observation shape: (5, 5, 3)
Action executed successfully with enhanced observations
Summary¶
Wrapper Type |
Purpose |
Example Use Case |
|---|---|---|
Action Wrappers |
||
|
Dict actions with single points |
Agents that select one cell at a time |
|
Dict actions with bounding boxes |
Agents that work with regions |
|
Single discrete action space |
Standard RL algorithms (DQN, PPO) |
Observation Wrappers |
||
|
Add input grid channel |
Always visible reference |
|
Add answer grid channel |
Training with supervision |
|
Add clipboard channel |
Copy-paste operations |
|
Add demonstration pairs |
Few-shot learning, pattern recognition |
Visualization Wrappers |
||
|
Enable detailed SVG rendering |
Debugging agent actions and transitions |
Wrappers enhance environment usability without altering core logic. They enable flexible action formats, richer observations, and better visualization, facilitating effective agent training and evaluation.