Source code for jaxarc.configs.environment_config

from __future__ import annotations

from typing import Literal

import equinox as eqx
from loguru import logger
from omegaconf import DictConfig

from .validation import (
    ConfigValidationError,
    check_hashable,
    validate_positive_int,
    validate_string_choice,
)


[docs] class EnvironmentConfig(eqx.Module): """Core environment behavior and runtime settings. This config only contains settings that directly affect environment behavior, not dataset constraints, logging, visualization, or storage settings. """ # Episode settings max_episode_steps: int = 100 # Debug level (simplified: off|minimal|verbose) debug_level: Literal["off", "minimal", "verbose"] = "minimal" # Render mode (rgb_array|ansi|svg) render_mode: Literal["rgb_array", "ansi", "svg"] = "rgb_array"
[docs] def validate(self) -> tuple[str, ...]: """Validate environment configuration and return tuple of errors.""" errors: list[str] = [] try: # Validate episode settings validate_positive_int(self.max_episode_steps, "max_episode_steps") if self.max_episode_steps > 10000: logger.warning( f"max_episode_steps is very large: {self.max_episode_steps}" ) # Validate debug level valid_levels = ("off", "minimal", "verbose") validate_string_choice(self.debug_level, "debug_level", valid_levels) # Validate render mode valid_render_modes = ("rgb_array", "ansi", "svg") validate_string_choice(self.render_mode, "render_mode", valid_render_modes) except ConfigValidationError as e: errors.append(str(e)) return tuple(errors)
def __check_init__(self): check_hashable(self, "EnvironmentConfig")
[docs] @classmethod def from_hydra(cls, cfg: DictConfig) -> EnvironmentConfig: """Create environment config from Hydra DictConfig.""" return cls( max_episode_steps=cfg.get("max_episode_steps", 100), debug_level=cfg.get("debug_level", "minimal"), render_mode=cfg.get("render_mode", "rgb_array"), )