Wrappers API¶
Wrappers transform environment interfaces for different use cases.
JaxARC provides two types of wrappers:
Action Wrappers: Convert between action formats (dict → mask, bbox → mask, flatten)
Observation Wrappers: Add channels to observations (input grid, answer, clipboard, context)
Visualization Wrappers: Enhance rendering capabilities (step visualization)
Action Wrappers¶
PointActionWrapper¶
BboxActionWrapper¶
FlattenActionWrapper¶
- class jaxarc.wrappers.FlattenActionWrapper(env: Environment)[source]¶
Bases:
Wrapper[State]A general-purpose wrapper to flatten any composite discrete action space.
This wrapper can handle any combination of DictSpace, MultiDiscreteSpace, and DiscreteSpace, converting them into a single, unified DiscreteSpace.
Observation Wrappers¶
InputGridObservationWrapper¶
AnswerObservationWrapper¶
ClipboardObservationWrapper¶
ContextualObservationWrapper¶
- class jaxarc.wrappers.ContextualObservationWrapper(env, num_context_pairs: int = 5)[source]¶
Bases:
BaseObservationWrapperAdds N context demonstration pairs to the observation.
This wrapper adds 2 * num_context_pairs channels to the observation, representing the input and output grids of other demonstration pairs from the same task. This provides the agent with contextual examples.
Visualization Wrappers¶
StepVisualizationWrapper¶
- class jaxarc.wrappers.StepVisualizationWrapper(env)[source]¶
Bases:
WrapperWrapper that enables detailed step visualization by tracking transition history.
Enables env.render(mode=”detailed”) which returns a rich SVG of the last transition.
- reset(key, env_params: EnvParams | None = None) tuple[Any, Any][source]¶
Reset the environment.
- Parameters:
rng_key – A JAX PRNG key for random number generation.
env_params – Optional environment parameters.
- Returns:
A tuple of the initial state and the first TimeStep.
- step(state: State, action: Any, env_params: EnvParams | None = None) tuple[Any, Any][source]¶
Take a step in the environment.
- Parameters:
state – The current environment state.
action – The action to take.
env_params – Optional environment parameters.
- Returns:
A tuple of the new state and the resulting TimeStep.
Usage Example¶
from jaxarc import make
from jaxarc.wrappers import PointActionWrapper, InputGridObservationWrapper
# Create base environment
env, env_params = make("Mini")
# Add wrappers
env = PointActionWrapper(env)
env = InputGridObservationWrapper(env)
# Use wrapped environment
state, timestep = env.reset(key, env_params)
action = {"operation": 2, "row": 5, "col": 5}
state, timestep = env.step(state, action, env_params)
See Also¶
Using Wrappers - Tutorial on using wrappers
Core API - Core environment API