Source code for jaxarc.configs.main_config

from __future__ import annotations

from pathlib import Path
from typing import Any

import equinox as eqx
import yaml
from loguru import logger
from omegaconf import DictConfig, OmegaConf

# Import canonical constants to avoid magic numbers
from jaxarc.constants import NUM_COLORS

from .action_config import ActionConfig
from .dataset_config import DatasetConfig
from .environment_config import EnvironmentConfig
from .grid_initialization_config import GridInitializationConfig
from .logging_config import LoggingConfig
from .reward_config import RewardConfig
from .storage_config import StorageConfig
from .validation import ConfigValidationError, check_hashable
from .visualization_config import VisualizationConfig
from .wandb_config import WandbConfig


[docs] class JaxArcConfig(eqx.Module): """Unified configuration for JaxARC using Equinox. Main container that unifies all configuration aspects. """ environment: EnvironmentConfig dataset: DatasetConfig action: ActionConfig reward: RewardConfig grid_initialization: GridInitializationConfig visualization: VisualizationConfig storage: StorageConfig logging: LoggingConfig wandb: WandbConfig def __init__( self, environment: EnvironmentConfig | None = None, dataset: DatasetConfig | None = None, action: ActionConfig | None = None, reward: RewardConfig | None = None, grid_initialization: GridInitializationConfig | None = None, visualization: VisualizationConfig | None = None, storage: StorageConfig | None = None, logging: LoggingConfig | None = None, wandb: WandbConfig | None = None, ): self.environment = environment or EnvironmentConfig() self.dataset = dataset or DatasetConfig() self.action = action or ActionConfig() self.reward = reward or RewardConfig() self.grid_initialization = grid_initialization or GridInitializationConfig() self.visualization = visualization or VisualizationConfig.from_hydra( DictConfig({}) ) self.storage = storage or StorageConfig() self.logging = logging or LoggingConfig() self.wandb = wandb or WandbConfig.from_hydra(DictConfig({})) def __check_init__(self): check_hashable(self, "JaxArcConfig") # Sub-config field names for iteration _SUB_CONFIGS = ( "environment", "dataset", "action", "reward", "grid_initialization", "visualization", "storage", "logging", "wandb", )
[docs] def validate(self) -> tuple[str, ...]: """Validate all components and cross-config consistency.""" all_errors: list[str] = [] for name in self._SUB_CONFIGS: all_errors.extend(getattr(self, name).validate()) all_errors.extend(self._validate_cross_config_consistency()) return tuple(all_errors)
def _validate_cross_config_consistency(self) -> tuple[str, ...]: errors: list[str] = [] warnings: list[str] = [] try: self._validate_debug_level_consistency(warnings) self._validate_wandb_consistency(errors, warnings) self._validate_action_environment_consistency(warnings) self._validate_reward_consistency(warnings) self._validate_dataset_consistency(warnings) self._validate_logging_consistency(warnings) for warning in warnings: logger.warning(warning) except (ValueError, TypeError, ConfigValidationError) as e: errors.append(f"Cross-configuration validation error: {e}") return tuple(errors) def _validate_debug_level_consistency(self, warnings: list[str]) -> None: debug_level = self.environment.debug_level if debug_level == "off": if self.visualization.enabled: warnings.append( "Debug level is 'off' but visualization is enabled - consider disabling visualization for better performance" ) if self.logging.log_operations or self.logging.log_rewards: warnings.append( "Debug level is 'off' but detailed logging is enabled - consider reducing log level" ) def _validate_wandb_consistency( self, errors: list[str], warnings: list[str] ) -> None: if self.wandb.enabled: if not self.wandb.project_name.strip(): errors.append("WandB enabled but project_name is empty") if getattr(self.logging, "log_level", "INFO") == "ERROR": warnings.append( "WandB enabled but log level is ERROR - may miss important metrics" ) def _validate_action_environment_consistency(self, warnings: list[str]) -> None: if self.action.max_operations > 50 and self.environment.max_episode_steps < 20: warnings.append( "Many operations available but few episode steps - may not explore action space effectively" ) def _validate_reward_consistency(self, warnings: list[str]) -> None: if abs(self.reward.step_penalty) * self.environment.max_episode_steps > abs( self.reward.success_bonus ): warnings.append( "Cumulative step penalties may exceed success bonus - consider adjusting reward balance" ) def _validate_dataset_consistency(self, warnings: list[str]) -> None: max_grid_area = self.dataset.max_grid_height * self.dataset.max_grid_width if max_grid_area > 400 and self.environment.max_episode_steps < 100: warnings.append( "Large grids with short episodes may not provide enough time for complex tasks" ) if self.dataset.max_colors > NUM_COLORS and self.action.allowed_operations: fill_ops = [op for op in self.action.allowed_operations if 0 <= op <= 9] if len(fill_ops) < self.dataset.max_colors: warnings.append( f"Dataset allows {self.dataset.max_colors} colors but only {len(fill_ops)} fill operations available" ) def _validate_logging_consistency(self, warnings: list[str]) -> None: if getattr( self.logging, "structured_logging", False ) and self.logging.log_format not in [ "json", "structured", ]: warnings.append( f"Structured logging enabled but format is '{self.logging.log_format}' - consider using 'json' or 'structured'" ) detailed_logging = self.logging.log_operations or self.logging.log_rewards if detailed_logging and self.logging.log_level in ["ERROR", "WARNING"]: warnings.append( "Detailed content logging enabled but log level may suppress the logs" )
[docs] def to_yaml(self) -> str: try: config_dict = { name: self._config_to_dict(getattr(self, name)) for name in self._SUB_CONFIGS } return yaml.dump( config_dict, default_flow_style=False, sort_keys=False, allow_unicode=True, encoding=None, ) except Exception as e: msg = f"Failed to export configuration to YAML: {e}" raise ConfigValidationError(msg) from e
[docs] def to_yaml_file(self, yaml_path: str | Path) -> None: try: yaml_path = Path(yaml_path) yaml_path.parent.mkdir(parents=True, exist_ok=True) yaml_content = self.to_yaml() with yaml_path.open("w", encoding="utf-8") as f: f.write(yaml_content) except Exception as e: msg = f"Failed to save configuration to YAML file: {e}" raise ConfigValidationError(msg) from e
def _config_to_dict(self, config: eqx.Module) -> dict[str, Any]: return { name: self._serialize_value(getattr(config, name)) for name in getattr(config, "__annotations__", {}) if hasattr(config, name) } def _serialize_value(self, value: Any) -> Any: if value is None: return None if isinstance(value, (str, int, float, bool)): return value if isinstance(value, (list, tuple)): return [self._serialize_value(item) for item in value] if isinstance(value, dict): return {k: self._serialize_value(v) for k, v in value.items()} if hasattr(value, "_content"): try: return OmegaConf.to_container(value, resolve=True) except (AttributeError, ValueError, TypeError): return str(value) else: return str(value)
[docs] @classmethod def from_hydra(cls, hydra_config: DictConfig) -> JaxArcConfig: try: cfg_classes = { "environment": EnvironmentConfig, "dataset": DatasetConfig, "action": ActionConfig, "reward": RewardConfig, "grid_initialization": GridInitializationConfig, "visualization": VisualizationConfig, "storage": StorageConfig, "logging": LoggingConfig, "wandb": WandbConfig, } sub_cfgs = { name: cfg_cls.from_hydra(hydra_config.get(name, DictConfig({}))) for name, cfg_cls in cfg_classes.items() } return cls(**sub_cfgs) except ConfigValidationError: raise except Exception as e: msg = f"Failed to create configuration from Hydra: {e}" raise ConfigValidationError(msg) from e