Source code for jaxarc.envs.environment

"""
JaxARC environment following Stoa API patterns.

Concrete implementation that delegates to functional.py with Stoa-compatible interface.
"""

from __future__ import annotations

from typing import TYPE_CHECKING, Any

import jax
import jax.numpy as jnp
import stoa.environment

from jaxarc.configs.main_config import JaxArcConfig
from jaxarc.envs.actions import Action, create_action
from jaxarc.envs.spaces import ARCActionSpace, BoundedArraySpace, DictSpace, GridSpace
from jaxarc.types import EnvParams, TimeStep
from jaxarc.utils.visualization.core import draw_grid_svg, render_rgb
from jaxarc.utils.visualization.display import render_ansi

from .functional import reset as functional_reset
from .functional import step as functional_step

if TYPE_CHECKING:
    from jaxarc.types import State


[docs] class Environment(stoa.environment.Environment): """ JaxARC environment implementing Stoa API patterns. Delegates to functional API while providing clean object-oriented interface. """ def __init__( self, config: JaxArcConfig, buffer: Any, episode_mode: int = 0, subset_indices: Any | None = None, ): self.params = EnvParams.from_config( config=config, episode_mode=episode_mode, buffer=buffer, subset_indices=subset_indices, ) self.render_mode = config.environment.render_mode
[docs] def observation_shape(self) -> tuple[int, int, int]: """Get observation shape.""" # Config values extracted as Python ints (not traced — pre-JIT) return ( int(self.params.dataset.max_grid_height), int(self.params.dataset.max_grid_width), 1, )
[docs] def reset( self, rng_key: jax.Array, env_params: EnvParams | None = None ) -> tuple[State, TimeStep]: """Reset using functional API (supports optional per-call params override).""" p = self.params if env_params is None else env_params state, timestep = functional_reset(p, rng_key) # Adding canconical action and operation_id to extras for logging/visualization # Ensure a stable, JAX-friendly extras schema for JIT/scan compatibility height = int(p.dataset.max_grid_height) width = int(p.dataset.max_grid_width) zero_sel = jnp.zeros((height, width), dtype=jnp.bool_) op_sentinel = jnp.array(-1, dtype=jnp.int32) extras = dict(timestep.extras) if timestep.extras is not None else {} # Canonical action present with static shapes; values are JAX arrays extras.setdefault( "canonical_action", { "operation": op_sentinel, "selection": zero_sel, }, ) # Convenience scalar mirrors canonical_action.operation extras.setdefault("operation_id", op_sentinel) # Rebuild timestep with enriched extras new_timestep = TimeStep( step_type=timestep.step_type, reward=timestep.reward, discount=timestep.discount, observation=timestep.observation, extras=extras, ) return state, new_timestep
[docs] def step( self, state: State, action: Action | dict, env_params: EnvParams | None = None ) -> tuple[State, TimeStep]: """Step using functional API (supports optional per-call params override).""" p = self.params if env_params is None else env_params # Accept canonical dict-form mask actions from ARCActionSpace and convert to internal Action if ( isinstance(action, dict) and ("operation" in action) and ("selection" in action) ): op = jnp.asarray(action["operation"], dtype=jnp.int32) sel = jnp.asarray(action["selection"], dtype=jnp.bool_) action = create_action(op, sel) next_state, timestep = functional_step(p, state, action) # Populate a stable, JAX-friendly extras schema based on the actual mask action height = int(p.dataset.max_grid_height) width = int(p.dataset.max_grid_width) zero_sel = jnp.zeros((height, width), dtype=jnp.bool_) op_sentinel = jnp.array(-1, dtype=jnp.int32) extras = dict(timestep.extras) if timestep.extras is not None else {} # Derive canonical mask-based action: operation and selection must be JAX arrays op_attr = getattr(action, "operation", None) sel_attr = getattr(action, "selection", None) op = op_sentinel if op_attr is None else jnp.asarray(op_attr, dtype=jnp.int32) sel = zero_sel if sel_attr is None else jnp.asarray(sel_attr, dtype=jnp.bool_) extras["canonical_action"] = {"operation": op, "selection": sel} extras["operation_id"] = op new_timestep = TimeStep( step_type=timestep.step_type, reward=timestep.reward, discount=timestep.discount, observation=timestep.observation, extras=extras, ) return next_state, new_timestep
[docs] def state_space(self, _env_params: EnvParams | None = None) -> stoa.spaces.Space: """Return the state space of the environment.""" height, width = self.observation_shape() return DictSpace( { "working_grid": GridSpace(max_height=height, max_width=width), "working_grid_mask": BoundedArraySpace( shape=(height, width), dtype=jnp.bool_, minimum=False, maximum=True ), "input_grid": GridSpace(max_height=height, max_width=width), "input_grid_mask": BoundedArraySpace( shape=(height, width), dtype=jnp.bool_, minimum=False, maximum=True ), "target_grid": GridSpace(max_height=height, max_width=width), "target_grid_mask": BoundedArraySpace( shape=(height, width), dtype=jnp.bool_, minimum=False, maximum=True ), "selected": BoundedArraySpace( shape=(height, width), dtype=jnp.bool_, minimum=False, maximum=True ), "clipboard": GridSpace(max_height=height, max_width=width), "step_count": BoundedArraySpace( shape=(), dtype=jnp.int32, minimum=0, maximum=self.params.max_episode_steps, ), "task_idx": BoundedArraySpace( shape=(), dtype=jnp.int32, minimum=0, maximum=int(jnp.iinfo(jnp.int32).max), ), "pair_idx": BoundedArraySpace( shape=(), dtype=jnp.int32, minimum=0, maximum=int(jnp.iinfo(jnp.int32).max), ), "allowed_operations_mask": BoundedArraySpace( shape=(35,), dtype=jnp.bool_, minimum=False, maximum=True ), "similarity_score": BoundedArraySpace( shape=(), dtype=jnp.float32, minimum=0.0, maximum=1.0 ), "key": BoundedArraySpace( shape=(2,), dtype=jnp.uint32, minimum=0, maximum=int(jnp.iinfo(jnp.uint32).max), ), } )
[docs] def observation_space( self, _env_params: EnvParams | None = None ) -> BoundedArraySpace: """Get ARC observation space.""" height, width, channels = self.observation_shape() return BoundedArraySpace( shape=(height, width, channels), dtype=jnp.int32, minimum=0, maximum=self.params.dataset.max_colors, name="observation", )
[docs] def action_space(self, _env_params: EnvParams | None = None) -> ARCActionSpace: """Get ARC action space.""" height, width, _ = self.observation_shape() return ARCActionSpace(max_height=height, max_width=width)
[docs] def reward_space(self, _env_params: EnvParams | None = None) -> BoundedArraySpace: """Get reward space.""" return BoundedArraySpace( shape=(), dtype=jax.numpy.float32, minimum=0.0, maximum=1.0 )
[docs] def discount_space(self, _env_params: EnvParams | None = None) -> BoundedArraySpace: """Get discount space.""" return BoundedArraySpace( shape=(), dtype=jax.numpy.float32, minimum=0.0, maximum=1.0 )
@property def unwrapped(self) -> Environment: """Get the unwrapped environment.""" return self
[docs] def close(self) -> None: """Close the environment.""" return
[docs] def render(self, state: State, mode: str | None = None) -> Any: """ Render the environment state. Args: state: The current environment state. mode: The rendering mode ("rgb_array", "ansi", "svg"). If None, uses the default mode from configuration. Returns: The rendered output (numpy array, string, or SVG string). """ # Determine render mode render_mode = mode if mode is not None else self.render_mode # Dispatch to appropriate renderer if render_mode == "rgb_array": return render_rgb(state.working_grid) if render_mode == "ansi": return render_ansi(state.working_grid) if render_mode == "svg": drawing = draw_grid_svg( state.working_grid, state.working_grid_mask, label=f"Step {int(state.step_count)}", show_size=True, ) # Ensure we have a Drawing object (default behavior of draw_grid_svg) if isinstance(drawing, tuple): drawing = drawing[0] return drawing.as_svg() raise ValueError(f"Unsupported render mode: {render_mode}")
__all__ = ["Environment"]