Source code for jaxarc.configs.grid_initialization_config

from __future__ import annotations

import equinox as eqx
from omegaconf import DictConfig

from .validation import (
    ConfigValidationError,
    check_hashable,
    ensure_tuple,
    validate_float_range,
    validate_tuple_elements,
)


[docs] class GridInitializationConfig(eqx.Module): """Configuration for grid initialization strategies. This config controls how working grids are initialized in the environment. Supports four modes for research flexibility: - Demo mode: Copy from training examples - Permutation mode: Apply transformations to demo grids - Empty mode: Start with blank grids - Random mode: Generate random patterns """ # Mode weights (normalized automatically, don't need to sum to 1.0) demo_weight: float = 0.4 permutation_weight: float = 0.3 empty_weight: float = 0.2 random_weight: float = 0.1 # Permutation configuration (simplified) permutation_types: tuple[str, ...] = ("rotate", "reflect", "color_remap") # Random pattern configuration (simplified) random_density: float = 0.3 random_pattern_type: str = "sparse" # "sparse" or "dense" def __init__(self, **kwargs): self.demo_weight = kwargs.get("demo_weight", 0.4) self.permutation_weight = kwargs.get("permutation_weight", 0.3) self.empty_weight = kwargs.get("empty_weight", 0.2) self.random_weight = kwargs.get("random_weight", 0.1) self.permutation_types = ensure_tuple( kwargs.get("permutation_types", ("rotate", "reflect", "color_remap")), default=("rotate", "reflect", "color_remap"), ) self.random_density = kwargs.get("random_density", 0.3) self.random_pattern_type = kwargs.get("random_pattern_type", "sparse")
[docs] def validate(self) -> tuple[str, ...]: """Validate grid initialization configuration.""" errors: list[str] = [] try: # Validate weights (they will be normalized, so just need to be non-negative) validate_float_range(self.demo_weight, "demo_weight", 0.0, float("inf")) validate_float_range( self.permutation_weight, "permutation_weight", 0.0, float("inf") ) validate_float_range(self.empty_weight, "empty_weight", 0.0, float("inf")) validate_float_range(self.random_weight, "random_weight", 0.0, float("inf")) # At least one weight must be positive total_weight = ( self.demo_weight + self.permutation_weight + self.empty_weight + self.random_weight ) if total_weight <= 0: errors.append("At least one initialization weight must be positive") # Validate random configuration validate_float_range(self.random_density, "random_density", 0.0, 1.0) if self.random_pattern_type not in ("sparse", "dense"): errors.append( f"Invalid random_pattern_type: {self.random_pattern_type}. Must be 'sparse' or 'dense'" ) # Validate permutation types _valid_perm_types = {"rotate", "reflect", "color_remap"} if hasattr(self.permutation_types, "__iter__"): errors.extend( validate_tuple_elements( self.permutation_types, "permutation_types", element_type=str, allowed=_valid_perm_types, ) ) # If permutation weight is positive, require non-empty permutation_types if self.permutation_weight > 0.0 and not self.permutation_types: errors.append( "permutation_types cannot be empty when permutation_weight > 0" ) except ConfigValidationError as e: errors.append(str(e)) return tuple(errors)
def __check_init__(self): check_hashable(self, "GridInitializationConfig")
[docs] @classmethod def from_hydra(cls, cfg: DictConfig) -> GridInitializationConfig: """Create grid initialization config from Hydra DictConfig.""" return cls( demo_weight=cfg.get("demo_weight", 0.4), permutation_weight=cfg.get("permutation_weight", 0.3), empty_weight=cfg.get("empty_weight", 0.2), random_weight=cfg.get("random_weight", 0.1), permutation_types=cfg.get( "permutation_types", ["rotate", "reflect", "color_remap"] ), random_density=cfg.get("random_density", 0.3), random_pattern_type=cfg.get("random_pattern_type", "sparse"), )