Source code for jaxarc.state

"""
Centralized environment state definition using Equinox.

This module defines the simplified `State` used throughout JaxARC.
Static configuration has been removed from state and moved to EnvParams.

Key properties:
- Equinox Module for automatic PyTree registration
- JAXTyping annotations for precise type safety
- Purely dynamic fields that change during episodes
- JAX transformation compatibility (jit, vmap, pmap)
"""

from __future__ import annotations

import chex
import equinox as eqx
import jax.numpy as jnp

from jaxarc.types import (
    GridArray,
    MaskArray,
    OperationMask,
    PairIndex,
    PRNGKey,
    SelectionArray,
    SimilarityScore,
    StepCount,
    TaskIndex,
)


[docs] class State(eqx.Module): """Environment state. Contains only truly dynamic variables that change during episodes. Static configuration is moved to EnvParams. """ # Core dynamic grid state working_grid: GridArray # Current grid being modified working_grid_mask: MaskArray # Valid cells mask input_grid: GridArray # Original input grid for current pair input_grid_mask: MaskArray # Valid cells mask for input grid target_grid: GridArray # Goal grid for current example target_grid_mask: MaskArray # Valid cells mask for target grid # Grid operations state selected: SelectionArray # Selection mask for operations clipboard: GridArray # For copy/paste operations # Episode progress tracking step_count: StepCount # Current step number # Dynamic control state allowed_operations_mask: OperationMask # Dynamic operation filtering # Similarity tracking score (required array) similarity_score: SimilarityScore # PRNG key for environment randomness (auto-reset compatibility) key: PRNGKey # Task/pair tracking (link into EnvParams.buffer) task_idx: TaskIndex # Index into EnvParams.buffer identifying active task pair_idx: PairIndex # Index of current demonstration/test pair within task def __check_init__(self) -> None: """Validate dynamic state structure.""" # During tracing, attributes may be placeholders; guard accordingly if not hasattr(self.working_grid, "shape"): return try: # Validate grid ranks chex.assert_rank(self.working_grid, 2) chex.assert_rank(self.working_grid_mask, 2) chex.assert_rank(self.input_grid, 2) chex.assert_rank(self.input_grid_mask, 2) chex.assert_rank(self.target_grid, 2) chex.assert_rank(self.target_grid_mask, 2) chex.assert_rank(self.selected, 2) chex.assert_rank(self.clipboard, 2) chex.assert_rank(self.allowed_operations_mask, 1) # Validate consistent shapes chex.assert_shape(self.working_grid_mask, self.working_grid.shape) chex.assert_shape(self.input_grid, self.working_grid.shape) chex.assert_shape(self.input_grid_mask, self.working_grid.shape) chex.assert_shape(self.target_grid, self.working_grid.shape) chex.assert_shape(self.target_grid_mask, self.working_grid.shape) chex.assert_shape(self.selected, self.working_grid.shape) chex.assert_shape(self.clipboard, self.working_grid.shape) # Validate scalars/types chex.assert_shape(self.step_count, ()) chex.assert_type(self.step_count, jnp.integer) # Task/pair indices should be scalar integer arrays chex.assert_shape(self.task_idx, ()) chex.assert_type(self.task_idx, jnp.int32) chex.assert_shape(self.pair_idx, ()) chex.assert_type(self.pair_idx, jnp.int32) except (AttributeError, TypeError): # Gracefully skip during tracing pass