Wrappers API

Wrappers transform environment interfaces for different use cases.

JaxARC provides two types of wrappers:

  • Action Wrappers: Convert between action formats (dict → mask, bbox → mask, flatten)

  • Observation Wrappers: Add channels to observations (input grid, answer, clipboard, context)

  • Visualization Wrappers: Enhance rendering capabilities (step visualization)

Action Wrappers

PointActionWrapper

class jaxarc.wrappers.PointActionWrapper(env: Environment)[source]

Bases: Wrapper

Point action wrapper with custom action space.

action_space(env_params: EnvParams | None = None) DictSpace[source]

Custom action space for point actions: (operation, row, col).

step(state: State, action: dict, env_params: EnvParams | None = None) tuple[State, TimeStep][source]

Convert point to mask and delegate.

BboxActionWrapper

class jaxarc.wrappers.BboxActionWrapper(env: Environment)[source]

Bases: Wrapper

Bbox action wrapper with custom action space.

action_space(env_params: EnvParams | None = None) DictSpace[source]

Custom action space for bbox actions: (operation, r1, c1, r2, c2).

step(state: State, action: dict, env_params: EnvParams | None = None) tuple[State, TimeStep][source]

Convert bbox to mask and delegate.

FlattenActionWrapper

class jaxarc.wrappers.FlattenActionWrapper(env: Environment)[source]

Bases: Wrapper[State]

A general-purpose wrapper to flatten any composite discrete action space.

This wrapper can handle any combination of DictSpace, MultiDiscreteSpace, and DiscreteSpace, converting them into a single, unified DiscreteSpace.

action_space(env_params: EnvParams | None = None) Space[source]

Returns the single, flattened DiscreteSpace.

step(state: State, action: Action, env_params: EnvParams | None = None) tuple[State, TimeStep][source]

Un-flattens the action and steps the underlying environment.

Observation Wrappers

InputGridObservationWrapper

class jaxarc.wrappers.InputGridObservationWrapper(env: Environment)[source]

Bases: BaseObservationWrapper

Adds the original input grid of the current pair as a new observation channel.

AnswerObservationWrapper

class jaxarc.wrappers.AnswerObservationWrapper(env: Environment)[source]

Bases: BaseObservationWrapper

Adds the task’s target grid (solution) as a new observation channel.

ClipboardObservationWrapper

class jaxarc.wrappers.ClipboardObservationWrapper(env: Environment)[source]

Bases: BaseObservationWrapper

Adds the agent’s clipboard as a new observation channel.

ContextualObservationWrapper

class jaxarc.wrappers.ContextualObservationWrapper(env, num_context_pairs: int = 5)[source]

Bases: BaseObservationWrapper

Adds N context demonstration pairs to the observation.

This wrapper adds 2 * num_context_pairs channels to the observation, representing the input and output grids of other demonstration pairs from the same task. This provides the agent with contextual examples.

Visualization Wrappers

StepVisualizationWrapper

class jaxarc.wrappers.StepVisualizationWrapper(env)[source]

Bases: Wrapper

Wrapper that enables detailed step visualization by tracking transition history.

Enables env.render(mode=”detailed”) which returns a rich SVG of the last transition.

render(state: State, mode: str | None = None) Any[source]

Render environment

reset(key, env_params: EnvParams | None = None) tuple[Any, Any][source]

Reset the environment.

Parameters:
  • rng_key – A JAX PRNG key for random number generation.

  • env_params – Optional environment parameters.

Returns:

A tuple of the initial state and the first TimeStep.

step(state: State, action: Any, env_params: EnvParams | None = None) tuple[Any, Any][source]

Take a step in the environment.

Parameters:
  • state – The current environment state.

  • action – The action to take.

  • env_params – Optional environment parameters.

Returns:

A tuple of the new state and the resulting TimeStep.

Usage Example

from jaxarc import make
from jaxarc.wrappers import PointActionWrapper, InputGridObservationWrapper

# Create base environment
env, env_params = make("Mini")

# Add wrappers
env = PointActionWrapper(env)
env = InputGridObservationWrapper(env)

# Use wrapped environment
state, timestep = env.reset(key, env_params)
action = {"operation": 2, "row": 5, "col": 5}
state, timestep = env.step(state, action, env_params)

See Also