Source code for jaxarc.utils.visualization.display

"""
Display functionality for tasks and episodes.

Handles both terminal (Rich) and file-based (SVG) visualization.
"""

from __future__ import annotations

import time
from typing import TYPE_CHECKING, Any, List, Optional

import drawsvg  # type: ignore[import-untyped]
import jax.numpy as jnp
import numpy as np
from rich import box
from rich.columns import Columns
from rich.console import Console, Group
from rich.padding import Padding
from rich.panel import Panel
from rich.rule import Rule
from rich.table import Table
from rich.text import Text

from jaxarc.utils.serialization_utils import serialize_jax_array
from jaxarc.utils.task_manager import extract_task_id_from_index

from .core import (
    ARC_COLOR_PALETTE,
    _draw_dotted_squircle,
    _extract_grid_data,
    _extract_valid_region,
    draw_grid_svg,
)

if TYPE_CHECKING:
    from jaxarc.types import Grid, GridArray, JaxArcTask


# ============================================================================
# SECTION: Rich Terminal Display (from rich_display.py)
# ============================================================================


def _get_panel_border_style(border_style: str) -> str:
    """Get panel border style based on border type."""
    if border_style == "input":
        return "blue"
    if border_style == "output":
        return "green"
    return "blue"


def _get_title_style(border_style: str) -> str:
    """Get title style based on border type."""
    if border_style == "input":
        return "bold blue"
    if border_style == "output":
        return "bold green"
    return "bold"


[docs] def visualize_grid_rich( grid_input: jnp.ndarray | np.ndarray | Grid, mask: jnp.ndarray | np.ndarray | None = None, title: str = "Grid", show_coordinates: bool = False, show_numbers: bool = False, double_width: bool = True, border_style: str = "default", ) -> Table | Panel: """Create a Rich Table visualization of a single grid. Args: grid_input: Grid data (JAX array, numpy array, or Grid object) mask: Optional boolean mask indicating valid cells title: Title for the table show_coordinates: Whether to show row/column coordinates show_numbers: If True, show colored numbers; if False, show colored blocks double_width: If True and show_numbers=False, use double-width blocks for square appearance border_style: Border style - 'input' for blue borders, 'output' for green borders, 'default' for normal Returns: Rich Table object for display """ grid, grid_mask = _extract_grid_data(grid_input) if mask is None: mask = grid_mask if mask is not None: mask = serialize_jax_array(mask) if grid.size == 0: table = Table(show_header=False, show_edge=False, show_lines=False, box=None) table.add_column("Empty") table.add_row("[grey23]Empty grid[/]") panel_style = _get_panel_border_style(border_style) title_style = _get_title_style(border_style) return Panel( table, title=Text(f"{title} (Empty)", style=title_style), border_style=panel_style, box=box.ROUNDED if border_style == "input" else box.HEAVY, padding=(0, 0), ) # Extract valid region valid_grid, (start_row, start_col), (height, width) = _extract_valid_region( grid, mask ) if height == 0 or width == 0: table = Table(show_header=False, show_edge=False, show_lines=False, box=None) table.add_column("Empty") table.add_row("[grey23]No valid data[/]") panel_style = _get_panel_border_style(border_style) title_style = _get_title_style(border_style) return Panel( table, title=Text(f"{title} (No valid data)", style=title_style), border_style=panel_style, box=box.ROUNDED if border_style == "input" else box.HEAVY, padding=(0, 0), ) # Create table without borders (will be wrapped in panel) table = Table( show_header=show_coordinates, show_edge=False, show_lines=False, box=None, padding=0, pad_edge=False, ) # Add columns if show_coordinates: table.add_column("", justify="center", width=3) # Row numbers for j in range(width): col_header = str(start_col + j) if show_coordinates else "" # Adjust column width based on display mode col_width = 2 # Single blocks table.add_column(col_header, justify="center", width=col_width, no_wrap=True) # Add rows for i in range(height): row_items = [] if show_coordinates: row_items.append(str(start_row + i)) for j in range(width): color_val = int(valid_grid[i, j]) # Check if this cell is valid (if mask is provided) is_valid = True if mask is not None: actual_row = start_row + i actual_col = start_col + j if actual_row < mask.shape[0] and actual_col < mask.shape[1]: is_valid = mask[actual_row, actual_col] if not is_valid: if show_numbers: row_items.append("[grey23]·[/]") else: placeholder = "·" if not double_width else "··" row_items.append(f"[grey23]{placeholder}[/]") elif show_numbers: # Show colored numbers rich_color = ARC_COLOR_PALETTE.get(color_val, "white") row_items.append(f"[{rich_color}]{color_val}[/]") elif double_width: # Use double-width blocks for more square appearance rich_color = ARC_COLOR_PALETTE.get(color_val, "white") row_items.append(f"[{rich_color}]██[/]") else: # Use single block character rich_color = ARC_COLOR_PALETTE.get(color_val, "white") row_items.append(f"[{rich_color}]█[/]") table.add_row(*row_items) # Wrap table in panel with appropriate border style panel_style = _get_panel_border_style(border_style) title_style = _get_title_style(border_style) return Panel( table, title=Text(f"{title} ({height}x{width})", style=title_style), border_style=panel_style, box=box.ROUNDED if border_style == "input" else box.HEAVY, padding=(0, 0), )
[docs] def log_grid_to_console( grid_input: jnp.ndarray | np.ndarray | Grid, mask: jnp.ndarray | np.ndarray | None = None, title: str = "Grid", show_coordinates: bool = False, show_numbers: bool = False, double_width: bool = True, ) -> None: """Log a grid visualization to the console using Rich. This function is designed to be used with jax.debug.callback for logging during JAX transformations. Args: grid_input: Grid data (JAX array, numpy array, or Grid object) mask: Optional boolean mask indicating valid cells title: Title for the grid display show_coordinates: Whether to show row/column coordinates show_numbers: If True, show colored numbers; if False, show colored blocks double_width: If True and show_numbers=False, use double-width blocks for square appearance """ console = Console() table = visualize_grid_rich( grid_input, mask, title, show_coordinates, show_numbers, double_width ) console.print(table)
[docs] def visualize_task_pair_rich( input_grid: jnp.ndarray | np.ndarray | Grid, output_grid: jnp.ndarray | np.ndarray | Grid | None = None, input_mask: jnp.ndarray | np.ndarray | None = None, output_mask: jnp.ndarray | np.ndarray | None = None, title: str = "Task Pair", show_numbers: bool = False, double_width: bool = True, console: Console | None = None, ) -> None: """Visualize an input-output pair using Rich tables with responsive layout. Args: input_grid: Input grid data output_grid: Output grid data (optional) input_mask: Optional mask for input grid output_mask: Optional mask for output grid title: Title for the visualization show_numbers: If True, show colored numbers; if False, show colored blocks double_width: If True and show_numbers=False, use double-width blocks for square appearance console: Optional Rich console (creates one if None) """ if console is None: console = Console() # Create input table with blue border input_table = visualize_grid_rich( input_grid, input_mask, f"{title} - Input", show_numbers=show_numbers, double_width=double_width, border_style="input", ) # Create output table or placeholder if output_grid is not None: output_table = visualize_grid_rich( output_grid, output_mask, f"{title} - Output", show_numbers=show_numbers, double_width=double_width, border_style="output", ) else: # Create placeholder for unknown output output_table = Table( show_header=False, show_edge=False, show_lines=False, box=None, ) output_table.add_column("Unknown", justify="center") question_text = Text("?", style="bold yellow") output_table.add_row(question_text) output_table = Panel( output_table, title=Text(f"{title} - Output", style="bold green"), border_style="green", box=box.HEAVY, padding=(0, 0), ) # Responsive layout based on terminal width terminal_width = console.size.width # If terminal is wide enough, show side-by-side if terminal_width >= 120: columns = Columns([input_table, output_table], equal=True, expand=True) console.print(columns) else: # Stack vertically with clear separation console.print(input_table) arrow_text = Text("↓", justify="center", style="bold") console.print(arrow_text) console.print(output_table)
[docs] def visualize_parsed_task_data_rich( task_data: JaxArcTask, show_test: bool = True, show_coordinates: bool = False, show_numbers: bool = False, double_width: bool = True, ) -> None: """Visualize a JaxArcTask object using Rich console output with enhanced layout and grouping. Args: task_data: The parsed task data to visualize show_test: Whether to show test pairs show_coordinates: Whether to show grid coordinates show_numbers: If True, show colored numbers; if False, show colored blocks double_width: If True and show_numbers=False, use double-width blocks for square appearance """ console = Console() terminal_width = console.size.width # Enhanced task header with Panel task_id = extract_task_id_from_index(task_data.task_index) task_title = f"Task: {task_id}" # Create properly styled text for task info task_info = Text(justify="center") task_info.append("Training Examples: ", style="bold") task_info.append(str(task_data.num_train_pairs)) task_info.append(" ") task_info.append("Test Examples: ", style="bold") task_info.append(str(task_data.num_test_pairs)) header_panel = Panel( task_info, title=task_title, title_align="left", border_style="bright_blue", box=box.ROUNDED, padding=(0, 1), ) console.print(header_panel) console.print() # Training examples with visual grouping if task_data.num_train_pairs > 0: training_content = [] for i in range(task_data.num_train_pairs): # Create input table with input border style input_table = visualize_grid_rich( task_data.input_grids_examples[i], task_data.input_masks_examples[i], f"Input {i + 1}", show_coordinates, show_numbers, double_width, border_style="input", ) # Create output table with output border style output_table = visualize_grid_rich( task_data.output_grids_examples[i], task_data.output_masks_examples[i], f"Output {i + 1}", show_coordinates, show_numbers, double_width, border_style="output", ) # Responsive layout for each pair if terminal_width >= 120: # Side-by-side layout for wide terminals pair_layout = Columns( [input_table, output_table], equal=True, expand=True ) training_content.append(pair_layout) else: # Vertical layout for narrow terminals training_content.append(input_table) arrow_text = Text("↓", justify="center", style="bold") training_content.append(Padding(arrow_text, (0, 0, 1, 0))) training_content.append(output_table) # Add separator between examples if i < task_data.num_train_pairs - 1: training_content.append(Rule(style="dim")) # Wrap training examples in a blue panel training_group = Group(*training_content) training_panel = Panel( training_group, title=f"Training Examples ({task_data.num_train_pairs})", title_align="left", border_style="blue", box=box.ROUNDED, padding=(1, 1), ) console.print(training_panel) # Test examples with visual grouping if show_test and task_data.num_test_pairs > 0: console.print() # Space between groups test_content = [] for i in range(task_data.num_test_pairs): # Create test input table test_input_table = visualize_grid_rich( task_data.test_input_grids[i], task_data.test_input_masks[i], f"Test Input {i + 1}", show_coordinates, show_numbers, double_width, border_style="input", ) # Create test output table or placeholder if ( i < len(task_data.true_test_output_grids) and task_data.true_test_output_grids[i] is not None ): test_output_table = visualize_grid_rich( task_data.true_test_output_grids[i], task_data.true_test_output_masks[i], f"Test Output {i + 1}", show_coordinates, show_numbers, double_width, border_style="output", ) else: # Create placeholder for unknown test output test_output_table = Table( show_header=False, show_edge=False, show_lines=False, box=None, ) test_output_table.add_column("Unknown", justify="center") question_text = Text("?", style="bold yellow") test_output_table.add_row(question_text) test_output_table = Panel( test_output_table, title=Text(f"Test Output {i + 1}", style="bold green"), border_style="green", box=box.HEAVY, padding=(0, 0), ) # Responsive layout for each test pair if terminal_width >= 120: # Side-by-side layout for wide terminals pair_layout = Columns( [test_input_table, test_output_table], equal=True, expand=True ) test_content.append(pair_layout) else: # Vertical layout for narrow terminals test_content.append(test_input_table) arrow_text = Text("↓", justify="center", style="bold") test_content.append(Padding(arrow_text, (0, 0, 1, 0))) test_content.append(test_output_table) # Add separator between examples if i < task_data.num_test_pairs - 1: test_content.append(Rule(style="dim")) # Wrap test examples in a red panel test_group = Group(*test_content) test_panel = Panel( test_group, title=f"Test Examples ({task_data.num_test_pairs})", title_align="left", border_style="red", box=box.ROUNDED, padding=(1, 1), ) console.print(test_panel)
# ============================================================================ # SECTION: Task Visualization (from task_visualization.py) # ============================================================================
[docs] def draw_task_pair_svg( input_grid: jnp.ndarray | np.ndarray | Grid, output_grid: jnp.ndarray | np.ndarray | Grid | None = None, input_mask: jnp.ndarray | np.ndarray | None = None, output_mask: jnp.ndarray | np.ndarray | None = None, width: float = 15.0, height: float = 8.0, label: str = "", show_unknown_output: bool = True, ) -> drawsvg.Drawing: """Draw an input-output task pair as SVG with strict height and flexible width. Args: input_grid: Input grid data output_grid: Output grid data (optional) input_mask: Optional mask for input grid output_mask: Optional mask for output grid width: Maximum width for the drawing (actual width may be less) height: Strict height constraint - all content must fit within this height label: Label for the pair show_unknown_output: Whether to show "?" for missing output Returns: SVG Drawing object """ padding = 0.5 extra_bottom_padding = 0.25 io_gap = 0.4 # Calculate available space for grids - height is STRICT ymax = (height - padding - extra_bottom_padding - io_gap) / 2 # Calculate aspect ratios to determine width requirements input_grid_data, input_mask_data = _extract_grid_data(input_grid) if input_mask is not None: input_mask_data = np.asarray(input_mask) _, _, (input_h, input_w) = _extract_valid_region(input_grid_data, input_mask_data) input_ratio = input_w / input_h if input_h > 0 else 1.0 max_ratio = input_ratio if output_grid is not None: output_grid_data, output_mask_data = _extract_grid_data(output_grid) if output_mask is not None: output_mask_data = np.asarray(output_mask) _, _, (output_h, output_w) = _extract_valid_region( output_grid_data, output_mask_data ) output_ratio = output_w / output_h if output_h > 0 else 1.0 max_ratio = max(input_ratio, output_ratio) # Calculate required width based on height constraint and aspect ratio required_width = ymax * max_ratio + padding * 2 final_width = max(required_width, padding * 2 + 1.0) # Minimum width # Don't exceed specified width constraint final_width = min(final_width, width) max_grid_width = final_width - padding * 2 # Draw elements following two-pass approach drawlist = [] x_ptr = 0.0 y_ptr = 0.0 # First pass: Draw input grid and determine dimensions input_result = draw_grid_svg( input_grid, input_mask, max_width=max_grid_width, max_height=ymax, label=f"{label} Input" if label else "Input", padding=padding, extra_bottom_padding=extra_bottom_padding, as_group=True, ) if isinstance(input_result, tuple): input_group, input_origin, input_size = input_result else: msg = "Expected tuple result when as_group=True" raise ValueError(msg) # Calculate output dimensions for spacing actual_output_width = 0.0 output_y_total_height = 0.0 output_g = None output_origin_out = (-padding / 2, -padding / 2) if output_grid is not None: output_result = draw_grid_svg( output_grid, output_mask, max_width=max_grid_width, max_height=ymax, label=f"{label} Output" if label else "Output", padding=padding, extra_bottom_padding=extra_bottom_padding, as_group=True, ) if isinstance(output_result, tuple): output_g, output_origin_out, output_size = output_result actual_output_width = output_size[0] output_y_total_height = output_size[1] else: msg = "Expected tuple result when as_group=True" raise ValueError(msg) else: # Approximate height for '?' slot output_y_total_height = ymax + padding + extra_bottom_padding # Position input grid drawlist.append( drawsvg.Use( input_group, x=(max_grid_width + padding - input_size[0]) / 2 - input_origin[0], y=-input_origin[1], ) ) x_ptr += max(input_size[0], actual_output_width) y_ptr = max(y_ptr, input_size[1]) # Second pass: Draw arrow and output arrow_x_center = input_size[0] / 2 arrow_top_y = y_ptr + padding - 0.6 arrow_bottom_y = y_ptr + padding + io_gap - 0.6 drawlist.append( drawsvg.Line( arrow_x_center, arrow_top_y, arrow_x_center, arrow_bottom_y, stroke_width=0.05, stroke="#888888", ) ) drawlist.append( drawsvg.Line( arrow_x_center - 0.15, arrow_bottom_y - 0.2, arrow_x_center, arrow_bottom_y, stroke_width=0.05, stroke="#888888", ) ) drawlist.append( drawsvg.Line( arrow_x_center + 0.15, arrow_bottom_y - 0.2, arrow_x_center, arrow_bottom_y, stroke_width=0.05, stroke="#888888", ) ) # Position output y_content_top_output_area = y_ptr + io_gap if output_g is not None: drawlist.append( drawsvg.Use( output_g, x=(max_grid_width + padding - actual_output_width) / 2 - output_origin_out[0], y=y_ptr - output_origin_out[1] + io_gap, ) ) elif show_unknown_output: # Draw question mark for unknown output q_text_y_center = ( y_content_top_output_area + (ymax / 2) + extra_bottom_padding / 2 ) drawlist.append( drawsvg.Text( "?", x=(max_grid_width + padding) / 2, y=q_text_y_center, font_size=1.0, font_family="Anuphan", font_weight="700", fill="#333333", text_anchor="middle", alignment_baseline="middle", ) ) y_ptr2 = y_ptr + io_gap + output_y_total_height # Calculate final drawing dimensions final_drawing_width = max(x_ptr, final_width) final_drawing_height = max(y_ptr2, height) # Height is strict # Create final drawing drawing = drawsvg.Drawing( final_drawing_width, final_drawing_height + 0.3, origin=(0, 0) ) drawing.append(drawsvg.Rectangle(0, 0, "100%", "100%", fill="#eeeff6")) # Add all draw elements for item in drawlist: drawing.append(item) # Embed font and set scale drawing.embed_google_font( "Anuphan:wght@400;600;700", text=set( "Input Output 0123456789x Test Task ABCDEFGHIJ? abcdefghjklmnopqrstuvwxyz ABCDEFGHIJKLMNOPQRSTUVWXYZ" ), ) drawing.set_pixel_scale(40) return drawing
[docs] def draw_parsed_task_data_svg( task_data: JaxArcTask, width: float = 30.0, height: float = 20.0, include_test: bool | str = False, border_colors: list[str] | None = None, ) -> drawsvg.Drawing: """Draw a complete JaxArcTask as an SVG with strict height and flexible width. Args: task_data: The parsed task data to visualize width: Maximum width for the drawing (actual width may be less) height: Strict height constraint - all content must fit within this height include_test: Whether to include test examples. If 'all', show test outputs too. border_colors: Custom border colors [input_color, output_color] Returns: SVG Drawing object """ from jaxarc.utils.task_manager import extract_task_id_from_index if border_colors is None: border_colors = ["#111111ff", "#111111ff"] padding = 0.5 extra_bottom_padding = 0.25 io_gap = 0.4 # Calculate available space for grids - height is STRICT ymax = (height - padding - extra_bottom_padding - io_gap) / 2 # Prepare examples list examples = [] # Add training examples for i in range(task_data.num_train_pairs): examples.append( ( task_data.input_grids_examples[i], task_data.output_grids_examples[i], task_data.input_masks_examples[i], task_data.output_masks_examples[i], f"{i + 1}", False, # is_test ) ) # Add test examples if include_test: for i in range(task_data.num_test_pairs): show_test_output = include_test == "all" output_grid = ( task_data.true_test_output_grids[i] if show_test_output else None ) output_mask = ( task_data.true_test_output_masks[i] if show_test_output else None ) examples.append( ( task_data.test_input_grids[i], output_grid, task_data.test_input_masks[i], output_mask, f"{i + 1}", True, # is_test ) ) if not examples: # Empty task drawing = drawsvg.Drawing(width, height, origin=(0, 0)) drawing.append(drawsvg.Rectangle(0, 0, "100%", "100%", fill="#eeeff6")) drawing.append( drawsvg.Text( f"Task {extract_task_id_from_index(task_data.task_index)} (No examples)", x=width / 2, y=height / 2, font_size=0.5, text_anchor="middle", fill="black", ) ) drawing.set_pixel_scale(40) return drawing # Prepare training examples train_examples = [] for i in range(task_data.num_train_pairs): train_examples.append( ( task_data.input_grids_examples[i], task_data.output_grids_examples[i], task_data.input_masks_examples[i], task_data.output_masks_examples[i], f"{i + 1}", False, # is_test ) ) # Prepare test examples test_examples = [] if include_test: for i in range(task_data.num_test_pairs): show_test_output = include_test == "all" output_grid = ( task_data.true_test_output_grids[i] if show_test_output else None ) output_mask = ( task_data.true_test_output_masks[i] if show_test_output else None ) test_examples.append( ( task_data.test_input_grids[i], output_grid, task_data.test_input_masks[i], output_mask, f"{i + 1}", True, # is_test ) ) # Combine all examples examples = train_examples + test_examples # Calculate ideal width for each example based on aspect ratio and height constraint max_widths = np.zeros(len(examples)) for i, ( input_grid, output_grid, input_mask, output_mask, _label, _is_test, ) in enumerate(examples): input_grid_data, _ = _extract_grid_data(input_grid) input_mask_data = np.asarray(input_mask) if input_mask is not None else None _, _, (input_h, input_w) = _extract_valid_region( input_grid_data, input_mask_data ) input_ratio = input_w / input_h if input_h > 0 else 1.0 max_ratio = input_ratio if output_grid is not None: output_grid_data, _ = _extract_grid_data(output_grid) output_mask_data = ( np.asarray(output_mask) if output_mask is not None else None ) _, _, (output_h, output_w) = _extract_valid_region( output_grid_data, output_mask_data ) output_ratio = output_w / output_h if output_h > 0 else 1.0 max_ratio = max(input_ratio, output_ratio) # Calculate ideal width based on height constraint and aspect ratio xmax_for_pair = ymax * max_ratio max_widths[i] = xmax_for_pair # Add extra spacing between training and test groups group_spacing = 0.5 if len(train_examples) > 0 and len(test_examples) > 0 else 0.0 # Proportional allocation algorithm - distribute width based on needs paddingless_width = width - padding * len(examples) - group_spacing allocation = np.zeros_like(max_widths) increment = 0.01 if paddingless_width > 0 and len(examples) > 0: if np.any(max_widths > 0): for _ in range(int(paddingless_width // increment)): incr_mask = (allocation + increment) <= max_widths if incr_mask.sum() > 0: allocation[incr_mask] += increment / incr_mask.sum() else: break # Fallback: equal distribution if no progress made if np.sum(allocation) == 0: allocation[:] = paddingless_width / len(examples) # Two-pass rendering following reference implementation pattern drawlist = [] # Account for squircle margins in positioning if we have grouping squircle_margin = 0.15 has_grouping = len(train_examples) > 0 and len(test_examples) > 0 x_offset = squircle_margin if has_grouping else 0.0 y_offset = squircle_margin if has_grouping else 0.0 # Calculate group boundaries train_width = ( sum(allocation[: len(train_examples)]) + padding * len(train_examples) if train_examples else 0 ) test_start_x = x_offset + train_width + (group_spacing if has_grouping else 0) x_ptr = x_offset y_ptr = y_offset # First pass: Draw input grids and calculate input row height for i, ( input_grid, output_grid, input_mask, output_mask, label, is_test, ) in enumerate(examples): input_result = draw_grid_svg( input_grid, input_mask, max_width=allocation[i], max_height=ymax, label=f"In #{label}", border_color=border_colors[0], padding=padding, extra_bottom_padding=extra_bottom_padding, as_group=True, ) if isinstance(input_result, tuple): input_group, input_origin, input_size = input_result else: msg = "Expected tuple result when as_group=True" raise ValueError(msg) # Calculate actual output width for spacing actual_output_width = 0.0 if output_grid is not None: output_result_for_spacing = draw_grid_svg( output_grid, output_mask, max_width=allocation[i], max_height=ymax, label=f"Out #{label}", border_color=border_colors[1], padding=padding, extra_bottom_padding=extra_bottom_padding, as_group=True, ) if isinstance(output_result_for_spacing, tuple): _, _, (actual_output_width, _) = output_result_for_spacing # Determine x position based on whether this is a test example if is_test and has_grouping: # For test examples, position relative to test start test_index = i - len(train_examples) test_x_offset = ( sum(allocation[len(train_examples) : len(train_examples) + test_index]) + padding * test_index ) current_x_ptr = test_start_x + test_x_offset else: # For training examples, use current x_ptr current_x_ptr = x_ptr # Position input grid drawlist.append( drawsvg.Use( input_group, x=current_x_ptr + (allocation[i] + padding - input_size[0]) / 2 - input_origin[0], y=y_offset - input_origin[1], ) ) # Only advance x_ptr for training examples or when not grouping if not is_test or not has_grouping: x_ptr += max(input_size[0], actual_output_width) y_ptr = max(y_ptr, input_size[1]) # Second pass: Draw arrows and outputs y_ptr2 = y_offset for i, ( input_grid, output_grid, input_mask, output_mask, label, is_test, ) in enumerate(examples): # Recalculate input for positioning input_result = draw_grid_svg( input_grid, input_mask, max_width=allocation[i], max_height=ymax, label=f"In #{label}", border_color=border_colors[0], padding=padding, extra_bottom_padding=extra_bottom_padding, as_group=True, ) if isinstance(input_result, tuple): input_group, input_origin, input_size = input_result else: msg = "Expected tuple result when as_group=True" raise ValueError(msg) output_g = None output_x_recalc = 0.0 output_y_total_height = 0.0 output_origin_recalc = (-padding / 2, -padding / 2) show_output = (not is_test) or (include_test == "all") if show_output and output_grid is not None: output_result = draw_grid_svg( output_grid, output_mask, max_width=allocation[i], max_height=ymax, label=f"Out #{label}", border_color=border_colors[1], padding=padding, extra_bottom_padding=extra_bottom_padding, as_group=True, ) if isinstance(output_result, tuple): output_g, output_origin_recalc, output_size = output_result output_x_recalc = output_size[0] output_y_total_height = output_size[1] else: msg = "Expected tuple result when as_group=True" raise ValueError(msg) else: # Approximate height for '?' slot output_y_total_height = ymax + padding + extra_bottom_padding # Determine x position based on whether this is a test example if is_test and has_grouping: # For test examples, position relative to test start test_index = i - len(train_examples) test_x_offset = ( sum(allocation[len(train_examples) : len(train_examples) + test_index]) + padding * test_index ) current_x_ptr = test_start_x + test_x_offset else: # For training examples, calculate position from start train_x_offset = sum(allocation[:i]) + padding * i current_x_ptr = x_offset + train_x_offset # Draw arrow arrow_x_center = current_x_ptr + input_size[0] / 2 arrow_top_y = y_ptr + padding - 0.6 arrow_bottom_y = y_ptr + padding + io_gap - 0.6 drawlist.append( drawsvg.Line( arrow_x_center, arrow_top_y, arrow_x_center, arrow_bottom_y, stroke_width=0.05, stroke="#888888", ) ) drawlist.append( drawsvg.Line( arrow_x_center - 0.15, arrow_bottom_y - 0.2, arrow_x_center, arrow_bottom_y, stroke_width=0.05, stroke="#888888", ) ) drawlist.append( drawsvg.Line( arrow_x_center + 0.15, arrow_bottom_y - 0.2, arrow_x_center, arrow_bottom_y, stroke_width=0.05, stroke="#888888", ) ) # Position output y_content_top_output_area = y_ptr + io_gap if show_output and output_g is not None: drawlist.append( drawsvg.Use( output_g, x=current_x_ptr + (allocation[i] + padding - output_x_recalc) / 2 - output_origin_recalc[0], y=y_ptr - output_origin_recalc[1] + io_gap, ) ) else: # Draw question mark q_text_y_center = ( y_content_top_output_area + (ymax / 2) + extra_bottom_padding / 2 ) drawlist.append( drawsvg.Text( "?", x=current_x_ptr + (allocation[i] + padding) / 2, y=q_text_y_center, font_size=1.0, font_family="Anuphan", font_weight="700", fill="#333333", text_anchor="middle", alignment_baseline="middle", ) ) y_ptr2 = max(y_ptr2, y_ptr + io_gap + output_y_total_height) # Calculate final drawing dimensions accounting for squircle margins if has_grouping: test_width = ( sum(allocation[len(train_examples) :]) + padding * len(test_examples) if test_examples else 0 ) final_drawing_width = round( x_offset + train_width + group_spacing + test_width + squircle_margin, 1 ) else: final_drawing_width = round(x_ptr, 1) final_drawing_height = round(y_ptr2 + (squircle_margin if has_grouping else 0), 1) # Ensure dimensions are not negative or too small final_drawing_width = max(final_drawing_width, 1.0) final_drawing_height = max(final_drawing_height, height) # Height is strict # Create final drawing with calculated dimensions drawing = drawsvg.Drawing( final_drawing_width, final_drawing_height + 0.3, origin=(0, 0) ) drawing.append(drawsvg.Rectangle(0, 0, "100%", "100%", fill="#eeeff6")) # Add all draw elements for item in drawlist: drawing.append(item) # Add grouping squircles if we have both training and test examples if len(train_examples) > 0 and len(test_examples) > 0: # Calculate training group bounds train_width = sum(allocation[: len(train_examples)]) + padding * len( train_examples ) # Training group squircle train_squircle_elements = _draw_dotted_squircle( x=0, y=0, width=train_width + squircle_margin * 2, height=y_ptr2 - y_offset + squircle_margin, label="Train", stroke_color="#4A90E2", ) for element in train_squircle_elements: drawing.append(element) # Test group squircle test_start_x = train_width + group_spacing + squircle_margin test_width = sum(allocation[len(train_examples) :]) + padding * len( test_examples ) test_squircle_elements = _draw_dotted_squircle( x=test_start_x, y=0, width=test_width + squircle_margin, height=y_ptr2 - y_offset + squircle_margin, label="Test", stroke_color="#E94B3C", ) for element in test_squircle_elements: drawing.append(element) # Add title font_size = 0.3 title_text = f"Task: {extract_task_id_from_index(task_data.task_index)}" drawing.append( drawsvg.Text( title_text, x=final_drawing_width - 0.1, y=final_drawing_height + 0.2, font_size=font_size, font_family="Anuphan", font_weight="600", fill="#666666", text_anchor="end", alignment_baseline="bottom", ) ) # Embed font and set scale drawing.embed_google_font( "Anuphan:wght@400;600;700", text=set( "Input Output 0123456789x Test Task ABCDEFGHIJ? abcdefghjklmnopqrstuvwxyz ABCDEFGHIJKLMNOPQRSTUVWXYZ" ), ) drawing.set_pixel_scale(40) return drawing
# ============================================================================ # SECTION: Episode Visualization (from episode_visualization.py) # ============================================================================ def draw_enhanced_episode_summary_svg( summary_data: Any, step_data: List[Any], config: Optional[Any] = None, width: float = 1400.0, height: float = 1000.0, ) -> str: """Generate enhanced SVG visualization of episode summary with comprehensive metrics. This enhanced version includes: - Reward progression chart with key moments highlighted - Similarity progression chart - Grid state thumbnails at key moments - Performance metrics panel - Success/failure analysis Args: summary_data: Episode summary data step_data: List of step visualization data config: Optional visualization configuration width: Width of the visualization height: Height of the visualization Returns: SVG string containing the enhanced episode summary """ import drawsvg as draw # Create main drawing drawing = draw.Drawing(width, height) drawing.append(draw.Rectangle(0, 0, width, height, fill="#f8f9fa")) # Layout parameters padding = 40 title_height = 100 metrics_height = 80 chart_height = 250 thumbnails_height = 200 remaining_height = ( height - title_height - metrics_height - 2 * chart_height - thumbnails_height - 6 * padding ) # Add enhanced title section title_bg_height = title_height - 20 drawing.append( draw.Rectangle( padding, padding, width - 2 * padding, title_bg_height, fill="#ffffff", stroke="#dee2e6", stroke_width=1, rx=8, ) ) # Main title title_text = f"Episode {summary_data.episode_num} Summary" drawing.append( draw.Text( title_text, font_size=32, x=width / 2, y=padding + 40, text_anchor="middle", font_family="Anuphan", font_weight="600", fill="#2c3e50", ) ) # Success indicator success_color = "#27ae60" if summary_data.success else "#e74c3c" success_text = "SUCCESS" if summary_data.success else "FAILED" drawing.append( draw.Text( success_text, font_size=18, x=width - padding - 20, y=padding + 30, text_anchor="end", font_family="Anuphan", font_weight="700", fill=success_color, ) ) # Task ID drawing.append( draw.Text( f"Task: {summary_data.task_id}", font_size=16, x=padding + 20, y=padding + 70, text_anchor="start", font_family="Anuphan", font_weight="400", fill="#6c757d", ) ) # Metrics panel metrics_y = title_height + 2 * padding drawing.append( draw.Rectangle( padding, metrics_y, width - 2 * padding, metrics_height, fill="#ffffff", stroke="#dee2e6", stroke_width=1, rx=8, ) ) # Metrics grid metrics = [ ("Total Steps", summary_data.total_steps, ""), ("Total Reward", summary_data.total_reward, ".3f"), ("Final Similarity", summary_data.final_similarity, ".3f"), ( "Avg Reward/Step", summary_data.total_reward / max(summary_data.total_steps, 1), ".3f", ), ] metric_width = (width - 2 * padding - 60) / len(metrics) for i, (name, value, fmt) in enumerate(metrics): x_pos = padding + 20 + i * metric_width # Metric name drawing.append( draw.Text( name, font_size=14, x=x_pos, y=metrics_y + 25, text_anchor="start", font_family="Anuphan", font_weight="600", fill="#495057", ) ) # Metric value value_text = f"{value:{fmt}}" if fmt else str(value) drawing.append( draw.Text( value_text, font_size=18, x=x_pos, y=metrics_y + 50, text_anchor="start", font_family="Anuphan", font_weight="500", fill="#2c3e50", ) ) # Reward progression chart chart1_y = metrics_y + metrics_height + padding chart_width = (width - 3 * padding) / 2 drawing.append( draw.Rectangle( padding, chart1_y, chart_width, chart_height, fill="white", stroke="#dee2e6", stroke_width=1, rx=8, ) ) # Chart title drawing.append( draw.Text( "Reward Progression", font_size=18, x=padding + 20, y=chart1_y + 30, text_anchor="start", font_family="Anuphan", font_weight="600", fill="#2c3e50", ) ) # Draw reward progression line if summary_data.reward_progression and len(summary_data.reward_progression) > 1: rewards = summary_data.reward_progression chart_inner_width = chart_width - 40 chart_inner_height = chart_height - 80 max_reward = max(rewards) if max(rewards) > 0 else 1 min_reward = min(rewards) if min(rewards) < 0 else 0 reward_range = max_reward - min_reward if max_reward != min_reward else 1 # Draw grid lines for i in range(5): y_grid = chart1_y + 50 + i * (chart_inner_height / 4) drawing.append( draw.Line( padding + 20, y_grid, padding + 20 + chart_inner_width, y_grid, stroke="#e9ecef", stroke_width=1, ) ) # Draw reward line points = [] for i, reward in enumerate(rewards): x = padding + 20 + (i / (len(rewards) - 1)) * chart_inner_width y = ( chart1_y + 50 + chart_inner_height - ((reward - min_reward) / reward_range) * chart_inner_height ) points.append((x, y)) if len(points) > 1: path_data = f"M {points[0][0]} {points[0][1]}" for x, y in points[1:]: path_data += f" L {x} {y}" drawing.append( draw.Path( d=path_data, stroke="#3498db", stroke_width=3, fill="none", ) ) # Add points for i, (x, y) in enumerate(points): # Highlight key moments if i in summary_data.key_moments: drawing.append( draw.Circle( x, y, 6, fill="#e74c3c", stroke="white", stroke_width=2, ) ) else: drawing.append( draw.Circle( x, y, 4, fill="#3498db", stroke="white", stroke_width=1, ) ) # Similarity progression chart chart2_x = padding + chart_width + padding drawing.append( draw.Rectangle( chart2_x, chart1_y, chart_width, chart_height, fill="white", stroke="#dee2e6", stroke_width=1, rx=8, ) ) # Chart title drawing.append( draw.Text( "Similarity Progression", font_size=18, x=chart2_x + 20, y=chart1_y + 30, text_anchor="start", font_family="Anuphan", font_weight="600", fill="#2c3e50", ) ) # Draw similarity progression line if ( summary_data.similarity_progression and len(summary_data.similarity_progression) > 1 ): similarities = summary_data.similarity_progression chart_inner_width = chart_width - 40 chart_inner_height = chart_height - 80 # Draw grid lines for i in range(5): y_grid = chart1_y + 50 + i * (chart_inner_height / 4) drawing.append( draw.Line( chart2_x + 20, y_grid, chart2_x + 20 + chart_inner_width, y_grid, stroke="#e9ecef", stroke_width=1, ) ) # Draw similarity line points = [] for i, similarity in enumerate(similarities): x = chart2_x + 20 + (i / (len(similarities) - 1)) * chart_inner_width y = chart1_y + 50 + chart_inner_height - (similarity * chart_inner_height) points.append((x, y)) if len(points) > 1: path_data = f"M {points[0][0]} {points[0][1]}" for x, y in points[1:]: path_data += f" L {x} {y}" drawing.append( draw.Path( d=path_data, stroke="#27ae60", stroke_width=3, fill="none", ) ) # Add points for i, (x, y) in enumerate(points): # Highlight key moments if i in summary_data.key_moments: drawing.append( draw.Circle( x, y, 6, fill="#e74c3c", stroke="white", stroke_width=2, ) ) else: drawing.append( draw.Circle( x, y, 4, fill="#27ae60", stroke="white", stroke_width=1, ) ) # Key moments thumbnails section thumbnails_y = chart1_y + chart_height + padding if summary_data.key_moments and step_data: drawing.append( draw.Rectangle( padding, thumbnails_y, width - 2 * padding, thumbnails_height, fill="white", stroke="#dee2e6", stroke_width=1, rx=8, ) ) # Section title drawing.append( draw.Text( "Key Moments", font_size=18, x=padding + 20, y=thumbnails_y + 30, text_anchor="start", font_family="Anuphan", font_weight="600", fill="#2c3e50", ) ) # Draw thumbnails for key moments thumbnail_size = 120 thumbnail_spacing = 20 thumbnails_per_row = min(len(summary_data.key_moments), 8) for i, step_idx in enumerate(summary_data.key_moments[:thumbnails_per_row]): if step_idx < len(step_data): step = step_data[step_idx] thumb_x = padding + 20 + i * (thumbnail_size + thumbnail_spacing) thumb_y = thumbnails_y + 50 # Draw thumbnail background drawing.append( draw.Rectangle( thumb_x, thumb_y, thumbnail_size, thumbnail_size, fill="#f8f9fa", stroke="#dee2e6", stroke_width=1, rx=4, ) ) # Draw simplified grid representation if hasattr(step, "after_grid"): grid_data = np.asarray(step.after_grid.data) grid_size = min(thumbnail_size - 20, 80) cell_size = grid_size / max(grid_data.shape) for row in range(min(grid_data.shape[0], 8)): for col in range(min(grid_data.shape[1], 8)): color_val = int(grid_data[row, col]) if config and hasattr(config, "get_color_palette"): color_palette = config.get_color_palette() else: color_palette = ARC_COLOR_PALETTE fill_color = color_palette.get(color_val, "#CCCCCC") drawing.append( draw.Rectangle( thumb_x + 10 + col * cell_size, thumb_y + 10 + row * cell_size, cell_size, cell_size, fill=fill_color, stroke="#6c757d", stroke_width=0.5, ) ) # Add step label drawing.append( draw.Text( f"Step {step_idx}", font_size=12, x=thumb_x + thumbnail_size / 2, y=thumb_y + thumbnail_size + 15, text_anchor="middle", font_family="Anuphan", font_weight="500", fill="#495057", ) ) # Add footer with timing information footer_y = height - 40 if hasattr(summary_data, "start_time") and hasattr(summary_data, "end_time"): duration = summary_data.end_time - summary_data.start_time footer_text = f"Duration: {duration:.1f}s | Generated: {time.strftime('%Y-%m-%d %H:%M:%S')}" else: footer_text = f"Generated: {time.strftime('%Y-%m-%d %H:%M:%S')}" drawing.append( draw.Text( footer_text, font_size=12, x=width / 2, y=footer_y, text_anchor="middle", font_family="Anuphan", font_weight="300", fill="#adb5bd", ) ) return drawing.as_svg()
[docs] def draw_episode_summary_svg( summary_data: Any, step_data: List[Any], config: Optional[Any] = None, width: float = 1400.0, height: float = 1000.0, ) -> str: """Generate episode summary visualization (enhanced version).""" return draw_enhanced_episode_summary_svg( summary_data=summary_data, step_data=step_data, config=config, width=width, height=height, )
[docs] def create_episode_comparison_visualization( episodes_data: List[Any], comparison_type: str = "reward_progression", width: float = 1200.0, height: float = 600.0, ) -> str: """Create comparison visualization across multiple episodes. Args: episodes_data: List of episode summary data comparison_type: Type of comparison ("reward_progression", "similarity", "performance") width: Width of the visualization height: Height of the visualization Returns: SVG string containing the comparison visualization """ import drawsvg as draw # Create main drawing drawing = draw.Drawing(width, height) drawing.append(draw.Rectangle(0, 0, width, height, fill="#f8f9fa")) # Layout parameters padding = 40 title_height = 80 chart_height = height - title_height - 2 * padding - 60 # Add title title_text = f"Episode Comparison - {comparison_type.replace('_', ' ').title()}" drawing.append( draw.Text( title_text, font_size=28, x=width / 2, y=50, text_anchor="middle", font_family="Anuphan", font_weight="600", fill="#2c3e50", ) ) # Chart area chart_y = title_height + padding chart_width = width - 2 * padding drawing.append( draw.Rectangle( padding, chart_y, chart_width, chart_height, fill="white", stroke="#dee2e6", stroke_width=1, rx=8, ) ) # Colors for different episodes episode_colors = ["#3498db", "#e74c3c", "#27ae60", "#f39c12", "#9b59b6", "#1abc9c"] if comparison_type == "reward_progression": # Draw reward progression for each episode chart_inner_width = chart_width - 60 chart_inner_height = chart_height - 60 # Find global min/max for scaling all_rewards = [] for episode in episodes_data: if hasattr(episode, "reward_progression") and episode.reward_progression: all_rewards.extend(episode.reward_progression) if all_rewards: max_reward = max(all_rewards) min_reward = min(all_rewards) reward_range = max_reward - min_reward if max_reward != min_reward else 1 # Draw grid lines for i in range(5): y_grid = chart_y + 30 + i * (chart_inner_height / 4) drawing.append( draw.Line( padding + 30, y_grid, padding + 30 + chart_inner_width, y_grid, stroke="#e9ecef", stroke_width=1, ) ) # Draw each episode's progression for ep_idx, episode in enumerate(episodes_data[: len(episode_colors)]): if ( hasattr(episode, "reward_progression") and episode.reward_progression ): rewards = episode.reward_progression color = episode_colors[ep_idx] points = [] for i, reward in enumerate(rewards): x = padding + 30 + (i / (len(rewards) - 1)) * chart_inner_width y = ( chart_y + 30 + chart_inner_height - ((reward - min_reward) / reward_range) * chart_inner_height ) points.append((x, y)) if len(points) > 1: path_data = f"M {points[0][0]} {points[0][1]}" for x, y in points[1:]: path_data += f" L {x} {y}" drawing.append( draw.Path( d=path_data, stroke=color, stroke_width=2, fill="none", ) ) # Add points for x, y in points: drawing.append( draw.Circle( x, y, 3, fill=color, stroke="white", stroke_width=1, ) ) # Add legend entry legend_y = chart_y + chart_height + 20 + ep_idx * 20 drawing.append( draw.Line( padding + 20, legend_y, padding + 40, legend_y, stroke=color, stroke_width=3, ) ) drawing.append( draw.Text( f"Episode {episode.episode_num}", font_size=14, x=padding + 50, y=legend_y + 5, text_anchor="start", font_family="Anuphan", font_weight="400", fill="#495057", ) ) elif comparison_type == "performance": # Create bar chart comparing final performance metrics metrics = ["total_reward", "final_similarity", "total_steps"] metric_labels = ["Total Reward", "Final Similarity", "Steps"] chart_inner_width = chart_width - 60 chart_inner_height = chart_height - 60 bar_width = (chart_width - 100) / ( len(episodes_data) * len(metrics) + len(metrics) ) group_spacing = bar_width * 0.5 for metric_idx, (metric, label) in enumerate(zip(metrics, metric_labels)): # Get values for this metric values = [] for episode in episodes_data: if hasattr(episode, metric): values.append(getattr(episode, metric)) else: values.append(0) if values: max_val = max(values) if max(values) > 0 else 1 # Draw bars for this metric for ep_idx, (episode, value) in enumerate(zip(episodes_data, values)): x = ( padding + 30 + metric_idx * (len(episodes_data) * bar_width + group_spacing) + ep_idx * bar_width ) bar_height = (value / max_val) * (chart_inner_height - 40) y = chart_y + chart_height - 30 - bar_height color = episode_colors[ep_idx % len(episode_colors)] drawing.append( draw.Rectangle( x, y, bar_width * 0.8, bar_height, fill=color, stroke="white", stroke_width=1, ) ) # Add metric label label_x = ( padding + 30 + metric_idx * (len(episodes_data) * bar_width + group_spacing) + (len(episodes_data) * bar_width) / 2 ) drawing.append( draw.Text( label, font_size=12, x=label_x, y=chart_y + chart_height - 10, text_anchor="middle", font_family="Anuphan", font_weight="500", fill="#495057", ) ) return drawing.as_svg()
def display_grid( grid: GridArray | np.ndarray | Grid, title: str = "Grid", show_mask: bool = True ) -> None: """Display a single grid using Rich.""" console = Console() # Note: visualize_grid_rich currently uses the mask embedded in grid if present. # show_mask parameter is kept for API compatibility but currently ignored # until visualize_grid_rich supports forcing mask visibility. console.print(visualize_grid_rich(grid, title=title)) def render_ansi(grid_input: GridArray | np.ndarray | Grid) -> str: """Render grid as ANSI string. Args: grid_input: Grid data (JAX array, numpy array, or Grid object) Returns: A string containing the ANSI representation of the grid. """ # Create a temporary console for capturing output temp_console = Console(force_terminal=True, color_system="truecolor", width=1000) table = visualize_grid_rich(grid_input) with temp_console.capture() as capture: temp_console.print(table) return capture.get() def display_step( step_data: dict[str, Any], step_idx: int, console: Console | None = None, show_coordinates: bool = False, show_numbers: bool = False, double_width: bool = True, ) -> None: """Display a single step of the episode using Rich. Args: step_data: Step data dictionary step_idx: Index of the step console: Optional Rich console (creates one if None) show_coordinates: Whether to show grid coordinates show_numbers: If True, show colored numbers; if False, show colored blocks double_width: If True and show_numbers=False, use double-width blocks for square appearance """ if console is None: console = Console() # Extract relevant data input_grid = step_data["input_grid"] output_grid = step_data["output_grid"] input_mask = step_data["input_mask"] output_mask = step_data["output_mask"] # Create title title = f"Step {step_idx + 1}" # Display input-output pair visualize_task_pair_rich( input_grid, output_grid, input_mask, output_mask, title=title, show_numbers=show_numbers, double_width=double_width, console=console, )