Source code for jaxarc.utils.visualization.core

"""
Visualization core utilities for JaxARC.

Contains constants, utilities, and SVG rendering engine.
"""

from __future__ import annotations

import io
import shutil
from pathlib import Path
from typing import TYPE_CHECKING, Any, cast

import drawsvg  # type: ignore[import-untyped]
import jax.numpy as jnp
import numpy as np
from loguru import logger

from jaxarc.utils.serialization_utils import serialize_jax_array

if TYPE_CHECKING:
    from jaxarc.types import Grid, GridArray


# ============================================================================
# SECTION: Constants and Color Definitions (from constants.py)
# ============================================================================

# ARC color palette - matches the provided color map
ARC_COLOR_PALETTE: dict[int, str] = {
    0: "#252525",  # 0: black
    1: "#0074D9",  # 1: blue
    2: "#FF4136",  # 2: red
    3: "#37D449",  # 3: green
    4: "#FFDC00",  # 4: yellow
    5: "#E6E6E6",  # 5: grey
    6: "#F012BE",  # 6: pink
    7: "#FF871E",  # 7: orange
    8: "#54D2EB",  # 8: light blue
    9: "#8D1D2C",  # 9: brown
    10: "#FFFFFF",  # 10: white (for padding/invalid)
}


# ============================================================================
# SECTION: Utility Functions (from utils.py)
# ============================================================================


def _extract_grid_data(
    grid_input: GridArray | np.ndarray | Grid,
) -> tuple[np.ndarray, np.ndarray | None]:
    """Extract numpy array and mask from various grid input types.

    Args:
        grid_input: Grid data in various formats

    Returns:
        Tuple of (grid_data as numpy array, mask as numpy array or None)

    Raises:
        ValueError: If input type is not supported
    """

    # Check for Grid type by duck typing (more robust than isinstance)
    if hasattr(grid_input, "data") and hasattr(grid_input, "mask"):
        return serialize_jax_array(grid_input.data), serialize_jax_array(
            grid_input.mask
        )
    if isinstance(grid_input, (jnp.ndarray, np.ndarray)):
        return serialize_jax_array(grid_input), None

    msg = f"Unsupported grid input type: {type(grid_input)}"
    raise ValueError(msg)


def _extract_valid_region(
    grid: np.ndarray, mask: np.ndarray | None = None
) -> tuple[np.ndarray, tuple[int, int], tuple[int, int]]:
    """Extract the valid (non-padded) region from a grid.

    Args:
        grid: The grid array
        mask: Optional boolean mask indicating valid cells

    Returns:
        Tuple of (valid_grid, (start_row, start_col), (height, width))
    """
    if mask is None:
        # Assume all cells are valid if no mask provided
        return grid, (0, 0), (grid.shape[0], grid.shape[1])

    if not np.any(mask):
        # No valid cells
        return np.array([[]], dtype=grid.dtype), (0, 0), (0, 0)

    # Find bounding box of valid region
    valid_rows = np.where(np.any(mask, axis=1))[0]
    valid_cols = np.where(np.any(mask, axis=0))[0]

    if len(valid_rows) == 0 or len(valid_cols) == 0:
        return np.array([[]], dtype=grid.dtype), (0, 0), (0, 0)

    start_row, end_row = valid_rows[0], valid_rows[-1] + 1
    start_col, end_col = valid_cols[0], valid_cols[-1] + 1

    valid_grid = grid[start_row:end_row, start_col:end_col]

    return (
        valid_grid,
        (start_row, start_col),
        (end_row - start_row, end_col - start_col),
    )


def get_color_name(color_id: int) -> str:
    """Get human-readable color name from color ID.

    Args:
        color_id: Integer color ID

    Returns:
        Human-readable color name
    """
    color_names = {
        0: "Black (0)",
        1: "Blue (1)",
        2: "Red (2)",
        3: "Green (3)",
        4: "Yellow (4)",
        5: "Grey (5)",
        6: "Pink (6)",
        7: "Orange (7)",
        8: "Light Blue (8)",
        9: "Brown (9)",
    }

    return color_names.get(color_id, f"Color {color_id}")


def detect_changed_cells(
    before_grid: Grid,
    after_grid: Grid,
) -> jnp.ndarray:
    """Detect which cells changed between before and after grids.

    Args:
        before_grid: Grid state before the action
        after_grid: Grid state after the action

    Returns:
        Boolean mask indicating which cells changed
    """
    before_data = serialize_jax_array(before_grid.data)
    after_data = serialize_jax_array(after_grid.data)

    # Handle different shapes by padding to match
    max_height = max(before_data.shape[0], after_data.shape[0])
    max_width = max(before_data.shape[1], after_data.shape[1])

    # Pad both grids to same size
    before_padded = np.zeros((max_height, max_width), dtype=before_data.dtype)
    after_padded = np.zeros((max_height, max_width), dtype=after_data.dtype)

    before_padded[: before_data.shape[0], : before_data.shape[1]] = before_data
    after_padded[: after_data.shape[0], : after_data.shape[1]] = after_data

    # Find changed cells
    changed = before_padded != after_padded

    return jnp.array(changed)


def infer_fill_color_from_grids(
    before_grid: Grid, after_grid: Grid, selection_mask: np.ndarray
) -> int:
    """Infer what color was used to fill selected cells by comparing grids.

    Args:
        before_grid: Grid state before the action
        after_grid: Grid state after the action
        selection_mask: Boolean mask of selected cells

    Returns:
        Color ID that was used for filling, or -1 if can't determine
    """
    try:
        before_data = serialize_jax_array(before_grid.data)
        after_data = serialize_jax_array(after_grid.data)

        # Find cells that were selected and changed
        for i in range(min(before_data.shape[0], after_data.shape[0])):
            for j in range(min(before_data.shape[1], after_data.shape[1])):
                if (
                    i < selection_mask.shape[0]
                    and j < selection_mask.shape[1]
                    and selection_mask[i, j]
                    and before_data[i, j] != after_data[i, j]
                ):
                    # This cell was selected and changed, return the new color
                    return int(after_data[i, j])

        return -1  # Couldn't determine
    except Exception:
        return -1


def get_info_metric(info: dict, key: str, default=None):
    """Extract metric value from info dict, supporting both old and new structures.

    This function handles the transition from storing metrics directly in the info
    dictionary to storing them nested under info['metrics']. It prioritizes the
    nested structure when both are present.

    Args:
        info: Info dictionary from environment step
        key: Metric key to extract
        default: Default value if metric not found

    Returns:
        Metric value, converted to appropriate type
    """
    # First check if it's in info.metrics (higher priority - new format)
    if "metrics" in info and key in info["metrics"]:
        val = info["metrics"][key]
        return float(val) if hasattr(val, "item") else val

    # Then check if it's directly in info (lower priority - old format)
    if key in info:
        val = info[key]
        return float(val) if hasattr(val, "item") else val

    return default


def hex_to_rgb(hex_color: str) -> tuple[int, int, int]:
    """Convert hex color string to RGB tuple."""
    hex_color = hex_color.lstrip("#")
    if len(hex_color) == 8:  # Handle RGBA (ignore alpha for now)
        hex_color = hex_color[:6]
    r = int(hex_color[0:2], 16)
    g = int(hex_color[2:4], 16)
    b = int(hex_color[4:6], 16)
    return (r, g, b)


def render_rgb(
    grid_input: GridArray | np.ndarray | Grid, cell_size: int = 20
) -> np.ndarray:
    """Render grid as an RGB array.

    Args:
        grid_input: Grid data (JAX array, numpy array, or Grid object)
        cell_size: The size of each grid cell in pixels.

    Returns:
        A numpy array of shape (height * cell_size, width * cell_size, 3) representing the RGB image.
    """
    grid, _ = _extract_grid_data(grid_input)
    height, width = grid.shape

    # Create empty RGB array
    img_height = height * cell_size
    img_width = width * cell_size
    img = np.zeros((img_height, img_width, 3), dtype=np.uint8)

    # Precompute color map
    color_map = {k: hex_to_rgb(v) for k, v in ARC_COLOR_PALETTE.items()}
    # Add default for unknown colors (gray)
    default_color = (128, 128, 128)

    for r in range(height):
        for c in range(width):
            color_id = int(grid[r, c])
            rgb = color_map.get(color_id, default_color)

            # Fill the cell
            r_start = r * cell_size
            r_end = (r + 1) * cell_size
            c_start = c * cell_size
            c_end = (c + 1) * cell_size

            img[r_start:r_end, c_start:c_end] = rgb

            # Draw grid lines (optional, maybe simple 1px border)
            # For now, let's keep it simple without borders or add a thin border
            # Adding a 1px border for better visibility
            if cell_size > 2:
                border_color = (50, 50, 50)
                img[r_start:r_end, c_start] = border_color
                img[r_start, c_start:c_end] = border_color

    return img


def _clear_output_directory(output_dir: str) -> None:
    """Clear output directory for new episode."""
    output_path = Path(output_dir)
    if output_path.exists():
        shutil.rmtree(output_path)
    output_path.mkdir(parents=True, exist_ok=True)


# ============================================================================
# SECTION: SVG Core (from svg_core.py)
# ============================================================================


[docs] def draw_grid_svg( grid_input: jnp.ndarray | np.ndarray | Grid, mask: jnp.ndarray | np.ndarray | None = None, max_width: float = 10.0, max_height: float = 10.0, padding: float = 0.5, extra_bottom_padding: float = 0.5, label: str = "", border_color: str = "#111111ff", show_size: bool = True, as_group: bool = False, ) -> drawsvg.Drawing | tuple[drawsvg.Group, tuple[float, float], tuple[float, float]]: """Draw a single grid as an SVG. Args: grid_input: Grid data (JAX array, numpy array, or Grid object) mask: Optional boolean mask indicating valid cells max_width: Maximum width for the drawing max_height: Maximum height for the drawing padding: Padding around the grid extra_bottom_padding: Extra padding at bottom for labels label: Label to display below the grid border_color: Color for the grid border show_size: Whether to show grid dimensions as_group: If True, return as a group for inclusion in larger drawings Returns: Either a Drawing object or tuple of (Group, origin, size) if as_group=True """ grid, grid_mask = _extract_grid_data(grid_input) if mask is None: mask = grid_mask if mask is not None: mask = np.asarray(mask) # Handle empty grids if grid.size == 0: if as_group: return ( drawsvg.Group(), (-0.5 * padding, -0.5 * padding), (padding, padding + extra_bottom_padding), ) drawing = drawsvg.Drawing( padding, padding + extra_bottom_padding, origin=(-0.5 * padding, -0.5 * padding), ) drawing.set_pixel_scale(40) return drawing # Extract valid region valid_grid, (start_row, start_col), (height, width) = _extract_valid_region( grid, mask ) if height == 0 or width == 0: if as_group: return ( drawsvg.Group(), (-0.5 * padding, -0.5 * padding), (padding, padding + extra_bottom_padding), ) drawing = drawsvg.Drawing( padding, padding + extra_bottom_padding, origin=(-0.5 * padding, -0.5 * padding), ) drawing.set_pixel_scale(40) return drawing # Calculate cell size cell_size_x = max_width / width if width > 0 else max_height cell_size_y = max_height / height if height > 0 else max_width cell_size = min(cell_size_x, cell_size_y) if width > 0 and height > 0 else 0 actual_width = width * cell_size actual_height = height * cell_size # Drawing setup line_thickness = 0.01 border_width = 0.08 lt = line_thickness / 2 if as_group: drawing = drawsvg.Group() else: drawing = drawsvg.Drawing( actual_width + padding, actual_height + padding + extra_bottom_padding, origin=(-0.5 * padding, -0.5 * padding), ) drawing.set_pixel_scale(40) # Draw grid cells for i in range(height): for j in range(width): color_val = int(valid_grid[i, j]) # Check if cell is valid is_valid = True if mask is not None: actual_row = start_row + i actual_col = start_col + j if actual_row < mask.shape[0] and actual_col < mask.shape[1]: is_valid = mask[actual_row, actual_col] if is_valid and 0 <= color_val < len(ARC_COLOR_PALETTE.keys()): fill_color = ARC_COLOR_PALETTE.get(color_val, "white") else: fill_color = "#CCCCCC" # Light gray for invalid/unknown colors drawing.append( drawsvg.Rectangle( j * cell_size + lt, i * cell_size + lt, cell_size - lt, cell_size - lt, fill=fill_color, ) ) # Add border border_margin = border_width / 3 drawing.append( drawsvg.Rectangle( -border_margin, -border_margin, actual_width + border_margin * 2, actual_height + border_margin * 2, fill="none", stroke=border_color, stroke_width=border_width, ) ) if not as_group: # Embed font cast(drawsvg.Drawing, drawing).embed_google_font( "Anuphan:wght@400;600;700", text=set( "Input Output 0123456789x Test Task ABCDEFGHIJ? abcdefghjklmnopqrstuvwxyz ABCDEFGHIJKLMNOPQRSTUVWXYZ" ), ) # Add size and label text font_size = (padding / 2 + extra_bottom_padding) / 2 if show_size: drawing.append( drawsvg.Text( text=f"{width}x{height}", x=actual_width, y=actual_height + font_size * 1.25, font_size=font_size, fill="black", text_anchor="end", font_family="Anuphan", ) ) if label: drawing.append( drawsvg.Text( text=label, x=-0.1 * font_size, y=actual_height + font_size * 1.25, font_size=font_size, fill="black", text_anchor="start", font_family="Anuphan", font_weight="600", ) ) if as_group: return ( cast(drawsvg.Group, drawing), (-0.5 * padding, -0.5 * padding), (actual_width + padding, actual_height + padding + extra_bottom_padding), ) return cast(drawsvg.Drawing, drawing)
[docs] def save_svg_drawing( drawing: drawsvg.Drawing, filename: str, context: Any | None = None, ) -> None: """Save an SVG drawing to file with support for multiple formats. Args: drawing: The SVG drawing to save filename: Output filename (extension determines format: .svg, .png, .pdf) context: Optional context for PDF conversion Raises: ValueError: If file extension is not supported ImportError: If required dependencies are missing for PNG/PDF output """ if filename.endswith(".svg"): drawing.save_svg(filename) logger.info(f"Saved SVG to {filename}") elif filename.endswith(".png"): drawing.save_png(filename) logger.info(f"Saved PNG to {filename}") elif filename.endswith(".pdf"): buffer = io.StringIO() drawing.as_svg(output_file=buffer, context=context) try: import cairosvg # type: ignore[import-untyped,import-not-found] cairosvg.svg2pdf(bytestring=buffer.getvalue(), write_to=filename) logger.info(f"Saved PDF to {filename}") except ImportError as e: error_msg = "cairosvg is required for PDF output. Please install it with: pip install cairosvg" logger.error(error_msg) raise ImportError(error_msg) from e else: error_msg = ( f"Unknown file extension for {filename}. Supported: .svg, .png, .pdf" ) raise ValueError(error_msg)
def _draw_dotted_squircle( x: float, y: float, width: float, height: float, label: str, stroke_color: str = "#666666", stroke_width: float = 0.05, corner_radius: float = 0.3, dash_array: str = "0.1,0.1", ) -> list[drawsvg.DrawingElement]: """Draw a dotted squircle (rounded rectangle) with label. Args: x: Left edge of the squircle y: Top edge of the squircle width: Width of the squircle height: Height of the squircle label: Label text to display stroke_color: Color of the dotted border stroke_width: Width of the border corner_radius: Radius for rounded corners dash_array: SVG dash pattern for dotted line Returns: List of drawing elements (squircle and label) """ elements = [] # Draw dotted squircle squircle = drawsvg.Rectangle( x, y, width, height, rx=corner_radius, ry=corner_radius, fill="none", stroke=stroke_color, stroke_width=stroke_width, stroke_dasharray=dash_array, opacity=0.7, ) elements.append(squircle) # Add label label_x = x + width - 0.1 label_y = y + 0.3 label_text = drawsvg.Text( text=label, x=label_x, y=label_y, font_size=0.25, font_family="Anuphan", font_weight="700", fill=stroke_color, text_anchor="end", opacity=0.8, ) elements.append(label_text) return elements def add_selection_visualization_overlay( drawing: Any, selection_mask: np.ndarray, grid_x: float, grid_y: float, cell_size: float, start_row: int, start_col: int, display_height: int, display_width: int, selection_color: str = "#3498db", selection_opacity: float = 0.3, border_width: float = 2, ) -> None: """Add selection visualization overlay to a grid. Args: drawing: DrawSVG drawing object to add overlay to selection_mask: Boolean mask of selected cells grid_x: X position of grid grid_y: Y position of grid cell_size: Size of each cell start_row: Starting row of valid region start_col: Starting column of valid region display_height: Height of display region display_width: Width of display region selection_color: Color for selection highlight selection_opacity: Opacity of selection fill border_width: Width of selection border """ import drawsvg as draw if not selection_mask.any(): return # First pass: draw filled rectangles for display_row in range(display_height): for display_col in range(display_width): orig_row = start_row + display_row orig_col = start_col + display_col if ( orig_row < selection_mask.shape[0] and orig_col < selection_mask.shape[1] and selection_mask[orig_row, orig_col] ): cell_x = grid_x + display_col * cell_size cell_y = grid_y + display_row * cell_size drawing.append( draw.Rectangle( cell_x, cell_y, cell_size, cell_size, fill=selection_color, fill_opacity=selection_opacity, stroke="none", ) ) # Second pass: draw boundary lines only on outer edges def is_selected(row, col): """Check if a cell is selected, handling bounds.""" if ( row < 0 or row >= selection_mask.shape[0] or col < 0 or col >= selection_mask.shape[1] ): return False return selection_mask[row, col] for display_row in range(display_height): for display_col in range(display_width): orig_row = start_row + display_row orig_col = start_col + display_col if ( orig_row < selection_mask.shape[0] and orig_col < selection_mask.shape[1] and selection_mask[orig_row, orig_col] ): cell_x = grid_x + display_col * cell_size cell_y = grid_y + display_row * cell_size # Check each edge and draw border line if it's on the boundary # Top edge if not is_selected(orig_row - 1, orig_col): drawing.append( draw.Line( cell_x, cell_y, cell_x + cell_size, cell_y, stroke=selection_color, stroke_width=border_width, stroke_opacity=0.9, ) ) # Bottom edge if not is_selected(orig_row + 1, orig_col): drawing.append( draw.Line( cell_x, cell_y + cell_size, cell_x + cell_size, cell_y + cell_size, stroke=selection_color, stroke_width=border_width, stroke_opacity=0.9, ) ) # Left edge if not is_selected(orig_row, orig_col - 1): drawing.append( draw.Line( cell_x, cell_y, cell_x, cell_y + cell_size, stroke=selection_color, stroke_width=border_width, stroke_opacity=0.9, ) ) # Right edge if not is_selected(orig_row, orig_col + 1): drawing.append( draw.Line( cell_x + cell_size, cell_y, cell_x + cell_size, cell_y + cell_size, stroke=selection_color, stroke_width=border_width, stroke_opacity=0.9, ) ) def add_change_highlighting( drawing: Any, changed_cells: np.ndarray, grid_x: float, grid_y: float, cell_size: float, start_row: int, start_col: int, display_height: int, display_width: int, change_color: str = "#ff6b6b", border_width: float = 3, ) -> None: """Add change highlighting overlay to a grid. Args: drawing: DrawSVG drawing object to add overlay to changed_cells: Boolean mask of changed cells grid_x: X position of grid grid_y: Y position of grid cell_size: Size of each cell start_row: Starting row of valid region start_col: Starting column of valid region display_height: Height of display region display_width: Width of display region change_color: Color for change highlight border_width: Width of change border """ import drawsvg as draw if not changed_cells.any(): return # Add pulsing border for changed cells for display_row in range(display_height): for display_col in range(display_width): orig_row = start_row + display_row orig_col = start_col + display_col if ( orig_row < changed_cells.shape[0] and orig_col < changed_cells.shape[1] and changed_cells[orig_row, orig_col] ): cell_x = grid_x + display_col * cell_size cell_y = grid_y + display_row * cell_size # Add animated border effect drawing.append( draw.Rectangle( cell_x - border_width / 2, cell_y - border_width / 2, cell_size + border_width, cell_size + border_width, fill="none", stroke=change_color, stroke_width=border_width, stroke_opacity=0.8, ) ) # Add inner glow effect drawing.append( draw.Rectangle( cell_x + 1, cell_y + 1, cell_size - 2, cell_size - 2, fill=change_color, fill_opacity=0.1, stroke="none", ) )