Source code for jaxarc.wrappers.visualization_wrapper

"""
Visualization wrapper for JaxARC environments.
"""

from __future__ import annotations

from typing import Any

import jax.numpy as jnp
from stoa.core_wrappers.wrapper import Wrapper

from jaxarc.state import State
from jaxarc.types import EnvParams, Grid
from jaxarc.utils.task_manager import get_task_id_globally
from jaxarc.utils.visualization.core import detect_changed_cells
from jaxarc.utils.visualization.rl_display import (
    draw_rl_step_svg_enhanced,
    get_operation_display_name,
)


[docs] class StepVisualizationWrapper(Wrapper): """ Wrapper that enables detailed step visualization by tracking transition history. Enables `env.render(mode="detailed")` which returns a rich SVG of the last transition. """ def __init__(self, env): super().__init__(env) self._last_transition: tuple[Any, Any, Any, Any] | None = None
[docs] def reset( # type: ignore[override] self, key, env_params: EnvParams | None = None ) -> tuple[Any, Any]: # Cast env to Any to avoid type checking against stoa.Environment env: Any = self._env state, timestep = env.reset(key, env_params) self._last_transition = None return state, timestep
[docs] def step( # type: ignore[override] self, state: State, action: Any, env_params: EnvParams | None = None ) -> tuple[Any, Any]: # Cast env to Any to avoid type checking against stoa.Environment env: Any = self._env next_state, timestep = env.step(state, action, env_params) # Cache the transition self._last_transition = (state, action, next_state, timestep) return next_state, timestep
[docs] def render(self, state: State, mode: str | None = None) -> Any: # type: ignore[override] if mode == "detailed": return self._render_detailed() # Cast env to Any to avoid type checking against stoa.Environment env: Any = self._env return env.render(state, mode)
def _render_detailed(self) -> str: if self._last_transition is None: return ( '<svg width="200" height="50">' '<text x="10" y="30" font-family="sans-serif" fill="red">' "No transition available (call step() first)" "</text></svg>" ) prev_state, action, next_state, timestep = self._last_transition # Construct Grid objects (using jnp.array to satisfy Equinox checks) before_grid = Grid( data=jnp.array(prev_state.working_grid), mask=jnp.array(prev_state.working_grid_mask), ) after_grid = Grid( data=jnp.array(next_state.working_grid), mask=jnp.array(next_state.working_grid_mask), ) # Extract info from timestep extras info = timestep.extras if timestep.extras is not None else {} # Get operation name op_id = -1 if "operation_id" in info: # Handle JAX array scalar val = info["operation_id"] op_id = int(val) elif hasattr(action, "operation"): val = action.operation op_id = int(val) elif isinstance(action, dict) and "operation" in action: val = action["operation"] op_id = int(val) op_name = get_operation_display_name(op_id) if op_id >= 0 else "Unknown" # Calculate changed cells changed_cells = detect_changed_cells(before_grid, after_grid) # Task ID task_idx = int(next_state.task_idx) task_name = get_task_id_globally(task_idx) task_id_str = task_name if task_name else f"Task {task_idx}" return draw_rl_step_svg_enhanced( before_grid=before_grid, after_grid=after_grid, action=action, reward=float(timestep.reward), info=info, step_num=int(next_state.step_count), operation_name=op_name, changed_cells=changed_cells, task_id=task_id_str, task_pair_index=int(next_state.pair_idx), )