Utilities API¶
Helper functions and utilities for visualization, data loading, and more.
Visualization¶
Enhanced visualization and logging system for JaxARC.
This module provides comprehensive visualization capabilities for ARC grids, tasks, and RL training episodes with support for multiple output formats and performance optimization.
- Public API:
- Core visualization functions:
log_grid_to_console: Console logging with Rich formatting
draw_grid_svg: SVG generation for single grids
visualize_grid_rich: Rich table visualization for grids
visualize_task_pair_rich: Rich visualization for input-output pairs
draw_task_pair_svg: SVG generation for task pairs
visualize_parsed_task_data_rich: Complete task visualization
draw_parsed_task_data_svg: SVG generation for complete tasks
- RL-specific functions:
draw_rl_step_svg: Visualization of RL step transitions
save_rl_step_visualization: Save step visualizations to disk
- Utility functions:
save_svg_drawing: Save SVG drawings to files
- Constants:
ARC_COLOR_PALETTE: Standard ARC color mapping
- class jaxarc.utils.visualization.EpisodeConfig(base_output_dir: str = 'outputs/episodes', run_name: str | None = None, episode_dir_format: str = 'episode_{episode:04d}', step_file_format: str = 'step_{step:03d}', max_episodes_per_run: int = 1000, cleanup_policy: Literal['oldest_first', 'size_based', 'manual'] = 'size_based', max_storage_gb: float = 10.0, create_run_subdirs: bool = True, preserve_empty_dirs: bool = False, compress_old_episodes: bool = False)[source]¶
Bases:
MappingConfiguration for episode management and storage.
This dataclass defines all settings for organizing and managing episode-based visualization storage with validation and serialization.
- estimate_storage_usage(path: Path) float[source]¶
Estimate storage usage in GB for a given path.
- Parameters:
path – Path to analyze
- Returns:
Storage usage in GB
- classmethod from_dict(data: dict[str, Any]) EpisodeConfig[source]¶
Create configuration from dictionary.
- Parameters:
data – Dictionary containing configuration parameters
- Returns:
EpisodeConfig instance
- Raises:
ValueError – If required keys are missing or invalid
- from_tuple()¶
- generate_run_name() str[source]¶
Generate a timestamped run name if none is provided.
- Returns:
Generated run name with timestamp
- get_base_path() Path[source]¶
Get the base output directory as a Path object.
- Returns:
Path object for the base output directory
- items() a set-like object providing a view on D's items¶
- keys() a set-like object providing a view on D's keys¶
- classmethod load_from_file(file_path: Path | str) EpisodeConfig[source]¶
Load configuration from JSON file.
- Parameters:
file_path – Path to the configuration file
- Returns:
EpisodeConfig instance
- Raises:
FileNotFoundError – If file doesn’t exist
ValueError – If file contains invalid configuration
- replace(**kwargs)¶
- save_to_file(file_path: Path | str) None[source]¶
Save configuration to JSON file.
- Parameters:
file_path – Path where to save the configuration
- Raises:
OSError – If file cannot be written
- to_dict() dict[str, Any][source]¶
Convert configuration to dictionary for serialization.
- Returns:
Dictionary representation of the configuration
- to_tuple()¶
- validate_storage_path(path: Path) bool[source]¶
Validate that a storage path is accessible and writable.
- Parameters:
path – Path to validate
- Returns:
True if path is valid and writable
- values() an object providing a view on D's values¶
- class jaxarc.utils.visualization.EpisodeManager(config: EpisodeConfig)[source]¶
Bases:
objectManages episode-based storage and organization.
This class handles directory creation, file organization, and cleanup for episode-based visualization data storage.
- cleanup_old_data() None[source]¶
Clean up old data based on configured policy.
This method implements the cleanup policy specified in the configuration to manage storage usage and maintain the episode limit.
- force_cleanup_run(run_name: str) bool[source]¶
Force cleanup of a specific run directory.
- Parameters:
run_name – Name of the run to clean up
- Returns:
True if cleanup was successful, False otherwise
- get_current_run_info() dict[str, Any][source]¶
Get information about the current run.
- Returns:
Dictionary with run information
- get_episode_summary_path(file_type: str = 'svg') Path[source]¶
Get file path for episode summary visualization.
- Parameters:
file_type – File extension (without dot)
- Returns:
Path for the episode summary file
- Raises:
ValueError – If no episode is active
- get_step_path(step_num: int, file_type: str = 'svg') Path[source]¶
Get file path for a specific step visualization.
- Parameters:
step_num – Step number (must be non-negative)
file_type – File extension (without dot)
- Returns:
Path for the step file
- Raises:
ValueError – If no episode is active or step_num is invalid
- list_episodes_in_run(run_dir: Path | None = None) list[tuple[int, Path]][source]¶
List all episodes in a run directory.
- Parameters:
run_dir – Run directory to scan. Uses current run if None.
- Returns:
List of (episode_number, episode_path) tuples, sorted by episode number
- start_new_episode(episode_num: int) Path[source]¶
Start a new episode within the current run.
- Parameters:
episode_num – Episode number (must be non-negative)
- Returns:
Path to the created episode directory
- Raises:
ValueError – If no run is active or episode_num is invalid
OSError – If directory cannot be created
- start_new_run(run_name: str | None = None) Path[source]¶
Start a new training run with timestamped directory.
- Parameters:
run_name – Optional custom run name. If None, uses config or generates one.
- Returns:
Path to the created run directory
- Raises:
OSError – If directory cannot be created
ValueError – If run_name is invalid
- jaxarc.utils.visualization.create_episode_comparison_visualization(episodes_data: List[Any], comparison_type: str = 'reward_progression', width: float = 1200.0, height: float = 600.0) str[source]¶
Create comparison visualization across multiple episodes.
- Parameters:
episodes_data – List of episode summary data
comparison_type – Type of comparison (“reward_progression”, “similarity”, “performance”)
width – Width of the visualization
height – Height of the visualization
- Returns:
SVG string containing the comparison visualization
- jaxarc.utils.visualization.draw_episode_summary_svg(summary_data: Any, step_data: List[Any], config: Any | None = None, width: float = 1400.0, height: float = 1000.0) str[source]¶
Generate episode summary visualization (enhanced version).
- jaxarc.utils.visualization.draw_grid_svg(grid_input: jnp.ndarray | np.ndarray | Grid, mask: jnp.ndarray | np.ndarray | None = None, max_width: float = 10.0, max_height: float = 10.0, padding: float = 0.5, extra_bottom_padding: float = 0.5, label: str = '', border_color: str = '#111111ff', show_size: bool = True, as_group: bool = False) drawsvg.Drawing | tuple[drawsvg.Group, tuple[float, float], tuple[float, float]][source]¶
Draw a single grid as an SVG.
- Parameters:
grid_input – Grid data (JAX array, numpy array, or Grid object)
mask – Optional boolean mask indicating valid cells
max_width – Maximum width for the drawing
max_height – Maximum height for the drawing
padding – Padding around the grid
extra_bottom_padding – Extra padding at bottom for labels
label – Label to display below the grid
border_color – Color for the grid border
show_size – Whether to show grid dimensions
as_group – If True, return as a group for inclusion in larger drawings
- Returns:
Either a Drawing object or tuple of (Group, origin, size) if as_group=True
- jaxarc.utils.visualization.draw_parsed_task_data_svg(task_data: JaxArcTask, width: float = 30.0, height: float = 20.0, include_test: bool | str = False, border_colors: list[str] | None = None) drawsvg.Drawing[source]¶
Draw a complete JaxArcTask as an SVG with strict height and flexible width.
- Parameters:
task_data – The parsed task data to visualize
width – Maximum width for the drawing (actual width may be less)
height – Strict height constraint - all content must fit within this height
include_test – Whether to include test examples. If ‘all’, show test outputs too.
border_colors – Custom border colors [input_color, output_color]
- Returns:
SVG Drawing object
- jaxarc.utils.visualization.draw_rl_step_svg(before_grid: Grid, after_grid: Grid, action: Dict[str, Any], reward: float, info: Dict[str, Any], step_num: int, operation_name: str = '', changed_cells: jnp.ndarray | None = None, config: Any | None = None, **kwargs) str[source]¶
Enhanced wrapper for draw_rl_step_svg_enhanced with backward compatibility.
- jaxarc.utils.visualization.draw_task_pair_svg(input_grid: jnp.ndarray | np.ndarray | Grid, output_grid: jnp.ndarray | np.ndarray | Grid | None = None, input_mask: jnp.ndarray | np.ndarray | None = None, output_mask: jnp.ndarray | np.ndarray | None = None, width: float = 15.0, height: float = 8.0, label: str = '', show_unknown_output: bool = True) drawsvg.Drawing[source]¶
Draw an input-output task pair as SVG with strict height and flexible width.
- Parameters:
input_grid – Input grid data
output_grid – Output grid data (optional)
input_mask – Optional mask for input grid
output_mask – Optional mask for output grid
width – Maximum width for the drawing (actual width may be less)
height – Strict height constraint - all content must fit within this height
label – Label for the pair
show_unknown_output – Whether to show “?” for missing output
- Returns:
SVG Drawing object
- jaxarc.utils.visualization.log_grid_to_console(grid_input: jnp.ndarray | np.ndarray | Grid, mask: jnp.ndarray | np.ndarray | None = None, title: str = 'Grid', show_coordinates: bool = False, show_numbers: bool = False, double_width: bool = True) None[source]¶
Log a grid visualization to the console using Rich.
This function is designed to be used with jax.debug.callback for logging during JAX transformations.
- Parameters:
grid_input – Grid data (JAX array, numpy array, or Grid object)
mask – Optional boolean mask indicating valid cells
title – Title for the grid display
show_coordinates – Whether to show row/column coordinates
show_numbers – If True, show colored numbers; if False, show colored blocks
double_width – If True and show_numbers=False, use double-width blocks for square appearance
- jaxarc.utils.visualization.save_rl_step_visualization(state: Any, action: dict, next_state: Any, output_dir: str = 'output/rl_steps') None[source]¶
JAX callback function to save RL step visualization.
This function is designed to be used with jax.debug.callback.
- Parameters:
state – Environment state before the action
action – Action dictionary with ‘selection’ and ‘operation’ keys
next_state – Environment state after the action
output_dir – Directory to save visualization files
- jaxarc.utils.visualization.save_svg_drawing(drawing: Drawing, filename: str, context: Any | None = None) None[source]¶
Save an SVG drawing to file with support for multiple formats.
- Parameters:
drawing – The SVG drawing to save
filename – Output filename (extension determines format: .svg, .png, .pdf)
context – Optional context for PDF conversion
- Raises:
ValueError – If file extension is not supported
ImportError – If required dependencies are missing for PNG/PDF output
- jaxarc.utils.visualization.visualize_grid_rich(grid_input: jnp.ndarray | np.ndarray | Grid, mask: jnp.ndarray | np.ndarray | None = None, title: str = 'Grid', show_coordinates: bool = False, show_numbers: bool = False, double_width: bool = True, border_style: str = 'default') Table | Panel[source]¶
Create a Rich Table visualization of a single grid.
- Parameters:
grid_input – Grid data (JAX array, numpy array, or Grid object)
mask – Optional boolean mask indicating valid cells
title – Title for the table
show_coordinates – Whether to show row/column coordinates
show_numbers – If True, show colored numbers; if False, show colored blocks
double_width – If True and show_numbers=False, use double-width blocks for square appearance
border_style – Border style - ‘input’ for blue borders, ‘output’ for green borders, ‘default’ for normal
- Returns:
Rich Table object for display
- jaxarc.utils.visualization.visualize_parsed_task_data_rich(task_data: JaxArcTask, show_test: bool = True, show_coordinates: bool = False, show_numbers: bool = False, double_width: bool = True) None[source]¶
Visualize a JaxArcTask object using Rich console output with enhanced layout and grouping.
- Parameters:
task_data – The parsed task data to visualize
show_test – Whether to show test pairs
show_coordinates – Whether to show grid coordinates
show_numbers – If True, show colored numbers; if False, show colored blocks
double_width – If True and show_numbers=False, use double-width blocks for square appearance
- jaxarc.utils.visualization.visualize_task_pair_rich(input_grid: jnp.ndarray | np.ndarray | Grid, output_grid: jnp.ndarray | np.ndarray | Grid | None = None, input_mask: jnp.ndarray | np.ndarray | None = None, output_mask: jnp.ndarray | np.ndarray | None = None, title: str = 'Task Pair', show_numbers: bool = False, double_width: bool = True, console: Console | None = None) None[source]¶
Visualize an input-output pair using Rich tables with responsive layout.
- Parameters:
input_grid – Input grid data
output_grid – Output grid data (optional)
input_mask – Optional mask for input grid
output_mask – Optional mask for output grid
title – Title for the visualization
show_numbers – If True, show colored numbers; if False, show colored blocks
double_width – If True and show_numbers=False, use double-width blocks for square appearance
console – Optional Rich console (creates one if None)
Core Utilities¶
Hydra configuration utilities for jaxarc project.
This module provides utilities for working with Hydra configurations, including loading configs and managing data paths. It focuses purely on Hydra integration without mixing in factory functions.
- jaxarc.utils.core.get_config(overrides: list[str] | None = None) DictConfig[source]¶
Load the default Hydra configuration.
- Parameters:
overrides – List of configuration overrides in Hydra format
- Returns:
Loaded Hydra configuration
Example
```python from jaxarc.utils.core import get_config
# Load default config cfg = get_config()
# Load with overrides cfg = get_config([“dataset.dataset_name=ConceptARC”, “action.selection_format=point”]) ```
- jaxarc.utils.core.get_path(path_type: str, create: bool = False) Path[source]¶
Get a configured path by type.
- Parameters:
path_type – Type of path (‘data_raw’, ‘data_processed’, ‘data_interim’, ‘data_external’)
create – Whether to create the directory if it doesn’t exist
- Returns:
Path object for the requested path type
- Raises:
KeyError – If path_type is not found in configuration
Example
```python from jaxarc.utils.core import get_path
# Get raw data path raw_path = get_path(“data_raw”, create=True) ```
Usage Examples¶
Visualization¶
from jaxarc.utils.visualization import draw_grid_svg, draw_task_pair_svg
import jax.numpy as jnp
# Create a simple grid
grid = jnp.array(
[
[0, 1, 2],
[3, 4, 5],
[6, 7, 8],
]
)
# Draw as SVG
svg = draw_grid_svg(grid)
display(svg) # In Jupyter notebook
Configuration¶
from jaxarc.utils.core import get_config
from jaxarc import JaxArcConfig
# Load configuration with overrides
hydra_cfg = get_config(
overrides=[
"dataset=mini_arc",
]
)
config = JaxArcConfig.from_hydra(hydra_cfg)
See Also¶
Visualizing ARC Tasks - Tutorial on visualization