Types API

Type definitions for JaxARC environment parameters and timesteps.

Module Contents

Type definitions for the JaxARC project.

This module contains all the core data structures used throughout the project, including grid representations, task data, agent states, and environment states. All types are designed to be JAX-compatible with proper validation and JAXTyping annotations.

This module also provides the core JAX array type aliases using JAXTyping for the JaxARC environment.

Key Features: - Core grid and mask array types with batch support - Action space type definitions - Task data structure types - Environment state types - Essential utility types

JAXTyping *batch modifier allows the same type to work for both single arrays (height, width) and batched arrays (batch1, batch2, …, height, width).

class jaxarc.types.ARCOperationType[source]

Bases: object

ARC operation types (grid + submit only).

Pair control operations (35-41) have been removed to simplify the action space. Remaining valid operation IDs: 0-34.

CLEAR = 31
COPY = 28
COPY_INPUT = 32
CUT = 30
FILL_0 = 0
FILL_1 = 1
FILL_2 = 2
FILL_3 = 3
FILL_4 = 4
FILL_5 = 5
FILL_6 = 6
FILL_7 = 7
FILL_8 = 8
FILL_9 = 9
FLIP_HORIZONTAL = 26
FLIP_VERTICAL = 27
FLOOD_FILL_0 = 10
FLOOD_FILL_1 = 11
FLOOD_FILL_2 = 12
FLOOD_FILL_3 = 13
FLOOD_FILL_4 = 14
FLOOD_FILL_5 = 15
FLOOD_FILL_6 = 16
FLOOD_FILL_7 = 17
FLOOD_FILL_8 = 18
FLOOD_FILL_9 = 19
MOVE_DOWN = 21
MOVE_LEFT = 22
MOVE_RIGHT = 23
MOVE_UP = 20
PASTE = 29
RESIZE = 33
ROTATE_C = 24
ROTATE_CC = 25
SUBMIT = 34
jaxarc.types.ColorValue

Scalar integer representing a color value (0-9).

jaxarc.types.DiscountValue

Float array representing discount value(s).

class jaxarc.types.EnvParams(dataset: DatasetConfig, action: ActionConfig, reward: RewardConfig, grid_initialization: GridInitializationConfig, max_episode_steps: int, buffer: Any = None, subset_indices: Any = None, episode_mode: int = 0)[source]

Bases: Module

Clean environment parameters - only what’s needed for environment logic.

This is NOT a rename of JaxArcConfig. JaxArcConfig contains framework concerns (logging, visualization, storage) that don’t belong in environment parameters.

EnvParams now carries a JAX-native task buffer for JIT/vmap-compatible reset(). The buffer is a stacked pytree of JAX arrays (batched JaxArcTask fields) and optional subset indices define a view into the buffer.

action: ActionConfig
buffer: Any = None
dataset: DatasetConfig
episode_mode: int = 0
classmethod from_config(config: JaxArcConfig, *, episode_mode: int = 0, buffer: Any = None, subset_indices: Any = None) EnvParams[source]

Extract environment parameters from the unified JaxArcConfig.

Parameters:
  • config – Full project configuration

  • episode_mode – 0=train, 1=test

  • buffer – Batched pytree of JAX arrays (stacked JaxArcTask fields)

  • subset_indices – Optional indices defining a subview into the buffer

grid_initialization: GridInitializationConfig
max_episode_steps: int
reward: RewardConfig
subset_indices: Any = None
class jaxarc.types.Grid(data: GridArray, mask: MaskArray)[source]

Bases: Module

Represents a grid in the ARC challenge using Equinox Module.

Equinox provides better JAX integration with automatic PyTree registration and improved compatibility with JAX transformations.

data

The grid data as a 2D integer array with JAXTyping shape annotation

Type:

jaxtyping.Int[Array, ‘*batch height width’]

mask

Boolean mask indicating valid cells with JAXTyping shape annotation

Type:

jaxtyping.Bool[Array, ‘*batch height width’]

data: Int[Array, '*batch height width']
mask: Bool[Array, '*batch height width']
property shape: tuple[int, int]

Get the shape of the valid region in the grid.

Uses the mask to determine the actual meaningful grid dimensions, not the padded dimensions.

Returns:

Tuple of (height, width) representing the valid region dimensions

jaxarc.types.GridArray

Integer array representing ARC grid(s) with color values 0-9.

jaxarc.types.GridHeight

Scalar integer representing grid height.

jaxarc.types.GridWidth

Scalar integer representing grid width.

class jaxarc.types.JaxArcTask(input_grids_examples: TaskInputGrids, input_masks_examples: TaskInputMasks, output_grids_examples: TaskOutputGrids, output_masks_examples: TaskOutputMasks, num_train_pairs: int, test_input_grids: TaskInputGrids, test_input_masks: TaskInputMasks, true_test_output_grids: TaskOutputGrids, true_test_output_masks: TaskOutputMasks, num_test_pairs: int, task_index: TaskIndex)[source]

Bases: Module

JAX-compatible ARC task data with fixed-size arrays for efficient processing using Equinox Module.

This structure contains all task data with fixed-size arrays padded to maximum dimensions for efficient batch processing and JAX transformations. All arrays now use JAXTyping annotations for better type safety and documentation.

# Training examples with JAXTyping annotations
input_grids_examples

Training input grids with precise shape annotation

Type:

jaxtyping.Int[Array, ‘max_pairs max_height max_width’]

input_masks_examples

Masks for training inputs with precise shape annotation

Type:

jaxtyping.Bool[Array, ‘max_pairs max_height max_width’]

output_grids_examples

Training output grids with precise shape annotation

Type:

jaxtyping.Int[Array, ‘max_pairs max_height max_width’]

output_masks_examples

Masks for training outputs with precise shape annotation

Type:

jaxtyping.Bool[Array, ‘max_pairs max_height max_width’]

num_train_pairs

Number of valid training pairs

Type:

int

# Test examples with JAXTyping annotations
test_input_grids

Test input grids with precise shape annotation

Type:

jaxtyping.Int[Array, ‘max_pairs max_height max_width’]

test_input_masks

Masks for test inputs with precise shape annotation

Type:

jaxtyping.Bool[Array, ‘max_pairs max_height max_width’]

true_test_output_grids

True test outputs with precise shape annotation

Type:

jaxtyping.Int[Array, ‘max_pairs max_height max_width’]

true_test_output_masks

Masks for true test outputs with precise shape annotation

Type:

jaxtyping.Bool[Array, ‘max_pairs max_height max_width’]

num_test_pairs

Number of valid test pairs

Type:

int

# Task metadata with JAXTyping annotation
task_index

Integer index for task identification (JAX-compatible scalar)

Type:

jaxtyping.Int[Array, ‘’]

get_available_demo_pairs() Bool[Array, ...][source]

Get mask of available training pairs.

Returns:

JAX boolean array indicating which training pairs are available (based on num_train_pairs)

get_available_test_pairs() Bool[Array, ...][source]

Get mask of available test pairs.

Returns:

JAX boolean array indicating which test pairs are available (based on num_test_pairs)

get_demo_pair_data(pair_idx: int) tuple[Int[Array, '*batch height width'], Int[Array, '*batch height width'], Bool[Array, '*batch height width'], Bool[Array, '*batch height width']][source]

Get training pair data by index.

Parameters:

pair_idx – Index of the training pair to retrieve

Returns:

Tuple of (input_grid, output_grid, input_mask, output_mask)

get_grid_shape() tuple[int, int][source]

Get the grid dimensions for this task.

Returns:

Tuple of (height, width) for the grid dimensions

get_max_test_pairs() int[source]

Get the maximum number of test pairs this task can hold.

Returns:

Maximum number of test pairs (array dimension)

get_max_train_pairs() int[source]

Get the maximum number of training pairs this task can hold.

Returns:

Maximum number of training pairs (array dimension)

get_task_id() str | None[source]

Get the task ID for this task.

This is a convenience method that looks up the task ID from the global task manager using the stored task_index.

Note: This method is NOT JAX-compatible and should not be used within JAX transformations (jit, vmap, etc.). Use only for debugging, logging, visualization, or other non-JAX code.

Returns:

String task ID if found in the global task manager, None otherwise

Example

`python task = parser.get_task_by_id("some_task") task_id = task.get_task_id()  # Returns "some_task" `

get_task_summary() dict[source]

Get a summary of task information.

Returns:

Dictionary containing task metadata

get_test_input_grid(pair_idx: int) Grid[source]

Get test input grid at given index.

get_test_output_grid(pair_idx: int) Grid[source]

Get test output grid at given index.

get_test_pair(pair_idx: int) TaskPair[source]

Get test pair at given index.

get_test_pair_data(pair_idx: int) tuple[Int[Array, '*batch height width'], Bool[Array, '*batch height width']][source]

Get test pair input data by index (no target during evaluation).

Parameters:

pair_idx – Index of the test pair to retrieve

Returns:

Tuple of (input_grid, input_mask)

get_train_input_grid(pair_idx: int) Grid[source]

Get training input grid at given index.

get_train_output_grid(pair_idx: int) Grid[source]

Get training output grid at given index.

get_train_pair(pair_idx: int) TaskPair[source]

Get training pair at given index.

input_grids_examples: Int[Array, 'max_pairs max_height max_width']
input_masks_examples: Bool[Array, 'max_pairs max_height max_width']
is_demo_pair_available(pair_idx: int) Array[source]

Check if a specific demonstration pair is available.

Parameters:

pair_idx – Index of the demonstration pair to check

Returns:

JAX boolean scalar array indicating if the pair is available

is_test_pair_available(pair_idx: int) Array[source]

Check if a specific test pair is available.

Parameters:

pair_idx – Index of the test pair to check

Returns:

JAX boolean scalar array indicating if the pair is available

num_test_pairs: int
num_train_pairs: int
output_grids_examples: Int[Array, 'max_pairs max_height max_width']
output_masks_examples: Bool[Array, 'max_pairs max_height max_width']
task_index: Int[Array, '']
test_input_grids: Int[Array, 'max_pairs max_height max_width']
test_input_masks: Bool[Array, 'max_pairs max_height max_width']
true_test_output_grids: Int[Array, 'max_pairs max_height max_width']
true_test_output_masks: Bool[Array, 'max_pairs max_height max_width']
jaxarc.types.MaskArray

Boolean array representing valid/invalid cells in grid(s).

jaxarc.types.ObservationArray

Integer array representing observation(s) from the environment.

jaxarc.types.OperationId

Scalar integer representing an ARC operation (0-34).

jaxarc.types.OperationMask

Boolean mask indicating which operations are currently allowed.

jaxarc.types.PRNGKey

JAX PRNG key array with shape (2,).

jaxarc.types.PairIndex

Scalar integer representing current demonstration/test pair.

jaxarc.types.RewardValue

Float array representing reward value(s).

jaxarc.types.SelectionArray

Boolean array representing selected cells for operations.

jaxarc.types.SimilarityScore

Float array representing grid similarity score(s).

jaxarc.types.StepCount

Scalar integer representing current step count.

jaxarc.types.TaskIndex

Scalar integer representing task identifier.

jaxarc.types.TaskInputGrids

Training/test input grids padded to maximum dimensions.

jaxarc.types.TaskInputMasks

Training/test input masks padded to maximum dimensions.

jaxarc.types.TaskOutputGrids

Training/test output grids padded to maximum dimensions.

jaxarc.types.TaskOutputMasks

Training/test output masks padded to maximum dimensions.

class jaxarc.types.TaskPair(input_grid: Grid, output_grid: Grid)[source]

Bases: Module

Represents a single input-output pair in an ARC task using Equinox Module.

input_grid

Input grid for this pair

Type:

jaxarc.types.Grid

output_grid

Expected output grid for this pair

Type:

jaxarc.types.Grid

input_grid: Grid
output_grid: Grid

Usage Example

import jax
from jaxarc import make, EnvParams, TimeStep

# Create environment
env, env_params = make("Mini")

# env_params is of type EnvParams
assert isinstance(env_params, EnvParams)

# Reset returns State and TimeStep
key = jax.random.PRNGKey(42)
state, timestep = env.reset(key, env_params)

# timestep is of type TimeStep
assert isinstance(timestep, TimeStep)

# TimeStep contains:
# - observation: jax.Array
# - reward: float
# - discount: float
# - step_type: StepType (FIRST, MID, LAST)
# - extras: dict

print(f"Observation shape: {timestep.observation.shape}")
print(f"Reward: {timestep.reward}")
print(f"Done: {timestep.step_type.last()}")

JAX Typing

JaxARC uses jaxtyping for array shape annotations:

from jaxtyping import Array, Float, Int, Bool

# Example type hints
observation: Float[Array, "height width channels"]
selection: Bool[Array, "height width"]
operation: Int[Array, ""]