Source code for jaxarc.registration

"""
Registration system for JaxARC environments.

This package provides a lean registry that maps simple dataset keys to
environment specs. Dataset parsing and task loading are no longer part of this
module. Environments are expected to be constructed with buffer-based EnvParams
(JAX-native, JIT-friendly) and not depend on parsers at runtime.

Core ideas:
- A global registry maps dataset keys (e.g., "Mini", "Concept", "AGI1", "AGI2") to EnvSpec definitions.
- No parser entry points or subset inference live here anymore.
- `make(id, **kwargs)` only parses the dataset key and returns the environment and parameters
  built from provided kwargs (e.g., a prebuilt buffer in EnvParams or an explicit params).
- Named subsets can be registered (e.g., `register_subset("Mini", "easy", [...])`) and then
  selected via `make("Mini-easy")` to load exactly those tasks. This makes it easy to publish
  curated benchmarks and implement curriculum learning.

Typical usage:
    from jaxarc.registration import make
    # Build EnvParams with a pre-stacked task buffer outside this module.
    env, params = make("Mini", params=my_params)

Notes:
- This module keeps a single way of doing things: buffer-based, JIT-friendly EnvParams.
- Dataset downloading/parsing and subset handling should be done outside this module.
"""

from __future__ import annotations

from typing import Any

from .registry import EnvRegistry, EnvSpec
from .subset_loader import (
    load_all_subsets_for_dataset,
    load_subset,
    load_subset_if_needed,
)

# -----------------------------------------------------------------------------
# Module-level singleton API
# -----------------------------------------------------------------------------

_registry = EnvRegistry()

# Default bootstrap: register common dataset IDs with minimal specs
_registry.register(id="Mini", max_episode_steps=100)
_registry.register(id="Concept", max_episode_steps=100)
_registry.register(id="AGI1", max_episode_steps=100)
_registry.register(id="AGI2", max_episode_steps=100)


[docs] def register( id: str, entry_point: str | None = None, env_entry: str = "jaxarc.envs:Environment", max_episode_steps: int = 100, **kwargs: Any, ) -> None: """Register an environment spec in the global registry.""" _registry.register( id=id, env_entry=env_entry, max_episode_steps=max_episode_steps, **kwargs, )
[docs] def make(id: str, **kwargs: Any) -> tuple[Any, Any]: """Create an environment instance and EnvParams using a registered spec. See EnvRegistry.make for details on supported kwargs. """ return _registry.make(id, **kwargs)
[docs] def register_subset( dataset_key: str, name: str, task_ids: list[str] | tuple[str, ...] ) -> None: """Register a named subset for a dataset key, enabling IDs like 'Mini-easy'.""" _registry.register_subset(dataset_key, name, task_ids)
[docs] def get_subset_task_ids( dataset_key: str, selector: str = "all", config: Any | None = None, auto_download: bool = False, ) -> list[str]: """Get task IDs for a specific subset without creating an environment. This allows users to query what tasks will be loaded before calling make(). Args: dataset_key: Dataset name (Mini, Concept, AGI1, AGI2) selector: Subset selector ('all', 'train', 'easy', task_id, etc.) config: Optional config auto_download: Download dataset if missing Returns: List of task IDs that will be loaded Examples: >>> get_subset_task_ids("Mini", "all") ['Most_Common_color_l6ab0lf3xztbyxsu3p', ...] >>> get_subset_task_ids("Mini", "easy") ['task1', 'task2', 'task3'] >>> get_subset_task_ids("Mini", "Most_Common_color_l6ab0lf3xztbyxsu3p") ['Most_Common_color_l6ab0lf3xztbyxsu3p'] """ return _registry.get_subset_task_ids( dataset_key, selector=selector, config=config, auto_download=auto_download )
[docs] def available_task_ids( dataset_key: str, config: Any | None = None, auto_download: bool = False ) -> list[str]: """List all available task IDs (equivalent to get_subset_task_ids with selector='all').""" return _registry.get_subset_task_ids( dataset_key, selector="all", config=config, auto_download=auto_download )
[docs] def available_named_subsets( dataset_key: str, include_builtin: bool = True ) -> tuple[str, ...]: """List available subset names for a dataset (includes built-in selectors by default). Args: dataset_key: Dataset name (Mini, Concept, AGI1, AGI2) include_builtin: Include built-in selectors like 'all', 'train', 'eval' (default: True) Returns: Tuple of subset names Examples: >>> available_named_subsets("Mini") ('all', 'easy', 'eval', 'train') >>> available_named_subsets("Mini", include_builtin=False) ('easy',) # Only custom subsets """ return _registry.available_named_subsets( dataset_key, include_builtin=include_builtin )
[docs] def subset_task_ids(dataset_key: str, name: str) -> tuple[str, ...]: """Return the task IDs registered under a named subset. This only works for explicitly registered subsets (via register_subset). For more flexible queries, use get_subset_task_ids() instead. """ return _registry.subset_task_ids(dataset_key, name)
__all__ = [ "EnvRegistry", "EnvSpec", "_registry", "available_named_subsets", "available_task_ids", "get_subset_task_ids", "load_all_subsets_for_dataset", "load_subset", "load_subset_if_needed", "make", "register", "register_subset", "subset_task_ids", ]