Core API¶
The core API provides essential functions and classes for creating and interacting with JaxARC environments.
Quick Example¶
import jax
from jaxarc import make, Action
# Create environment
env, env_params = make("Mini")
# Reset and run episode
key = jax.random.PRNGKey(42)
state, timestep = env.reset(key, env_params)
# Take action
action = Action(...) # Create action
state, timestep = env.step(state, action, env_params)
Environment Creation¶
Core Classes¶
Environment¶
- class jaxarc.Environment(config: JaxArcConfig, buffer: Any, episode_mode: int = 0, subset_indices: Any | None = None)[source]¶
Bases:
EnvironmentJaxARC environment implementing Stoa API patterns.
Delegates to functional API while providing clean object-oriented interface.
- observation_space(_env_params: EnvParams | None = None) BoundedArraySpace[source]¶
Get ARC observation space.
- render(state: State, mode: str | None = None) Any[source]¶
Render the environment state.
- Parameters:
state – The current environment state.
mode – The rendering mode (“rgb_array”, “ansi”, “svg”). If None, uses the default mode from configuration.
- Returns:
The rendered output (numpy array, string, or SVG string).
- reset(rng_key: jax.Array, env_params: EnvParams | None = None) tuple[State, TimeStep][source]¶
Reset using functional API (supports optional per-call params override).
- state_space(_env_params: EnvParams | None = None) Space[source]¶
Return the state space of the environment.
- step(state: State, action: Action | dict, env_params: EnvParams | None = None) tuple[State, TimeStep][source]¶
Step using functional API (supports optional per-call params override).
- property unwrapped: Environment¶
Get the unwrapped environment.
State¶
- class jaxarc.State(working_grid: GridArray, working_grid_mask: MaskArray, input_grid: GridArray, input_grid_mask: MaskArray, target_grid: GridArray, target_grid_mask: MaskArray, selected: SelectionArray, clipboard: GridArray, step_count: StepCount, allowed_operations_mask: OperationMask, similarity_score: SimilarityScore, key: PRNGKey, task_idx: TaskIndex, pair_idx: PairIndex)[source]¶
Bases:
ModuleEnvironment state.
Contains only truly dynamic variables that change during episodes. Static configuration is moved to EnvParams.
- allowed_operations_mask: Bool[Array, '35']¶
- clipboard: Int[Array, '*batch height width']¶
- input_grid: Int[Array, '*batch height width']¶
- input_grid_mask: Bool[Array, '*batch height width']¶
- key: Int[Array, '2']¶
- pair_idx: Int[Array, '']¶
- selected: Bool[Array, '*batch height width']¶
- similarity_score: Float[Array, '*batch']¶
- step_count: Int[Array, '']¶
- target_grid: Int[Array, '*batch height width']¶
- target_grid_mask: Bool[Array, '*batch height width']¶
- task_idx: Int[Array, '']¶
- working_grid: Int[Array, '*batch height width']¶
- working_grid_mask: Bool[Array, '*batch height width']¶
Action¶
- class jaxarc.Action(operation: jnp.int32, selection: SelectionArray)[source]¶
Bases:
ModuleSimple action representation.
- operation¶
ARC operation ID (0-34)
- Type:
- selection¶
Boolean mask indicating selected cells
- Type:
jaxtyping.Bool[Array, ‘*batch height width’]
- selection: Bool[Array, '*batch height width']¶