Utilities API

Helper functions and utilities for visualization, data loading, and more.

Visualization

Enhanced visualization and logging system for JaxARC.

This module provides comprehensive visualization capabilities for ARC grids, tasks, and RL training episodes with support for multiple output formats and performance optimization.

Public API:
Core visualization functions:
  • log_grid_to_console: Console logging with Rich formatting

  • draw_grid_svg: SVG generation for single grids

  • visualize_grid_rich: Rich table visualization for grids

  • visualize_task_pair_rich: Rich visualization for input-output pairs

  • draw_task_pair_svg: SVG generation for task pairs

  • visualize_parsed_task_data_rich: Complete task visualization

  • draw_parsed_task_data_svg: SVG generation for complete tasks

RL-specific functions:
  • draw_rl_step_svg: Visualization of RL step transitions

  • save_rl_step_visualization: Save step visualizations to disk

Utility functions:
  • save_svg_drawing: Save SVG drawings to files

Constants:
  • ARC_COLOR_PALETTE: Standard ARC color mapping

class jaxarc.utils.visualization.EpisodeConfig(base_output_dir: str = 'outputs/episodes', run_name: str | None = None, episode_dir_format: str = 'episode_{episode:04d}', step_file_format: str = 'step_{step:03d}', max_episodes_per_run: int = 1000, cleanup_policy: Literal['oldest_first', 'size_based', 'manual'] = 'size_based', max_storage_gb: float = 10.0, create_run_subdirs: bool = True, preserve_empty_dirs: bool = False, compress_old_episodes: bool = False)[source]

Bases: Mapping

Configuration for episode management and storage.

This dataclass defines all settings for organizing and managing episode-based visualization storage with validation and serialization.

base_output_dir: str = 'outputs/episodes'
cleanup_policy: Literal['oldest_first', 'size_based', 'manual'] = 'size_based'
compress_old_episodes: bool = False
create_run_subdirs: bool = True
episode_dir_format: str = 'episode_{episode:04d}'
estimate_storage_usage(path: Path) float[source]

Estimate storage usage in GB for a given path.

Parameters:

path – Path to analyze

Returns:

Storage usage in GB

classmethod from_dict(data: dict[str, Any]) EpisodeConfig[source]

Create configuration from dictionary.

Parameters:

data – Dictionary containing configuration parameters

Returns:

EpisodeConfig instance

Raises:

ValueError – If required keys are missing or invalid

from_tuple()
generate_run_name() str[source]

Generate a timestamped run name if none is provided.

Returns:

Generated run name with timestamp

get_base_path() Path[source]

Get the base output directory as a Path object.

Returns:

Path object for the base output directory

items() a set-like object providing a view on D's items
keys() a set-like object providing a view on D's keys
classmethod load_from_file(file_path: Path | str) EpisodeConfig[source]

Load configuration from JSON file.

Parameters:

file_path – Path to the configuration file

Returns:

EpisodeConfig instance

Raises:
max_episodes_per_run: int = 1000
max_storage_gb: float = 10.0
preserve_empty_dirs: bool = False
replace(**kwargs)
run_name: str | None = None
save_to_file(file_path: Path | str) None[source]

Save configuration to JSON file.

Parameters:

file_path – Path where to save the configuration

Raises:

OSError – If file cannot be written

step_file_format: str = 'step_{step:03d}'
to_dict() dict[str, Any][source]

Convert configuration to dictionary for serialization.

Returns:

Dictionary representation of the configuration

to_tuple()
validate_storage_path(path: Path) bool[source]

Validate that a storage path is accessible and writable.

Parameters:

path – Path to validate

Returns:

True if path is valid and writable

values() an object providing a view on D's values
class jaxarc.utils.visualization.EpisodeManager(config: EpisodeConfig)[source]

Bases: object

Manages episode-based storage and organization.

This class handles directory creation, file organization, and cleanup for episode-based visualization data storage.

cleanup_old_data() None[source]

Clean up old data based on configured policy.

This method implements the cleanup policy specified in the configuration to manage storage usage and maintain the episode limit.

current_episode_dir: Path | None
current_episode_num: int | None
current_run_dir: Path | None
current_run_name: str | None
force_cleanup_run(run_name: str) bool[source]

Force cleanup of a specific run directory.

Parameters:

run_name – Name of the run to clean up

Returns:

True if cleanup was successful, False otherwise

get_current_run_info() dict[str, Any][source]

Get information about the current run.

Returns:

Dictionary with run information

get_episode_summary_path(file_type: str = 'svg') Path[source]

Get file path for episode summary visualization.

Parameters:

file_type – File extension (without dot)

Returns:

Path for the episode summary file

Raises:

ValueError – If no episode is active

get_step_path(step_num: int, file_type: str = 'svg') Path[source]

Get file path for a specific step visualization.

Parameters:
  • step_num – Step number (must be non-negative)

  • file_type – File extension (without dot)

Returns:

Path for the step file

Raises:

ValueError – If no episode is active or step_num is invalid

list_episodes_in_run(run_dir: Path | None = None) list[tuple[int, Path]][source]

List all episodes in a run directory.

Parameters:

run_dir – Run directory to scan. Uses current run if None.

Returns:

List of (episode_number, episode_path) tuples, sorted by episode number

start_new_episode(episode_num: int) Path[source]

Start a new episode within the current run.

Parameters:

episode_num – Episode number (must be non-negative)

Returns:

Path to the created episode directory

Raises:
  • ValueError – If no run is active or episode_num is invalid

  • OSError – If directory cannot be created

start_new_run(run_name: str | None = None) Path[source]

Start a new training run with timestamped directory.

Parameters:

run_name – Optional custom run name. If None, uses config or generates one.

Returns:

Path to the created run directory

Raises:
jaxarc.utils.visualization.create_episode_comparison_visualization(episodes_data: List[Any], comparison_type: str = 'reward_progression', width: float = 1200.0, height: float = 600.0) str[source]

Create comparison visualization across multiple episodes.

Parameters:
  • episodes_data – List of episode summary data

  • comparison_type – Type of comparison (“reward_progression”, “similarity”, “performance”)

  • width – Width of the visualization

  • height – Height of the visualization

Returns:

SVG string containing the comparison visualization

jaxarc.utils.visualization.draw_episode_summary_svg(summary_data: Any, step_data: List[Any], config: Any | None = None, width: float = 1400.0, height: float = 1000.0) str[source]

Generate episode summary visualization (enhanced version).

jaxarc.utils.visualization.draw_grid_svg(grid_input: jnp.ndarray | np.ndarray | Grid, mask: jnp.ndarray | np.ndarray | None = None, max_width: float = 10.0, max_height: float = 10.0, padding: float = 0.5, extra_bottom_padding: float = 0.5, label: str = '', border_color: str = '#111111ff', show_size: bool = True, as_group: bool = False) drawsvg.Drawing | tuple[drawsvg.Group, tuple[float, float], tuple[float, float]][source]

Draw a single grid as an SVG.

Parameters:
  • grid_input – Grid data (JAX array, numpy array, or Grid object)

  • mask – Optional boolean mask indicating valid cells

  • max_width – Maximum width for the drawing

  • max_height – Maximum height for the drawing

  • padding – Padding around the grid

  • extra_bottom_padding – Extra padding at bottom for labels

  • label – Label to display below the grid

  • border_color – Color for the grid border

  • show_size – Whether to show grid dimensions

  • as_group – If True, return as a group for inclusion in larger drawings

Returns:

Either a Drawing object or tuple of (Group, origin, size) if as_group=True

jaxarc.utils.visualization.draw_parsed_task_data_svg(task_data: JaxArcTask, width: float = 30.0, height: float = 20.0, include_test: bool | str = False, border_colors: list[str] | None = None) drawsvg.Drawing[source]

Draw a complete JaxArcTask as an SVG with strict height and flexible width.

Parameters:
  • task_data – The parsed task data to visualize

  • width – Maximum width for the drawing (actual width may be less)

  • height – Strict height constraint - all content must fit within this height

  • include_test – Whether to include test examples. If ‘all’, show test outputs too.

  • border_colors – Custom border colors [input_color, output_color]

Returns:

SVG Drawing object

jaxarc.utils.visualization.draw_rl_step_svg(before_grid: Grid, after_grid: Grid, action: Dict[str, Any], reward: float, info: Dict[str, Any], step_num: int, operation_name: str = '', changed_cells: jnp.ndarray | None = None, config: Any | None = None, **kwargs) str[source]

Enhanced wrapper for draw_rl_step_svg_enhanced with backward compatibility.

jaxarc.utils.visualization.draw_task_pair_svg(input_grid: jnp.ndarray | np.ndarray | Grid, output_grid: jnp.ndarray | np.ndarray | Grid | None = None, input_mask: jnp.ndarray | np.ndarray | None = None, output_mask: jnp.ndarray | np.ndarray | None = None, width: float = 15.0, height: float = 8.0, label: str = '', show_unknown_output: bool = True) drawsvg.Drawing[source]

Draw an input-output task pair as SVG with strict height and flexible width.

Parameters:
  • input_grid – Input grid data

  • output_grid – Output grid data (optional)

  • input_mask – Optional mask for input grid

  • output_mask – Optional mask for output grid

  • width – Maximum width for the drawing (actual width may be less)

  • height – Strict height constraint - all content must fit within this height

  • label – Label for the pair

  • show_unknown_output – Whether to show “?” for missing output

Returns:

SVG Drawing object

jaxarc.utils.visualization.log_grid_to_console(grid_input: jnp.ndarray | np.ndarray | Grid, mask: jnp.ndarray | np.ndarray | None = None, title: str = 'Grid', show_coordinates: bool = False, show_numbers: bool = False, double_width: bool = True) None[source]

Log a grid visualization to the console using Rich.

This function is designed to be used with jax.debug.callback for logging during JAX transformations.

Parameters:
  • grid_input – Grid data (JAX array, numpy array, or Grid object)

  • mask – Optional boolean mask indicating valid cells

  • title – Title for the grid display

  • show_coordinates – Whether to show row/column coordinates

  • show_numbers – If True, show colored numbers; if False, show colored blocks

  • double_width – If True and show_numbers=False, use double-width blocks for square appearance

jaxarc.utils.visualization.save_rl_step_visualization(state: Any, action: dict, next_state: Any, output_dir: str = 'output/rl_steps') None[source]

JAX callback function to save RL step visualization.

This function is designed to be used with jax.debug.callback.

Parameters:
  • state – Environment state before the action

  • action – Action dictionary with ‘selection’ and ‘operation’ keys

  • next_state – Environment state after the action

  • output_dir – Directory to save visualization files

jaxarc.utils.visualization.save_svg_drawing(drawing: Drawing, filename: str, context: Any | None = None) None[source]

Save an SVG drawing to file with support for multiple formats.

Parameters:
  • drawing – The SVG drawing to save

  • filename – Output filename (extension determines format: .svg, .png, .pdf)

  • context – Optional context for PDF conversion

Raises:
  • ValueError – If file extension is not supported

  • ImportError – If required dependencies are missing for PNG/PDF output

jaxarc.utils.visualization.visualize_grid_rich(grid_input: jnp.ndarray | np.ndarray | Grid, mask: jnp.ndarray | np.ndarray | None = None, title: str = 'Grid', show_coordinates: bool = False, show_numbers: bool = False, double_width: bool = True, border_style: str = 'default') Table | Panel[source]

Create a Rich Table visualization of a single grid.

Parameters:
  • grid_input – Grid data (JAX array, numpy array, or Grid object)

  • mask – Optional boolean mask indicating valid cells

  • title – Title for the table

  • show_coordinates – Whether to show row/column coordinates

  • show_numbers – If True, show colored numbers; if False, show colored blocks

  • double_width – If True and show_numbers=False, use double-width blocks for square appearance

  • border_style – Border style - ‘input’ for blue borders, ‘output’ for green borders, ‘default’ for normal

Returns:

Rich Table object for display

jaxarc.utils.visualization.visualize_parsed_task_data_rich(task_data: JaxArcTask, show_test: bool = True, show_coordinates: bool = False, show_numbers: bool = False, double_width: bool = True) None[source]

Visualize a JaxArcTask object using Rich console output with enhanced layout and grouping.

Parameters:
  • task_data – The parsed task data to visualize

  • show_test – Whether to show test pairs

  • show_coordinates – Whether to show grid coordinates

  • show_numbers – If True, show colored numbers; if False, show colored blocks

  • double_width – If True and show_numbers=False, use double-width blocks for square appearance

jaxarc.utils.visualization.visualize_task_pair_rich(input_grid: jnp.ndarray | np.ndarray | Grid, output_grid: jnp.ndarray | np.ndarray | Grid | None = None, input_mask: jnp.ndarray | np.ndarray | None = None, output_mask: jnp.ndarray | np.ndarray | None = None, title: str = 'Task Pair', show_numbers: bool = False, double_width: bool = True, console: Console | None = None) None[source]

Visualize an input-output pair using Rich tables with responsive layout.

Parameters:
  • input_grid – Input grid data

  • output_grid – Output grid data (optional)

  • input_mask – Optional mask for input grid

  • output_mask – Optional mask for output grid

  • title – Title for the visualization

  • show_numbers – If True, show colored numbers; if False, show colored blocks

  • double_width – If True and show_numbers=False, use double-width blocks for square appearance

  • console – Optional Rich console (creates one if None)

Core Utilities

Hydra configuration utilities for jaxarc project.

This module provides utilities for working with Hydra configurations, including loading configs and managing data paths. It focuses purely on Hydra integration without mixing in factory functions.

jaxarc.utils.core.get_config(overrides: list[str] | None = None) DictConfig[source]

Load the default Hydra configuration.

Parameters:

overrides – List of configuration overrides in Hydra format

Returns:

Loaded Hydra configuration

Example

```python from jaxarc.utils.core import get_config

# Load default config cfg = get_config()

# Load with overrides cfg = get_config([“dataset.dataset_name=ConceptARC”, “action.selection_format=point”]) ```

jaxarc.utils.core.get_path(path_type: str, create: bool = False) Path[source]

Get a configured path by type.

Parameters:
  • path_type – Type of path (‘data_raw’, ‘data_processed’, ‘data_interim’, ‘data_external’)

  • create – Whether to create the directory if it doesn’t exist

Returns:

Path object for the requested path type

Raises:

KeyError – If path_type is not found in configuration

Example

```python from jaxarc.utils.core import get_path

# Get raw data path raw_path = get_path(“data_raw”, create=True) ```

jaxarc.utils.core.get_raw_path(create: bool = False) Path[source]

Get the raw data path.

Parameters:

create – Whether to create the directory if it doesn’t exist

Returns:

Path to raw data directory

Usage Examples

Visualization

from jaxarc.utils.visualization import draw_grid_svg, draw_task_pair_svg
import jax.numpy as jnp

# Create a simple grid
grid = jnp.array(
    [
        [0, 1, 2],
        [3, 4, 5],
        [6, 7, 8],
    ]
)

# Draw as SVG
svg = draw_grid_svg(grid)
display(svg)  # In Jupyter notebook

Configuration

from jaxarc.utils.core import get_config
from jaxarc import JaxArcConfig

# Load configuration with overrides
hydra_cfg = get_config(
    overrides=[
        "dataset=mini_arc",
    ]
)

config = JaxArcConfig.from_hydra(hydra_cfg)

See Also