Types API¶
Type definitions for JaxARC environment parameters and timesteps.
Module Contents¶
Type definitions for the JaxARC project.
This module contains all the core data structures used throughout the project, including grid representations, task data, agent states, and environment states. All types are designed to be JAX-compatible with proper validation and JAXTyping annotations.
This module also provides the core JAX array type aliases using JAXTyping for the JaxARC environment.
Key Features: - Core grid and mask array types with batch support - Action space type definitions - Task data structure types - Environment state types - Essential utility types
JAXTyping *batch modifier allows the same type to work for both single arrays (height, width) and batched arrays (batch1, batch2, …, height, width).
- class jaxarc.types.ARCOperationType[source]¶
Bases:
objectARC operation types (grid + submit only).
Pair control operations (35-41) have been removed to simplify the action space. Remaining valid operation IDs: 0-34.
- CLEAR = 31¶
- COPY = 28¶
- COPY_INPUT = 32¶
- CUT = 30¶
- FILL_0 = 0¶
- FILL_1 = 1¶
- FILL_2 = 2¶
- FILL_3 = 3¶
- FILL_4 = 4¶
- FILL_5 = 5¶
- FILL_6 = 6¶
- FILL_7 = 7¶
- FILL_8 = 8¶
- FILL_9 = 9¶
- FLIP_HORIZONTAL = 26¶
- FLIP_VERTICAL = 27¶
- FLOOD_FILL_0 = 10¶
- FLOOD_FILL_1 = 11¶
- FLOOD_FILL_2 = 12¶
- FLOOD_FILL_3 = 13¶
- FLOOD_FILL_4 = 14¶
- FLOOD_FILL_5 = 15¶
- FLOOD_FILL_6 = 16¶
- FLOOD_FILL_7 = 17¶
- FLOOD_FILL_8 = 18¶
- FLOOD_FILL_9 = 19¶
- MOVE_DOWN = 21¶
- MOVE_LEFT = 22¶
- MOVE_RIGHT = 23¶
- MOVE_UP = 20¶
- PASTE = 29¶
- RESIZE = 33¶
- ROTATE_C = 24¶
- ROTATE_CC = 25¶
- SUBMIT = 34¶
- jaxarc.types.ColorValue¶
Scalar integer representing a color value (0-9).
- jaxarc.types.DiscountValue¶
Float array representing discount value(s).
- class jaxarc.types.EnvParams(dataset: DatasetConfig, action: ActionConfig, reward: RewardConfig, grid_initialization: GridInitializationConfig, max_episode_steps: int, buffer: Any = None, subset_indices: Any = None, episode_mode: int = 0)[source]¶
Bases:
ModuleClean environment parameters - only what’s needed for environment logic.
This is NOT a rename of JaxArcConfig. JaxArcConfig contains framework concerns (logging, visualization, storage) that don’t belong in environment parameters.
EnvParams now carries a JAX-native task buffer for JIT/vmap-compatible reset(). The buffer is a stacked pytree of JAX arrays (batched JaxArcTask fields) and optional subset indices define a view into the buffer.
- action: ActionConfig¶
- dataset: DatasetConfig¶
- classmethod from_config(config: JaxArcConfig, *, episode_mode: int = 0, buffer: Any = None, subset_indices: Any = None) EnvParams[source]¶
Extract environment parameters from the unified JaxArcConfig.
- Parameters:
config – Full project configuration
episode_mode – 0=train, 1=test
buffer – Batched pytree of JAX arrays (stacked JaxArcTask fields)
subset_indices – Optional indices defining a subview into the buffer
- grid_initialization: GridInitializationConfig¶
- reward: RewardConfig¶
- class jaxarc.types.Grid(data: GridArray, mask: MaskArray)[source]¶
Bases:
ModuleRepresents a grid in the ARC challenge using Equinox Module.
Equinox provides better JAX integration with automatic PyTree registration and improved compatibility with JAX transformations.
- data¶
The grid data as a 2D integer array with JAXTyping shape annotation
- Type:
jaxtyping.Int[Array, ‘*batch height width’]
- mask¶
Boolean mask indicating valid cells with JAXTyping shape annotation
- Type:
jaxtyping.Bool[Array, ‘*batch height width’]
- data: Int[Array, '*batch height width']¶
- mask: Bool[Array, '*batch height width']¶
- jaxarc.types.GridArray¶
Integer array representing ARC grid(s) with color values 0-9.
- jaxarc.types.GridHeight¶
Scalar integer representing grid height.
- jaxarc.types.GridWidth¶
Scalar integer representing grid width.
- class jaxarc.types.JaxArcTask(input_grids_examples: TaskInputGrids, input_masks_examples: TaskInputMasks, output_grids_examples: TaskOutputGrids, output_masks_examples: TaskOutputMasks, num_train_pairs: int, test_input_grids: TaskInputGrids, test_input_masks: TaskInputMasks, true_test_output_grids: TaskOutputGrids, true_test_output_masks: TaskOutputMasks, num_test_pairs: int, task_index: TaskIndex)[source]¶
Bases:
ModuleJAX-compatible ARC task data with fixed-size arrays for efficient processing using Equinox Module.
This structure contains all task data with fixed-size arrays padded to maximum dimensions for efficient batch processing and JAX transformations. All arrays now use JAXTyping annotations for better type safety and documentation.
- # Training examples with JAXTyping annotations
- input_grids_examples¶
Training input grids with precise shape annotation
- Type:
jaxtyping.Int[Array, ‘max_pairs max_height max_width’]
- input_masks_examples¶
Masks for training inputs with precise shape annotation
- Type:
jaxtyping.Bool[Array, ‘max_pairs max_height max_width’]
- output_grids_examples¶
Training output grids with precise shape annotation
- Type:
jaxtyping.Int[Array, ‘max_pairs max_height max_width’]
- output_masks_examples¶
Masks for training outputs with precise shape annotation
- Type:
jaxtyping.Bool[Array, ‘max_pairs max_height max_width’]
- # Test examples with JAXTyping annotations
- test_input_grids¶
Test input grids with precise shape annotation
- Type:
jaxtyping.Int[Array, ‘max_pairs max_height max_width’]
- test_input_masks¶
Masks for test inputs with precise shape annotation
- Type:
jaxtyping.Bool[Array, ‘max_pairs max_height max_width’]
- true_test_output_grids¶
True test outputs with precise shape annotation
- Type:
jaxtyping.Int[Array, ‘max_pairs max_height max_width’]
- true_test_output_masks¶
Masks for true test outputs with precise shape annotation
- Type:
jaxtyping.Bool[Array, ‘max_pairs max_height max_width’]
- # Task metadata with JAXTyping annotation
- task_index¶
Integer index for task identification (JAX-compatible scalar)
- Type:
jaxtyping.Int[Array, ‘’]
- get_available_demo_pairs() Bool[Array, ...][source]¶
Get mask of available training pairs.
- Returns:
JAX boolean array indicating which training pairs are available (based on num_train_pairs)
- get_available_test_pairs() Bool[Array, ...][source]¶
Get mask of available test pairs.
- Returns:
JAX boolean array indicating which test pairs are available (based on num_test_pairs)
- get_demo_pair_data(pair_idx: int) tuple[Int[Array, '*batch height width'], Int[Array, '*batch height width'], Bool[Array, '*batch height width'], Bool[Array, '*batch height width']][source]¶
Get training pair data by index.
- Parameters:
pair_idx – Index of the training pair to retrieve
- Returns:
Tuple of (input_grid, output_grid, input_mask, output_mask)
- get_grid_shape() tuple[int, int][source]¶
Get the grid dimensions for this task.
- Returns:
Tuple of (height, width) for the grid dimensions
- get_max_test_pairs() int[source]¶
Get the maximum number of test pairs this task can hold.
- Returns:
Maximum number of test pairs (array dimension)
- get_max_train_pairs() int[source]¶
Get the maximum number of training pairs this task can hold.
- Returns:
Maximum number of training pairs (array dimension)
- get_task_id() str | None[source]¶
Get the task ID for this task.
This is a convenience method that looks up the task ID from the global task manager using the stored task_index.
Note: This method is NOT JAX-compatible and should not be used within JAX transformations (jit, vmap, etc.). Use only for debugging, logging, visualization, or other non-JAX code.
- Returns:
String task ID if found in the global task manager, None otherwise
Example
`python task = parser.get_task_by_id("some_task") task_id = task.get_task_id() # Returns "some_task" `
- get_task_summary() dict[source]¶
Get a summary of task information.
- Returns:
Dictionary containing task metadata
- get_test_pair_data(pair_idx: int) tuple[Int[Array, '*batch height width'], Bool[Array, '*batch height width']][source]¶
Get test pair input data by index (no target during evaluation).
- Parameters:
pair_idx – Index of the test pair to retrieve
- Returns:
Tuple of (input_grid, input_mask)
- input_grids_examples: Int[Array, 'max_pairs max_height max_width']¶
- input_masks_examples: Bool[Array, 'max_pairs max_height max_width']¶
- is_demo_pair_available(pair_idx: int) Array[source]¶
Check if a specific demonstration pair is available.
- Parameters:
pair_idx – Index of the demonstration pair to check
- Returns:
JAX boolean scalar array indicating if the pair is available
- is_test_pair_available(pair_idx: int) Array[source]¶
Check if a specific test pair is available.
- Parameters:
pair_idx – Index of the test pair to check
- Returns:
JAX boolean scalar array indicating if the pair is available
- output_grids_examples: Int[Array, 'max_pairs max_height max_width']¶
- output_masks_examples: Bool[Array, 'max_pairs max_height max_width']¶
- task_index: Int[Array, '']¶
- test_input_grids: Int[Array, 'max_pairs max_height max_width']¶
- test_input_masks: Bool[Array, 'max_pairs max_height max_width']¶
- true_test_output_grids: Int[Array, 'max_pairs max_height max_width']¶
- true_test_output_masks: Bool[Array, 'max_pairs max_height max_width']¶
- jaxarc.types.MaskArray¶
Boolean array representing valid/invalid cells in grid(s).
- jaxarc.types.ObservationArray¶
Integer array representing observation(s) from the environment.
- jaxarc.types.OperationId¶
Scalar integer representing an ARC operation (0-34).
- jaxarc.types.OperationMask¶
Boolean mask indicating which operations are currently allowed.
- jaxarc.types.PRNGKey¶
JAX PRNG key array with shape (2,).
- jaxarc.types.PairIndex¶
Scalar integer representing current demonstration/test pair.
- jaxarc.types.RewardValue¶
Float array representing reward value(s).
- jaxarc.types.SelectionArray¶
Boolean array representing selected cells for operations.
- jaxarc.types.SimilarityScore¶
Float array representing grid similarity score(s).
- jaxarc.types.StepCount¶
Scalar integer representing current step count.
- jaxarc.types.TaskIndex¶
Scalar integer representing task identifier.
- jaxarc.types.TaskInputGrids¶
Training/test input grids padded to maximum dimensions.
- jaxarc.types.TaskInputMasks¶
Training/test input masks padded to maximum dimensions.
- jaxarc.types.TaskOutputGrids¶
Training/test output grids padded to maximum dimensions.
- jaxarc.types.TaskOutputMasks¶
Training/test output masks padded to maximum dimensions.
Usage Example¶
import jax
from jaxarc import make, EnvParams, TimeStep
# Create environment
env, env_params = make("Mini")
# env_params is of type EnvParams
assert isinstance(env_params, EnvParams)
# Reset returns State and TimeStep
key = jax.random.PRNGKey(42)
state, timestep = env.reset(key, env_params)
# timestep is of type TimeStep
assert isinstance(timestep, TimeStep)
# TimeStep contains:
# - observation: jax.Array
# - reward: float
# - discount: float
# - step_type: StepType (FIRST, MID, LAST)
# - extras: dict
print(f"Observation shape: {timestep.observation.shape}")
print(f"Reward: {timestep.reward}")
print(f"Done: {timestep.step_type.last()}")
JAX Typing¶
JaxARC uses jaxtyping for array shape annotations:
from jaxtyping import Array, Float, Int, Bool
# Example type hints
observation: Float[Array, "height width channels"]
selection: Bool[Array, "height width"]
operation: Int[Array, ""]