"""
RL-specific visualization and episode management.
Handles training visualization, episode tracking, and RL metrics display.
"""
from __future__ import annotations
import json
import os
import shutil
from datetime import datetime
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional
import chex
import jax.numpy as jnp
import numpy as np
from loguru import logger
from jaxarc.envs.actions import Action
from jaxarc.envs.grid_operations import get_operation_display_text
from .core import (
ARC_COLOR_PALETTE,
_extract_grid_data,
_extract_valid_region,
add_change_highlighting,
add_selection_visualization_overlay,
get_info_metric,
)
if TYPE_CHECKING:
from jaxarc.types import Grid
# ============================================================================
# SECTION: Episode Management (from episode_manager.py)
# ============================================================================
[docs]
@chex.dataclass
class EpisodeConfig:
"""Configuration for episode management and storage.
This dataclass defines all settings for organizing and managing
episode-based visualization storage with validation and serialization.
"""
# Directory structure settings
base_output_dir: str = "outputs/episodes"
run_name: str | None = None # Auto-generated if None
episode_dir_format: str = "episode_{episode:04d}"
step_file_format: str = "step_{step:03d}"
# Storage limits and policies
max_episodes_per_run: int = 1000
cleanup_policy: Literal["oldest_first", "size_based", "manual"] = "size_based"
max_storage_gb: float = 10.0
# File management settings
create_run_subdirs: bool = True
preserve_empty_dirs: bool = False
compress_old_episodes: bool = False
def __post_init__(self) -> None:
"""Validate configuration after initialization."""
self._validate_config()
def _validate_config(self) -> None:
"""Validate all configuration parameters.
Raises:
ValueError: If any configuration parameter is invalid
"""
# Validate directory paths
if not self.base_output_dir or not isinstance(self.base_output_dir, str):
raise ValueError("base_output_dir must be a non-empty string")
# Validate format strings
try:
self.episode_dir_format.format(episode=1)
except (KeyError, ValueError) as e:
raise ValueError(f"Invalid episode_dir_format: {e}") from e
try:
self.step_file_format.format(step=1)
except (KeyError, ValueError) as e:
raise ValueError(f"Invalid step_file_format: {e}") from e
# Validate numeric limits
if self.max_episodes_per_run <= 0:
raise ValueError("max_episodes_per_run must be positive")
if self.max_storage_gb <= 0:
raise ValueError("max_storage_gb must be positive")
# Validate cleanup policy
valid_policies = {"oldest_first", "size_based", "manual"}
if self.cleanup_policy not in valid_policies:
raise ValueError(f"cleanup_policy must be one of {valid_policies}")
# Validate run_name if provided
if self.run_name is not None:
if not isinstance(self.run_name, str) or not self.run_name.strip():
raise ValueError("run_name must be a non-empty string if provided")
# Check for invalid characters in run_name
invalid_chars = set('<>:"/\\|?*')
if any(char in self.run_name for char in invalid_chars):
raise ValueError(
f"run_name contains invalid characters: {invalid_chars}"
)
[docs]
def to_dict(self) -> dict[str, Any]:
"""Convert configuration to dictionary for serialization.
Returns:
Dictionary representation of the configuration
"""
return {
"base_output_dir": self.base_output_dir,
"run_name": self.run_name,
"episode_dir_format": self.episode_dir_format,
"step_file_format": self.step_file_format,
"max_episodes_per_run": self.max_episodes_per_run,
"cleanup_policy": self.cleanup_policy,
"max_storage_gb": self.max_storage_gb,
"create_run_subdirs": self.create_run_subdirs,
"preserve_empty_dirs": self.preserve_empty_dirs,
"compress_old_episodes": self.compress_old_episodes,
}
[docs]
@classmethod
def from_dict(cls, data: dict[str, Any]) -> EpisodeConfig:
"""Create configuration from dictionary.
Args:
data: Dictionary containing configuration parameters
Returns:
EpisodeConfig instance
Raises:
ValueError: If required keys are missing or invalid
"""
# Extract known fields, ignoring unknown ones for forward compatibility
known_fields = {
"base_output_dir",
"run_name",
"episode_dir_format",
"step_file_format",
"max_episodes_per_run",
"cleanup_policy",
"max_storage_gb",
"create_run_subdirs",
"preserve_empty_dirs",
"compress_old_episodes",
}
filtered_data = {k: v for k, v in data.items() if k in known_fields}
try:
return cls(**filtered_data)
except TypeError as e:
raise ValueError(f"Invalid configuration data: {e}") from e
[docs]
def save_to_file(self, file_path: Path | str) -> None:
"""Save configuration to JSON file.
Args:
file_path: Path where to save the configuration
Raises:
OSError: If file cannot be written
"""
file_path = Path(file_path)
# Ensure parent directory exists
file_path.parent.mkdir(parents=True, exist_ok=True)
try:
with file_path.open("w", encoding="utf-8") as f:
json.dump(self.to_dict(), f, indent=2, sort_keys=True)
except OSError as e:
logger.error(f"Failed to save config to {file_path}: {e}")
raise
[docs]
@classmethod
def load_from_file(cls, file_path: Path | str) -> EpisodeConfig:
"""Load configuration from JSON file.
Args:
file_path: Path to the configuration file
Returns:
EpisodeConfig instance
Raises:
FileNotFoundError: If file doesn't exist
ValueError: If file contains invalid configuration
"""
file_path = Path(file_path)
if not file_path.exists():
raise FileNotFoundError(f"Configuration file not found: {file_path}")
try:
with file_path.open("r", encoding="utf-8") as f:
data = json.load(f)
except (OSError, json.JSONDecodeError) as e:
raise ValueError(f"Failed to load config from {file_path}: {e}") from e
return cls.from_dict(data)
[docs]
def get_base_path(self) -> Path:
"""Get the base output directory as a Path object.
Returns:
Path object for the base output directory
"""
return Path(self.base_output_dir).expanduser().resolve()
[docs]
def generate_run_name(self) -> str:
"""Generate a timestamped run name if none is provided.
Returns:
Generated run name with timestamp
"""
if self.run_name is not None:
return self.run_name
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
return f"run_{timestamp}"
[docs]
def validate_storage_path(self, path: Path) -> bool:
"""Validate that a storage path is accessible and writable.
Args:
path: Path to validate
Returns:
True if path is valid and writable
"""
try:
# Check if path exists or can be created
path.mkdir(parents=True, exist_ok=True)
# Test write permissions
test_file = path / ".write_test"
test_file.write_text("test", encoding="utf-8")
test_file.unlink()
return True
except (OSError, PermissionError):
return False
[docs]
def estimate_storage_usage(self, path: Path) -> float:
"""Estimate storage usage in GB for a given path.
Args:
path: Path to analyze
Returns:
Storage usage in GB
"""
if not path.exists():
return 0.0
total_size = 0
try:
for dirpath, dirnames, filenames in os.walk(path):
for filename in filenames:
file_path = Path(dirpath) / filename
try:
total_size += file_path.stat().st_size
except (OSError, FileNotFoundError):
# Skip files that can't be accessed
continue
except (OSError, PermissionError):
logger.warning(f"Could not access some files in {path}")
return total_size / (1024**3) # Convert bytes to GB
[docs]
class EpisodeManager:
"""Manages episode-based storage and organization.
This class handles directory creation, file organization, and cleanup
for episode-based visualization data storage.
"""
def __init__(self, config: EpisodeConfig):
"""Initialize episode manager with configuration.
Args:
config: Episode configuration settings
"""
self.config = config
self.current_run_dir: Path | None = None
self.current_episode_dir: Path | None = None
self.current_run_name: str | None = None
self.current_episode_num: int | None = None
# Validate base directory on initialization
base_path = self.config.get_base_path()
if not self.config.validate_storage_path(base_path):
raise ValueError(f"Cannot access or write to base directory: {base_path}")
[docs]
def start_new_run(self, run_name: str | None = None) -> Path:
"""Start a new training run with timestamped directory.
Args:
run_name: Optional custom run name. If None, uses config or generates one.
Returns:
Path to the created run directory
Raises:
OSError: If directory cannot be created
ValueError: If run_name is invalid
"""
# Use provided name, config name, or generate one
if run_name is not None:
if not isinstance(run_name, str) or not run_name.strip():
raise ValueError("run_name must be a non-empty string")
self.current_run_name = run_name.strip()
else:
self.current_run_name = self.config.generate_run_name()
# Create run directory
base_path = self.config.get_base_path()
self.current_run_dir = base_path / self.current_run_name
try:
self.current_run_dir.mkdir(parents=True, exist_ok=True)
except OSError as e:
logger.error(f"Failed to create run directory {self.current_run_dir}: {e}")
raise
# Save configuration to run directory
config_path = self.current_run_dir / "episode_config.json"
self.config.save_to_file(config_path)
# Reset episode tracking
self.current_episode_dir = None
self.current_episode_num = None
logger.info(
f"Started new run: {self.current_run_name} at {self.current_run_dir}"
)
return self.current_run_dir
[docs]
def start_new_episode(self, episode_num: int) -> Path:
"""Start a new episode within the current run.
Args:
episode_num: Episode number (must be non-negative)
Returns:
Path to the created episode directory
Raises:
ValueError: If no run is active or episode_num is invalid
OSError: If directory cannot be created
"""
if self.current_run_dir is None:
raise ValueError("No active run. Call start_new_run() first.")
if episode_num < 0:
raise ValueError("episode_num must be non-negative")
if episode_num >= self.config.max_episodes_per_run:
raise ValueError(
f"episode_num {episode_num} exceeds max_episodes_per_run {self.config.max_episodes_per_run}"
)
# Create episode directory
episode_dir_name = self.config.episode_dir_format.format(episode=episode_num)
self.current_episode_dir = self.current_run_dir / episode_dir_name
try:
self.current_episode_dir.mkdir(parents=True, exist_ok=True)
except OSError as e:
logger.error(
f"Failed to create episode directory {self.current_episode_dir}: {e}"
)
raise
self.current_episode_num = episode_num
logger.debug(f"Started episode {episode_num} at {self.current_episode_dir}")
return self.current_episode_dir
[docs]
def get_step_path(self, step_num: int, file_type: str = "svg") -> Path:
"""Get file path for a specific step visualization.
Args:
step_num: Step number (must be non-negative)
file_type: File extension (without dot)
Returns:
Path for the step file
Raises:
ValueError: If no episode is active or step_num is invalid
"""
if self.current_episode_dir is None:
raise ValueError("No active episode. Call start_new_episode() first.")
if step_num < 0:
raise ValueError("step_num must be non-negative")
step_filename = self.config.step_file_format.format(step=step_num)
return self.current_episode_dir / f"{step_filename}.{file_type}"
[docs]
def get_episode_summary_path(self, file_type: str = "svg") -> Path:
"""Get file path for episode summary visualization.
Args:
file_type: File extension (without dot)
Returns:
Path for the episode summary file
Raises:
ValueError: If no episode is active
"""
if self.current_episode_dir is None:
raise ValueError("No active episode. Call start_new_episode() first.")
return self.current_episode_dir / f"summary.{file_type}"
[docs]
def get_current_run_info(self) -> dict[str, Any]:
"""Get information about the current run.
Returns:
Dictionary with run information
"""
return {
"run_name": self.current_run_name,
"run_dir": str(self.current_run_dir) if self.current_run_dir else None,
"episode_num": self.current_episode_num,
"episode_dir": str(self.current_episode_dir)
if self.current_episode_dir
else None,
}
[docs]
def list_episodes_in_run(
self, run_dir: Path | None = None
) -> list[tuple[int, Path]]:
"""List all episodes in a run directory.
Args:
run_dir: Run directory to scan. Uses current run if None.
Returns:
List of (episode_number, episode_path) tuples, sorted by episode number
"""
if run_dir is None:
run_dir = self.current_run_dir
if run_dir is None or not run_dir.exists():
return []
episodes = []
for item in run_dir.iterdir():
if item.is_dir():
# Try to extract episode number from directory name
try:
# This is a simple approach - could be made more robust
if item.name.startswith("episode_"):
episode_str = item.name.replace("episode_", "")
episode_num = int(episode_str)
episodes.append((episode_num, item))
except ValueError:
# Skip directories that don't match expected format
continue
return sorted(episodes)
[docs]
def cleanup_old_data(self) -> None:
"""Clean up old data based on configured policy.
This method implements the cleanup policy specified in the configuration
to manage storage usage and maintain the episode limit.
"""
if self.config.cleanup_policy == "manual":
logger.debug("Cleanup policy is manual - skipping automatic cleanup")
return
base_path = self.config.get_base_path()
if not base_path.exists():
return
current_usage = self.config.estimate_storage_usage(base_path)
if current_usage <= self.config.max_storage_gb:
logger.debug(
f"Storage usage {current_usage:.2f}GB is within limit {self.config.max_storage_gb}GB"
)
return
logger.info(
f"Storage usage {current_usage:.2f}GB exceeds limit {self.config.max_storage_gb}GB - starting cleanup"
)
if self.config.cleanup_policy == "oldest_first":
self._cleanup_oldest_first(base_path)
elif self.config.cleanup_policy == "size_based":
self._cleanup_size_based(base_path)
def _cleanup_oldest_first(self, base_path: Path) -> None:
"""Clean up oldest runs first until under storage limit.
Args:
base_path: Base directory to clean up
"""
# Get all run directories with their modification times
runs = []
for item in base_path.iterdir():
if item.is_dir():
try:
mtime = item.stat().st_mtime
runs.append((mtime, item))
except OSError:
continue
# Sort by modification time (oldest first)
runs.sort()
for mtime, run_dir in runs:
current_usage = self.config.estimate_storage_usage(base_path)
if current_usage <= self.config.max_storage_gb:
break
# Don't delete current run
if run_dir == self.current_run_dir:
continue
logger.info(f"Removing old run directory: {run_dir}")
try:
shutil.rmtree(run_dir)
except OSError as e:
logger.error(f"Failed to remove {run_dir}: {e}")
def _cleanup_size_based(self, base_path: Path) -> None:
"""Clean up largest runs first until under storage limit.
Args:
base_path: Base directory to clean up
"""
# Get all run directories with their sizes
runs = []
for item in base_path.iterdir():
if item.is_dir():
try:
size = self.config.estimate_storage_usage(item)
runs.append((size, item))
except OSError:
continue
# Sort by size (largest first)
runs.sort(reverse=True)
for size, run_dir in runs:
current_usage = self.config.estimate_storage_usage(base_path)
if current_usage <= self.config.max_storage_gb:
break
# Don't delete current run
if run_dir == self.current_run_dir:
continue
logger.info(f"Removing large run directory ({size:.2f}GB): {run_dir}")
try:
shutil.rmtree(run_dir)
except OSError as e:
logger.error(f"Failed to remove {run_dir}: {e}")
[docs]
def force_cleanup_run(self, run_name: str) -> bool:
"""Force cleanup of a specific run directory.
Args:
run_name: Name of the run to clean up
Returns:
True if cleanup was successful, False otherwise
"""
base_path = self.config.get_base_path()
run_dir = base_path / run_name
if not run_dir.exists():
logger.warning(f"Run directory does not exist: {run_dir}")
return False
# Don't delete current run
if run_dir == self.current_run_dir:
logger.warning(f"Cannot delete current active run: {run_name}")
return False
try:
shutil.rmtree(run_dir)
logger.info(f"Successfully removed run directory: {run_dir}")
return True
except OSError as e:
logger.error(f"Failed to remove run directory {run_dir}: {e}")
return False
# ============================================================================
# SECTION: RL Visualization (from rl_visualization.py)
# ============================================================================
def get_operation_display_name(
operation_id: int, action_data: Dict[str, Any] = None
) -> str:
"""Get human-readable operation name from operation ID with context."""
return get_operation_display_text(operation_id)
def draw_rl_step_svg_enhanced(
before_grid: Grid,
after_grid: Grid,
action: Any, # Can be Action or dict
reward: float,
info: Dict[str, Any],
step_num: int,
operation_name: str = "",
changed_cells: Optional[jnp.ndarray] = None,
config: Optional[Any] = None,
max_width: float = 1400.0,
max_height: float = 700.0,
task_id: str = "",
task_pair_index: int = 0,
total_task_pairs: int = 1,
) -> str:
"""Generate enhanced SVG visualization of a single RL step with more information.
This enhanced version shows:
- Before and after grids with improved styling
- Action selection highlighting
- Changed cell highlighting
- Reward information and metrics
- Operation name and details
- Step metadata
- Task context information
Args:
before_grid: Grid state before the action
after_grid: Grid state after the action
action: Action object or dictionary
reward: Reward received for this step
info: Additional information dictionary
step_num: Step number in the episode
operation_name: Human-readable operation name
changed_cells: Optional mask of cells that changed
config: Optional visualization configuration
max_width: Maximum width of the entire visualization
max_height: Maximum height of the entire visualization
task_id: Task identifier for context
task_pair_index: Current task pair index
total_task_pairs: Total number of task pairs
Returns:
SVG string containing the enhanced visualization
"""
import drawsvg as draw
# Get color palette from config or use default
if config and hasattr(config, "get_color_palette"):
color_palette = config.get_color_palette()
else:
color_palette = ARC_COLOR_PALETTE
# Layout parameters
top_padding = 100
bottom_padding = 50
side_padding = 50
grid_spacing = 180
grid_max_width = 280
grid_max_height = 280
# Calculate total dimensions
total_width = 2 * grid_max_width + grid_spacing + 2 * side_padding
total_height = grid_max_height + top_padding + bottom_padding
# Create main drawing with background
drawing = draw.Drawing(total_width, total_height)
drawing.append(draw.Rectangle(0, 0, total_width, total_height, fill="#f8f9fa"))
# Add enhanced title with step info
title_text = f"Step {step_num}"
if operation_name:
title_text += f" - {operation_name}"
drawing.append(
draw.Text(
title_text,
font_size=28,
x=total_width / 2,
y=40,
text_anchor="middle",
font_family="Anuphan",
font_weight="600",
fill="#2c3e50",
)
)
# Add task context information
task_context_text = ""
if task_id:
task_context_text = f"Task: {task_id}"
if total_task_pairs > 1:
if task_context_text:
task_context_text += f" | Pair {task_pair_index + 1}/{total_task_pairs}"
else:
task_context_text = f"Pair {task_pair_index + 1}/{total_task_pairs}"
if task_context_text:
drawing.append(
draw.Text(
task_context_text,
font_size=16,
x=total_width / 2,
y=65,
text_anchor="middle",
font_family="Anuphan",
font_weight="400",
fill="#6c757d",
)
)
# Add reward information (adjusted position for task context)
reward_color = "#27ae60" if reward > 0 else "#e74c3c" if reward < 0 else "#95a5a6"
reward_text = f"Reward: {reward:.3f}"
reward_y = 85 if task_context_text else 70
drawing.append(
draw.Text(
reward_text,
font_size=20,
x=total_width / 2,
y=reward_y,
text_anchor="middle",
font_family="Anuphan",
font_weight="500",
fill=reward_color,
)
)
# Grid positions
before_x = side_padding
after_x = side_padding + grid_max_width + grid_spacing
grids_y = top_padding
# Helper function to draw enhanced grid
def draw_enhanced_grid(
grid: Grid,
x: float,
y: float,
grid_label: str,
selection_mask: Optional[np.ndarray] = None,
highlight_changes: bool = False,
changed_cells: Optional[np.ndarray] = None,
) -> tuple[float, float]:
"""Draw an enhanced grid with overlays and styling."""
grid_data, grid_mask = _extract_grid_data(grid)
if grid_mask is not None:
grid_mask = np.asarray(grid_mask)
# Extract valid region
valid_grid, (start_row, start_col), (height, width) = _extract_valid_region(
grid_data, grid_mask
)
if height == 0 or width == 0:
return 0, 0
# Calculate cell size to fit within max dimensions
cell_size = min(grid_max_width / width, grid_max_height / height)
actual_width = width * cell_size
actual_height = height * cell_size
# Center the grid within the allocated space
grid_x = x + (grid_max_width - actual_width) / 2
grid_y = y + (grid_max_height - actual_height) / 2
# Draw grid background
drawing.append(
draw.Rectangle(
grid_x - 5,
grid_y - 5,
actual_width + 10,
actual_height + 10,
fill="white",
stroke="#dee2e6",
stroke_width=1,
rx=5,
)
)
# Draw grid cells
for i in range(height):
for j in range(width):
color_val = int(valid_grid[i, j])
# Check if cell is valid
is_valid = True
if grid_mask is not None:
actual_row = start_row + i
actual_col = start_col + j
if (
actual_row < grid_mask.shape[0]
and actual_col < grid_mask.shape[1]
):
is_valid = grid_mask[actual_row, actual_col]
if is_valid and 0 <= color_val < len(color_palette.keys()):
fill_color = color_palette.get(color_val, "white")
else:
fill_color = "#CCCCCC"
cell_x = grid_x + j * cell_size
cell_y = grid_y + i * cell_size
# Draw cell
drawing.append(
draw.Rectangle(
cell_x,
cell_y,
cell_size,
cell_size,
fill=fill_color,
stroke="#6c757d",
stroke_width=0.5,
)
)
# Add changed cell highlighting after all cells are drawn
if highlight_changes and changed_cells is not None:
add_change_highlighting(
drawing,
changed_cells,
grid_x,
grid_y,
cell_size,
start_row,
start_col,
height,
width,
)
# Add selection overlay if provided
if selection_mask is not None and selection_mask.any():
add_selection_visualization_overlay(
drawing,
selection_mask,
grid_x,
grid_y,
cell_size,
start_row,
start_col,
height,
width,
selection_color="#00FFFF", # Bright cyan - very visible
selection_opacity=0.4,
border_width=3,
)
# Add enhanced grid border
drawing.append(
draw.Rectangle(
grid_x - 3,
grid_y - 3,
actual_width + 6,
actual_height + 6,
fill="none",
stroke="#495057",
stroke_width=2,
rx=3,
)
)
# Add enhanced grid label with background
label_bg_width = len(grid_label) * 12 + 20
drawing.append(
draw.Rectangle(
grid_x - 5,
grid_y + actual_height + 15,
label_bg_width,
25,
fill="#e9ecef",
stroke="#dee2e6",
stroke_width=1,
rx=3,
)
)
drawing.append(
draw.Text(
grid_label,
font_size=16,
x=grid_x + 5,
y=grid_y + actual_height + 32,
text_anchor="start",
font_family="Anuphan",
font_weight="600",
fill="#495057",
)
)
return actual_width, actual_height
# Extract selection mask from action
selection_mask = None
if isinstance(action, Action):
selection_mask = np.asarray(action.selection)
elif isinstance(action, tuple) and len(action) >= 3:
# Handle tuple actions from action wrappers
grid_height, grid_width = before_grid.shape
selection_mask = np.zeros((grid_height, grid_width), dtype=bool)
if len(action) == 3:
# PointActionWrapper: (operation, row, col)
_, row, col = action[0], action[1], action[2]
# Clip coordinates to valid range and set the selected point
row = max(0, min(int(row), grid_height - 1))
col = max(0, min(int(col), grid_width - 1))
selection_mask[row, col] = True
elif len(action) == 5:
# BboxActionWrapper: (operation, r1, c1, r2, c2)
_, r1, c1, r2, c2 = action[0], action[1], action[2], action[3], action[4]
# Clip coordinates to valid range
r1 = max(0, min(int(r1), grid_height - 1))
c1 = max(0, min(int(c1), grid_width - 1))
r2 = max(0, min(int(r2), grid_height - 1))
c2 = max(0, min(int(c2), grid_width - 1))
# Ensure proper ordering (min, max)
min_r, max_r = min(r1, r2), max(r1, r2)
min_c, max_c = min(c1, c2), max(c1, c2)
# Set rectangular region (inclusive bounds)
selection_mask[min_r : max_r + 1, min_c : max_c + 1] = True
elif isinstance(action, dict): # Fallback for old dictionary format
if "selection" in action:
selection_mask = np.asarray(action["selection"])
# Draw before grid with selection overlay
draw_enhanced_grid(
before_grid, before_x, grids_y, "Before State", selection_mask=selection_mask
)
# Draw after grid with change highlighting
draw_enhanced_grid(
after_grid,
after_x,
grids_y,
"After State",
highlight_changes=True,
changed_cells=changed_cells,
)
# Add enhanced arrow between grids
arrow_y = grids_y + grid_max_height / 2
arrow_start_x = before_x + grid_max_width + 30
arrow_end_x = after_x - 30
# Arrow shaft
drawing.append(
draw.Line(
arrow_start_x,
arrow_y,
arrow_end_x,
arrow_y,
stroke="#6c757d",
stroke_width=3,
)
)
# Arrow head
drawing.append(
draw.Lines(
arrow_end_x - 15,
arrow_y - 10,
arrow_end_x - 15,
arrow_y + 10,
arrow_end_x,
arrow_y,
close=True,
fill="#6c757d",
)
)
return drawing.as_svg()
[docs]
def draw_rl_step_svg(
before_grid: Grid,
after_grid: Grid,
action: Dict[str, Any],
reward: float,
info: Dict[str, Any],
step_num: int,
operation_name: str = "",
changed_cells: Optional[jnp.ndarray] = None,
config: Optional[Any] = None,
**kwargs,
) -> str:
"""Enhanced wrapper for draw_rl_step_svg_enhanced with backward compatibility."""
return draw_rl_step_svg_enhanced(
before_grid=before_grid,
after_grid=after_grid,
action=action,
reward=reward,
info=info,
step_num=step_num,
operation_name=operation_name,
changed_cells=changed_cells,
config=config,
**kwargs,
)
[docs]
def save_rl_step_visualization(
state: Any, # ArcEnvState
action: dict,
next_state: Any, # ArcEnvState
output_dir: str = "output/rl_steps",
) -> None:
"""JAX callback function to save RL step visualization.
This function is designed to be used with jax.debug.callback.
Args:
state: Environment state before the action
action: Action dictionary with 'selection' and 'operation' keys
next_state: Environment state after the action
output_dir: Directory to save visualization files
"""
from pathlib import Path
from jaxarc.types import Grid
# Ensure output directory exists
Path(output_dir).mkdir(parents=True, exist_ok=True)
# Create Grid objects (convert JAX arrays to numpy)
before_grid = Grid(
data=np.asarray(state.working_grid),
mask=np.asarray(state.working_grid_mask),
)
after_grid = Grid(
data=np.asarray(next_state.working_grid),
mask=np.asarray(next_state.working_grid_mask),
)
# Extract action components
# Note: This handles structured actions, dictionary format, and tuple format for visualization
if hasattr(action, "operation"):
operation_id = int(action.operation)
elif isinstance(action, tuple) and len(action) >= 1:
# Handle tuple actions from PointActionWrapper: (operation, row, col)
operation_id = int(action[0])
else:
operation_id = int(action["operation"]) # Legacy format for visualization only
step_number = int(state.step_count)
# Create dummy reward and info for visualization
reward = 0.0 # Placeholder since we don't have reward in this context
info = {"step_count": step_number} # Basic info
# Generate visualization
svg_content = draw_rl_step_svg(
before_grid=before_grid,
after_grid=after_grid,
action=action,
reward=reward,
info=info,
step_num=step_number,
)
# Save to file with zero-padded step number
filename = f"step_{step_number:03d}.svg"
filepath = Path(output_dir) / filename
with open(filepath, "w", encoding="utf-8") as f:
f.write(svg_content)
# Log the save (will appear in console during execution)
logger.info(f"Saved RL step visualization: {filepath}")
def create_action_summary_panel(
action: Dict[str, Any],
reward: float,
info: Dict[str, Any],
operation_name: str = "",
width: float = 400,
height: float = 100,
) -> str:
"""Create an action summary panel as SVG.
Args:
action: Action dictionary
reward: Reward received
info: Additional information
operation_name: Human-readable operation name
width: Panel width
height: Panel height
Returns:
SVG string for the action summary panel
"""
import drawsvg as draw
drawing = draw.Drawing(width, height)
# Panel background
drawing.append(
draw.Rectangle(
0,
0,
width,
height,
fill="#ffffff",
stroke="#dee2e6",
stroke_width=1,
rx=8,
)
)
# Title
drawing.append(
draw.Text(
"Action Summary",
font_size=16,
x=10,
y=25,
text_anchor="start",
font_family="Anuphan",
font_weight="600",
fill="#2c3e50",
)
)
# Operation info
if operation_name:
drawing.append(
draw.Text(
f"Operation: {operation_name}",
font_size=14,
x=10,
y=45,
text_anchor="start",
font_family="Anuphan",
font_weight="400",
fill="#495057",
)
)
# Reward info
reward_color = "#27ae60" if reward > 0 else "#e74c3c" if reward < 0 else "#95a5a6"
drawing.append(
draw.Text(
f"Reward: {reward:.3f}",
font_size=14,
x=10,
y=65,
text_anchor="start",
font_family="Anuphan",
font_weight="500",
fill=reward_color,
)
)
# Additional info - check both direct and nested metrics
similarity_val = get_info_metric(info, "similarity")
if similarity_val is not None:
drawing.append(
draw.Text(
f"Similarity: {similarity_val:.3f}",
font_size=12,
x=10,
y=85,
text_anchor="start",
font_family="Anuphan",
font_weight="400",
fill="#6c757d",
)
)
return drawing.as_svg()
def create_metrics_visualization(
metrics: Dict[str, float],
width: float = 300,
height: float = 200,
) -> str:
"""Create a metrics visualization panel.
Args:
metrics: Dictionary of metric names to values
width: Panel width
height: Panel height
Returns:
SVG string for the metrics panel
"""
import drawsvg as draw
drawing = draw.Drawing(width, height)
# Panel background
drawing.append(
draw.Rectangle(
0,
0,
width,
height,
fill="#ffffff",
stroke="#dee2e6",
stroke_width=1,
rx=8,
)
)
# Title
drawing.append(
draw.Text(
"Step Metrics",
font_size=16,
x=10,
y=25,
text_anchor="start",
font_family="Anuphan",
font_weight="600",
fill="#2c3e50",
)
)
# Display metrics
y_pos = 50
for name, value in metrics.items():
# Metric name
drawing.append(
draw.Text(
f"{name}:",
font_size=12,
x=10,
y=y_pos,
text_anchor="start",
font_family="Anuphan",
font_weight="500",
fill="#495057",
)
)
# Metric value
drawing.append(
draw.Text(
f"{value:.3f}",
font_size=12,
x=width - 10,
y=y_pos,
text_anchor="end",
font_family="Anuphan",
font_weight="400",
fill="#6c757d",
)
)
y_pos += 20
if y_pos > height - 20:
break
return drawing.as_svg()