Source code for jaxarc.registration.registry

"""Core registry implementation for JaxARC environments.

Contains the EnvSpec dataclass and EnvRegistry class with instance methods
that manage specs and subsets. Standalone config/subset functions are
imported from sibling modules.
"""

from __future__ import annotations

import importlib
from dataclasses import dataclass, field
from typing import Any, Dict, Optional, Tuple

from loguru import logger

from jaxarc.utils.buffer import stack_task_list

from .config_prep import (
    ensure_dataset_available,
    maybe_adjust_task_split,
    prepare_config,
    resolve_episode_mode,
)

# -----------------------------------------------------------------------------
# Data structures
# -----------------------------------------------------------------------------


[docs] @dataclass class EnvSpec: """Environment specification for registration.""" id: str env_entry: str = "jaxarc.envs:Environment" max_episode_steps: int = 100 kwargs: Dict[str, Any] = field(default_factory=dict)
# ----------------------------------------------------------------------------- # Registry implementation # -----------------------------------------------------------------------------
[docs] class EnvRegistry: """Global environment registry with gym-like semantics.""" def __init__(self) -> None: self._specs: Dict[str, EnvSpec] = {} # Named subset registry: maps normalized dataset key -> subset name -> tuple of task IDs self._subsets: Dict[str, Dict[str, tuple[str, ...]]] = {}
[docs] def register( self, id: str, env_entry: str = "jaxarc.envs:Environment", max_episode_steps: int = 100, **kwargs: Any, ) -> None: """Register a new environment specification. Args: id: Unique environment ID (e.g., "JaxARC-Mini-v0") entry_point: Dotted path or colon path to class/factory (e.g., "jaxarc.envs:Environment") max_episode_steps: Default max steps for this environment family **kwargs: Additional metadata stored with the spec """ self._specs[id] = EnvSpec( id=id, env_entry=env_entry, max_episode_steps=int(max_episode_steps), kwargs=dict(kwargs), )
[docs] def register_subset( self, dataset_key: str, name: str, task_ids: list[str] | tuple[str, ...] ) -> None: """Register a named subset (e.g., 'Mini-easy') that maps to specific task IDs. Args: dataset_key: Base dataset key (e.g., 'Mini', 'Concept', 'AGI1', 'AGI2' or synonyms) name: Subset name (e.g., 'easy', 'hard', 'my-benchmark') task_ids: Sequence of task IDs to include in this subset """ key = self._normalize_dataset_key(dataset_key) sel = name.strip().lower() if not sel: raise ValueError("Subset name must be non-empty") ids_tuple: tuple[str, ...] = ( tuple(task_ids) if not isinstance(task_ids, tuple) else task_ids ) if key not in self._subsets: self._subsets[key] = {} self._subsets[key][sel] = ids_tuple
[docs] def available_named_subsets( self, dataset_key: str, include_builtin: bool = True ) -> tuple[str, ...]: """Return names of available subsets for a dataset. Args: dataset_key: Dataset name (Mini, Concept, AGI1, AGI2) include_builtin: Include built-in selectors ('all', 'train', 'eval') and concept groups (default: True) Returns: Tuple of subset names, sorted alphabetically Examples: >>> available_named_subsets("Mini") ('all',) # Mini doesn't have train/eval splits >>> available_named_subsets("Concept") ('AboveBelow', 'Center', 'all', ...) # Includes concept groups >>> available_named_subsets("AGI1") ('all', 'eval', 'train') # AGI has splits >>> available_named_subsets("Mini", include_builtin=False) () # Only custom subsets """ key = self._normalize_dataset_key(dataset_key) # Start with manually registered subsets subsets = set(self._subsets.get(key, {}).keys()) if include_builtin: # Add 'all' for everyone subsets.add("all") # Only AGI datasets have train/eval splits if key in ("agi1", "agi2"): subsets.update(["train", "eval"]) # Add concept groups for ConceptARC if key == "concept": try: # Try to get concept groups if dataset is available spec_key = self._canonical_spec_key(dataset_key) if spec_key in self._specs: spec = self._specs[spec_key] cfg = prepare_config(None, spec.max_episode_steps, spec_key) try: cfg = ensure_dataset_available( cfg, spec_key, auto_download=False ) parser = self._create_parser(cfg) if hasattr(parser, "get_concept_groups"): concepts = parser.get_concept_groups() subsets.update(concepts) except Exception: # Dataset not available, skip concept groups pass except Exception: # If we can't load concepts, just continue pass return tuple(sorted(subsets))
[docs] def get_subset_task_ids( self, dataset_key: str, selector: str = "all", config: Optional[Any] = None, auto_download: bool = False, ) -> list[str]: """Get task IDs for a specific subset without creating an environment. This allows users to query what tasks will be loaded before calling make(). Args: dataset_key: Dataset name (Mini, Concept, AGI1, AGI2) selector: Subset selector ('all', 'train', 'easy', task_id, etc.) config: Optional config auto_download: Download dataset if missing Returns: List of task IDs that will be loaded Examples: >>> get_subset_task_ids("Mini", "all") ['Most_Common_color_l6ab0lf3xztbyxsu3p', ...] >>> get_subset_task_ids("Mini", "easy") ['task1', 'task2', 'task3'] # Only tasks in 'easy' subset >>> get_subset_task_ids("Concept", "Center") ['Center_001', 'Center_002', ...] # Tasks in Center concept >>> get_subset_task_ids("Mini", "Most_Common_color_l6ab0lf3xztbyxsu3p") ['Most_Common_color_l6ab0lf3xztbyxsu3p'] # Single task """ spec_key = self._canonical_spec_key(dataset_key) if spec_key not in self._specs: msg = f"Environment '{spec_key}' is not registered" raise ValueError(msg) spec = self._specs[spec_key] cfg = prepare_config(config, spec.max_episode_steps, spec_key) # Adjust split for AGI datasets (returns modified config) cfg = maybe_adjust_task_split(cfg, dataset_key, selector) # Ensure dataset available and create parser cfg = ensure_dataset_available(cfg, spec_key, auto_download) parser = self._create_parser(cfg) # Use unified resolution return self._resolve_selector_to_task_ids(dataset_key, selector, parser)
[docs] def subset_task_ids(self, dataset_key: str, name: str) -> tuple[str, ...]: """Return the task IDs registered for a named subset (e.g., 'Mini', 'easy').""" return self._get_named_subset_ids(dataset_key, name)
[docs] def available_task_ids( self, dataset_key: str, config: Optional[Any] = None, auto_download: bool = False, ) -> list[str]: """Return all available task IDs for a dataset key after ensuring dataset availability.""" spec_key = self._canonical_spec_key(dataset_key) if spec_key not in self._specs: raise ValueError(f"Environment '{spec_key}' is not registered") spec = self._specs[spec_key] cfg = prepare_config(config, spec.max_episode_steps, spec_key) cfg = ensure_dataset_available(cfg, spec_key, auto_download) dataset_config = cfg.dataset parser_entry = getattr( dataset_config, "parser_entry_point", "jaxarc.parsers:ArcAgiParser" ) parser_obj = self._import_from_entry_point(parser_entry) parser = parser_obj(cfg.dataset) if self._is_class(parser_obj) else parser_obj return ( parser.get_available_task_ids() if hasattr(parser, "get_available_task_ids") else [] )
[docs] def make(self, id: str, **kwargs: Any) -> Tuple[Any, Any]: """Create an environment instance and parameters for a registered spec. Expected kwargs: - params: EnvParams (preferred; buffer-based, JIT-friendly) - env_entry: str (optional) override of environment entry point Returns: (env, params) tuple: env: Environment instance params: EnvParams provided directly """ dataset_key, modifiers = self._parse_id(id) if dataset_key not in self._specs: raise ValueError(f"Environment '{dataset_key}' is not registered") spec = self._specs[dataset_key] # Instantiate environment (spec.env_entry or override) env_entry = kwargs.get("env_entry", spec.env_entry) env_obj = self._import_from_entry_point(env_entry) # If params explicitly provided, use them if "params" in kwargs and kwargs["params"] is not None: return env_obj( config=kwargs["config"], buffer=kwargs["params"].buffer ), kwargs["params"] # Prepare config and dataset availability config = prepare_config( kwargs.get("config"), spec.max_episode_steps, dataset_key ) auto_download = bool(kwargs.get("auto_download", False)) # Parse selector (may be empty) selector = modifiers.get("selector", "") # Adjust split for AGI datasets based on selector (returns modified config) config = maybe_adjust_task_split(config, dataset_key, selector) # Ensure dataset exists on disk (optionally download) config = ensure_dataset_available(config, dataset_key, auto_download) # Instantiate the dataset parser from config parser = self._create_parser(config) # For AGI datasets, we may need the parser_obj for cross-split lookups dataset_config = config.dataset parser_entry = getattr( dataset_config, "parser_entry_point", "jaxarc.parsers:ArcAgiParser" ) parser_obj = self._import_from_entry_point(parser_entry) # Resolve episode mode (0=train, 1=eval) episode_mode = resolve_episode_mode(kwargs.get("episode_mode"), selector) # UNIFIED RESOLUTION - works for all selector types try: ids = self._resolve_selector_to_task_ids( dataset_key, selector if selector else "all", parser ) except ValueError as e: msg = f"Failed to resolve '{id}': {e}" raise ValueError(msg) from e if not ids: msg = "No tasks resolved for the given selector." raise ValueError(msg) # Build stacked buffer using parser, handling cross-split lookups for AGI datasets if needed tasks = self._get_tasks_for_ids(parser, parser_obj, config, dataset_key, ids) buf = stack_task_list(tasks) env = env_obj(config=config, buffer=buf, episode_mode=episode_mode) return env, env.params
# ------------------------------------------------------------------------- # Helpers # ------------------------------------------------------------------------- def _resolve_selector_to_task_ids( self, dataset_key: str, selector: str, parser: Any ) -> list[str]: """Resolve any selector to a list of task IDs. Priority order: 1. Named subset (e.g., 'easy' from register_subset) 2. Built-in selectors ('all', 'train', 'eval') 3. Concept groups (ConceptARC: 'AboveBelow', 'Center', etc.) 4. Single task ID (e.g., 'Most_Common_color_l6ab0lf3xztbyxsu3p') Args: dataset_key: Dataset key (Mini, Concept, AGI1, AGI2) selector: Selector string from make("Dataset-{selector}") parser: Initialized parser instance Returns: List of resolved task IDs Raises: ValueError: If selector cannot be resolved """ # 1. Check named subsets first (highest priority) named_ids = self._get_named_subset_ids(dataset_key, selector) if named_ids: return list(named_ids) # 2. Check built-in selectors sel_l = selector.lower() if sel_l in ( "", "all", "train", "training", "eval", "evaluation", "test", "corpus", ): return self._get_all_task_ids(parser) # 3. Concept-specific: check concept groups key_l = self._normalize_dataset_key(dataset_key) if key_l == "concept": if hasattr(parser, "get_concept_groups") and hasattr( parser, "get_tasks_in_concept" ): concepts = parser.get_concept_groups() if selector in concepts: return list(parser.get_tasks_in_concept(selector)) # 4. Try as single task ID all_ids = self._get_all_task_ids(parser) if selector in all_ids: return [selector] # 5. Failed to resolve - provide helpful error available_options = self._describe_available_selectors(dataset_key, parser) raise ValueError( f"Unknown selector '{selector}' for {dataset_key}.\n" f"Available options: {available_options}" ) def _get_all_task_ids(self, parser: Any) -> list[str]: """Get all available task IDs from parser.""" if hasattr(parser, "get_available_task_ids"): return parser.get_available_task_ids() return [] def _describe_available_selectors(self, dataset_key: str, parser: Any) -> str: """Create a helpful description of valid selectors for error messages.""" # Get all available named subsets (includes built-ins, custom subsets, and concepts) named = self.available_named_subsets(dataset_key, include_builtin=True) options = [f"'{n}'" for n in named] if named else [] # Add note about task IDs options.append("or any valid task ID") return ", ".join(options) def _create_parser(self, config: Any) -> Any: """Create parser instance from config. Extracted to eliminate duplication across dataset branches. """ dataset_config = config.dataset parser_entry = getattr( dataset_config, "parser_entry_point", "jaxarc.parsers:ArcAgiParser" ) parser_obj = self._import_from_entry_point(parser_entry) return parser_obj(config.dataset) if self._is_class(parser_obj) else parser_obj def _parse_id(self, id: str) -> tuple[str, dict[str, str]]: """Parse environment ID and extract modifiers. Conventions: - Accepts either: - DatasetID (no selector) - DatasetID-{Selector} (with selector) - When selector is present, all remaining tokens after DatasetID are joined to form the single Selector string. """ tokens = id.split("-", 1) dataset_key = tokens[0] selector = tokens[1] if len(tokens) > 1 else "" modifiers: dict[str, str] = {} if selector: modifiers["selector"] = selector return dataset_key, modifiers @staticmethod def _normalize_dataset_key(dataset_key: str) -> str: """Normalize dataset key to a canonical lowercase token for internal mapping.""" key = dataset_key.lower() if key in ("mini", "miniarc", "mini-arc"): return "mini" if key in ("concept", "conceptarc", "concept-arc"): return "concept" if key in ("agi1", "arc-agi-1", "agi-1", "agi_1"): return "agi1" if key in ("agi2", "arc-agi-2", "agi-2", "agi_2"): return "agi2" return key @staticmethod def _canonical_spec_key(dataset_key: str) -> str: """Map a dataset key (including synonyms) to a registered spec key.""" key = EnvRegistry._normalize_dataset_key(dataset_key) if key == "mini": return "Mini" if key == "concept": return "Concept" if key == "agi1": return "AGI1" if key == "agi2": return "AGI2" # Fallback: assume caller provided exact registered key return dataset_key def _get_named_subset_ids(self, dataset_key: str, selector: str) -> tuple[str, ...]: """Fetch named subset IDs if registered for the dataset_key/selector pair.""" key = self._normalize_dataset_key(dataset_key) subsets = self._subsets.get(key, {}) return subsets.get(selector.lower(), tuple()) def _get_tasks_for_ids( self, parser: Any, parser_entry_obj: Any, config: Any, dataset_key: str, ids: list[str], ) -> list[Any]: """Load tasks by ID using the current parser. For AGI datasets, missing IDs are looked up in the opposite split.""" import equinox as eqx tasks: list[Any] = [] missing: list[str] = [] for tid in ids: try: tasks.append(parser.get_task_by_id(tid)) except Exception: missing.append(tid) if not missing: return tasks key_l = dataset_key.lower() if key_l in ( "agi1", "arc-agi-1", "agi-1", "agi_1", "agi2", "arc-agi-2", "agi-2", "agi_2", ): try: ds = config.dataset current_split = getattr(ds, "task_split", "train") opposite = ( "evaluation" if current_split in ("train", "training") else "train" ) logger.debug( f"Looking for {len(missing)} missing tasks in opposite split '{opposite}'" ) # Properly update immutable config using eqx.tree_at ds_opposite = eqx.tree_at(lambda d: d.task_split, ds, opposite) config_opposite = eqx.tree_at(lambda c: c.dataset, config, ds_opposite) # Create parser for opposite split parser2 = ( parser_entry_obj(config_opposite.dataset) if self._is_class(parser_entry_obj) else parser_entry_obj ) still_missing: list[str] = [] for tid in list(missing): try: tasks.append(parser2.get_task_by_id(tid)) logger.debug( f"Found task '{tid}' in opposite split '{opposite}'" ) except Exception: still_missing.append(tid) missing = still_missing except Exception as e: logger.warning(f"Failed to lookup missing tasks in opposite split: {e}") # Fall through to error below if missing: raise ValueError( f"Some task ids were not found for dataset '{dataset_key}': {missing}" ) return tasks @staticmethod def _is_class(obj: Any) -> bool: try: import inspect return inspect.isclass(obj) except Exception: return False @staticmethod def _import_from_entry_point(entry_point: str) -> Any: """Import an object from an entry point string. Supports: - "package.module:object" - "package.module.Object" Raises: ValueError: If the entry point format is invalid or import fails """ module_name: Optional[str] = None attr_name: Optional[str] = None if ":" in entry_point: module_name, attr_name = entry_point.split(":", 1) else: # Split by last dot to separate module and attribute parts = entry_point.split(".") if len(parts) < 2: raise ValueError(f"Invalid entry_point '{entry_point}'") module_name = ".".join(parts[:-1]) attr_name = parts[-1] try: module = importlib.import_module(module_name) obj = getattr(module, attr_name) return obj except Exception as e: raise ValueError(f"Failed to import '{entry_point}': {e}") from e