Core API

The core API provides essential functions and classes for creating and interacting with JaxARC environments.

Quick Example

import jax
from jaxarc import make, Action

# Create environment
env, env_params = make("Mini")

# Reset and run episode
key = jax.random.PRNGKey(42)
state, timestep = env.reset(key, env_params)

# Take action
action = Action(...)  # Create action
state, timestep = env.step(state, action, env_params)

Environment Creation

jaxarc.make(id: str, **kwargs: Any) tuple[Any, Any][source]

Create an environment instance and EnvParams using a registered spec.

See EnvRegistry.make for details on supported kwargs.

Core Classes

Environment

class jaxarc.Environment(config: JaxArcConfig, buffer: Any, episode_mode: int = 0, subset_indices: Any | None = None)[source]

Bases: Environment

JaxARC environment implementing Stoa API patterns.

Delegates to functional API while providing clean object-oriented interface.

action_space(_env_params: EnvParams | None = None) ARCActionSpace[source]

Get ARC action space.

close() None[source]

Close the environment.

discount_space(_env_params: EnvParams | None = None) BoundedArraySpace[source]

Get discount space.

observation_shape() tuple[int, int, int][source]

Get observation shape.

observation_space(_env_params: EnvParams | None = None) BoundedArraySpace[source]

Get ARC observation space.

render(state: State, mode: str | None = None) Any[source]

Render the environment state.

Parameters:
  • state – The current environment state.

  • mode – The rendering mode (“rgb_array”, “ansi”, “svg”). If None, uses the default mode from configuration.

Returns:

The rendered output (numpy array, string, or SVG string).

reset(rng_key: jax.Array, env_params: EnvParams | None = None) tuple[State, TimeStep][source]

Reset using functional API (supports optional per-call params override).

reward_space(_env_params: EnvParams | None = None) BoundedArraySpace[source]

Get reward space.

state_space(_env_params: EnvParams | None = None) Space[source]

Return the state space of the environment.

step(state: State, action: Action | dict, env_params: EnvParams | None = None) tuple[State, TimeStep][source]

Step using functional API (supports optional per-call params override).

property unwrapped: Environment

Get the unwrapped environment.

State

class jaxarc.State(working_grid: GridArray, working_grid_mask: MaskArray, input_grid: GridArray, input_grid_mask: MaskArray, target_grid: GridArray, target_grid_mask: MaskArray, selected: SelectionArray, clipboard: GridArray, step_count: StepCount, allowed_operations_mask: OperationMask, similarity_score: SimilarityScore, key: PRNGKey, task_idx: TaskIndex, pair_idx: PairIndex)[source]

Bases: Module

Environment state.

Contains only truly dynamic variables that change during episodes. Static configuration is moved to EnvParams.

allowed_operations_mask: Bool[Array, '35']
clipboard: Int[Array, '*batch height width']
input_grid: Int[Array, '*batch height width']
input_grid_mask: Bool[Array, '*batch height width']
key: Int[Array, '2']
pair_idx: Int[Array, '']
selected: Bool[Array, '*batch height width']
similarity_score: Float[Array, '*batch']
step_count: Int[Array, '']
target_grid: Int[Array, '*batch height width']
target_grid_mask: Bool[Array, '*batch height width']
task_idx: Int[Array, '']
working_grid: Int[Array, '*batch height width']
working_grid_mask: Bool[Array, '*batch height width']

Action

class jaxarc.Action(operation: jnp.int32, selection: SelectionArray)[source]

Bases: Module

Simple action representation.

operation

ARC operation ID (0-34)

Type:

jax.numpy.int32

selection

Boolean mask indicating selected cells

Type:

jaxtyping.Bool[Array, ‘*batch height width’]

operation: int32
selection: Bool[Array, '*batch height width']
validate(grid_shape: tuple[int, int], max_operations: int = 35) Action[source]

Validate action parameters.

Parameters:
  • grid_shape – Shape of the grid (height, width)

  • max_operations – Maximum number of operations

Returns:

Validated action with clipped operation