Configuration API

JaxARC uses a comprehensive configuration system based on Hydra and Equinox.

Configuration in JaxARC is handled through the JaxArcConfig class, which provides:

  • Type-safe configuration with Equinox modules

  • Hydra integration for YAML-based configs

  • Presets for common use cases

  • Runtime validation

Module Contents

Modular configuration package for JaxARC.

This package splits the previously monolithic envs.config module into focused configuration modules. Public API remains stable via re-exports.

class jaxarc.configs.ActionConfig(**kwargs)[source]

Bases: Module

Configuration for action space and validation.

This config contains all settings related to action handling, validation, and operation constraints, including dynamic action space control.

allow_invalid_actions: bool = False
allowed_operations: tuple[int, ...] | None = None
context_dependent_operations: bool = False
dynamic_action_filtering: bool = False
classmethod from_hydra(cfg: DictConfig) ActionConfig[source]

Create action config from Hydra DictConfig.

invalid_operation_policy: str = 'clip'
max_operations: int = 35
selection_threshold: float = 1.0
validate() tuple[str, ...][source]

Validate action configuration and return tuple of errors.

validate_actions: bool = True
class jaxarc.configs.DatasetConfig(dataset_name: str = 'arc-agi-1', dataset_path: str = '', dataset_repo: str = '', parser_entry_point: str = 'jaxarc.parsers:ArcAgiParser', expected_subdirs: tuple[str, ...] = ('data',), max_grid_height: int = 30, max_grid_width: int = 30, min_grid_height: int = 3, min_grid_width: int = 3, max_colors: int = 10, background_color: int = -1, max_train_pairs: int = 10, max_test_pairs: int = 3, task_split: str = 'train', shuffle_tasks: bool = True)[source]

Bases: Module

Dataset-specific settings and constraints.

This config contains all dataset-related settings including grid constraints, color limits, task sampling, and dataset identification.

background_color: int = -1
dataset_name: str = 'arc-agi-1'
dataset_path: str = ''
dataset_repo: str = ''
expected_subdirs: tuple[str, ...] = ('data',)
classmethod from_hydra(cfg: DictConfig) DatasetConfig[source]

Create dataset config from Hydra DictConfig.

max_colors: int = 10
max_grid_height: int = 30
max_grid_width: int = 30
max_test_pairs: int = 3
max_train_pairs: int = 10
min_grid_height: int = 3
min_grid_width: int = 3
parser_entry_point: str = 'jaxarc.parsers:ArcAgiParser'
shuffle_tasks: bool = True
task_split: str = 'train'
validate() tuple[str, ...][source]

Validate dataset configuration and return tuple of errors.

class jaxarc.configs.EnvironmentConfig(max_episode_steps: int = 100, debug_level: Literal['off', 'minimal', 'verbose'] = 'minimal', render_mode: Literal['rgb_array', 'ansi', 'svg'] = 'rgb_array')[source]

Bases: Module

Core environment behavior and runtime settings.

This config only contains settings that directly affect environment behavior, not dataset constraints, logging, visualization, or storage settings.

debug_level: Literal['off', 'minimal', 'verbose'] = 'minimal'
classmethod from_hydra(cfg: DictConfig) EnvironmentConfig[source]

Create environment config from Hydra DictConfig.

max_episode_steps: int = 100
render_mode: Literal['rgb_array', 'ansi', 'svg'] = 'rgb_array'
validate() tuple[str, ...][source]

Validate environment configuration and return tuple of errors.

class jaxarc.configs.GridInitializationConfig(**kwargs)[source]

Bases: Module

Configuration for grid initialization strategies.

This config controls how working grids are initialized in the environment. Supports four modes for research flexibility: - Demo mode: Copy from training examples - Permutation mode: Apply transformations to demo grids - Empty mode: Start with blank grids - Random mode: Generate random patterns

demo_weight: float = 0.4
empty_weight: float = 0.2
classmethod from_hydra(cfg: DictConfig) GridInitializationConfig[source]

Create grid initialization config from Hydra DictConfig.

permutation_types: tuple[str, ...] = ('rotate', 'reflect', 'color_remap')
permutation_weight: float = 0.3
random_density: float = 0.3
random_pattern_type: str = 'sparse'
random_weight: float = 0.1
validate() tuple[str, ...][source]

Validate grid initialization configuration.

class jaxarc.configs.JaxArcConfig(environment: EnvironmentConfig | None = None, dataset: DatasetConfig | None = None, action: ActionConfig | None = None, reward: RewardConfig | None = None, grid_initialization: GridInitializationConfig | None = None, visualization: VisualizationConfig | None = None, storage: StorageConfig | None = None, logging: LoggingConfig | None = None, wandb: WandbConfig | None = None)[source]

Bases: Module

Unified configuration for JaxARC using Equinox.

Main container that unifies all configuration aspects.

action: ActionConfig
dataset: DatasetConfig
environment: EnvironmentConfig
classmethod from_hydra(hydra_config: DictConfig) JaxArcConfig[source]
grid_initialization: GridInitializationConfig
logging: LoggingConfig
reward: RewardConfig
storage: StorageConfig
to_yaml() str[source]
to_yaml_file(yaml_path: str | Path) None[source]
validate() tuple[str, ...][source]

Validate all components and cross-config consistency.

visualization: VisualizationConfig
wandb: WandbConfig
class jaxarc.configs.LoggingConfig(log_operations: bool = False, log_rewards: bool = False, log_frequency: int = 10, batched_logging_enabled: bool = False, log_format: str = 'text', log_level: str = 'INFO', structured_logging: bool = False)[source]

Bases: Module

All logging behavior and formats.

This config contains everything related to logging: what to log, how to format it, where to write it, and performance settings.

batched_logging_enabled: bool = False
classmethod from_hydra(cfg: DictConfig) LoggingConfig[source]

Create logging config from Hydra DictConfig.

log_format: str = 'text'
log_frequency: int = 10
log_level: str = 'INFO'
log_operations: bool = False
log_rewards: bool = False
structured_logging: bool = False
validate() tuple[str, ...][source]

Validate logging configuration and return tuple of errors.

class jaxarc.configs.RewardConfig(step_penalty: float = -0.01, success_bonus: float = 10.0, similarity_weight: float = 1.0, unsolved_submission_penalty: float = 0.0)[source]

Bases: Module

Configuration for reward calculation.

This config contains all settings related to reward computation, penalties, bonuses, and reward shaping with mode-aware enhancements.

classmethod from_hydra(cfg: DictConfig) RewardConfig[source]

Create reward config from Hydra DictConfig.

similarity_weight: float = 1.0
step_penalty: float = -0.01
success_bonus: float = 10.0
unsolved_submission_penalty: float = 0.0
validate() tuple[str, ...][source]

Validate reward configuration and return tuple of errors.

class jaxarc.configs.StorageConfig(base_output_dir: str = 'outputs', run_name: str | None = None, episodes_dir: str = 'episodes', debug_dir: str = 'debug', visualization_dir: str = 'visualizations', logs_dir: str = 'logs', max_episodes_per_run: int = 100, max_storage_gb: float = 5.0, cleanup_policy: str = 'size_based', create_run_subdirs: bool = True, clear_output_on_start: bool = True)[source]

Bases: Module

All storage, output, and file management settings.

This config contains everything related to file storage, output directories, cleanup policies, and file organization. All output paths are managed here.

base_output_dir: str = 'outputs'
cleanup_policy: str = 'size_based'
clear_output_on_start: bool = True
create_run_subdirs: bool = True
debug_dir: str = 'debug'
episodes_dir: str = 'episodes'
classmethod from_hydra(cfg: DictConfig) StorageConfig[source]

Create storage config from Hydra DictConfig.

logs_dir: str = 'logs'
max_episodes_per_run: int = 100
max_storage_gb: float = 5.0
run_name: str | None = None
validate() tuple[str, ...][source]

Validate storage configuration and return tuple of errors.

visualization_dir: str = 'visualizations'
class jaxarc.configs.VisualizationConfig(**kwargs)[source]

Bases: Module

All visualization and rendering settings.

This config contains everything related to visual output, rendering, and visualization behavior. No logging or storage settings here.

enabled: bool = True
episode_summaries: bool = True
classmethod from_hydra(cfg: DictConfig) VisualizationConfig[source]

Create visualization config from Hydra DictConfig.

step_visualizations: bool = True
validate() tuple[str, ...][source]

Validate visualization configuration and return tuple of errors.

class jaxarc.configs.WandbConfig(**kwargs)[source]

Bases: Module

Weights & Biases integration settings.

This config contains everything related to W&B logging and tracking. No local logging or storage settings here.

enabled: bool = False
entity: str | None = None
classmethod from_hydra(cfg: DictConfig) WandbConfig[source]

Create wandb config from Hydra DictConfig.

group: str | None = None
job_type: str = 'training'
notes: str = 'JaxARC experiment'
offline_mode: bool = False
project_name: str = 'jaxarc-experiments'
save_code: bool = True
tags: tuple[str, ...] = ('jaxarc',)
validate() tuple[str, ...][source]

Validate wandb configuration and return tuple of errors.

Usage Examples

From Python

from jaxarc import JaxArcConfig, make

# Default configuration
config = JaxArcConfig()
env, env_params = make("Mini", config=config)

# Custom configuration
config = JaxArcConfig(
    grid_size=32,
    max_episode_steps=1000,
)
env, env_params = make("Mini", config=config)

From Hydra Config

from jaxarc.utils.core import get_config

# Load from YAML with overrides
hydra_config = get_config(
    overrides=["dataset=mini_arc", "action=point", "grid_size=32"]
)

config = JaxArcConfig.from_hydra(hydra_config)

From YAML File

# config.yaml
dataset: mini_arc
action: point
grid_size: 32
max_episode_steps: 1000
from hydra import compose, initialize
from jaxarc import JaxArcConfig

with initialize(config_path=".", version_base=None):
    cfg = compose(config_name="config")
    config = JaxArcConfig.from_hydra(cfg)