Source code for jaxarc.configs.wandb_config

from __future__ import annotations

import equinox as eqx
from omegaconf import DictConfig

from .validation import ConfigValidationError, check_hashable, ensure_tuple


[docs] class WandbConfig(eqx.Module): """Weights & Biases integration settings. This config contains everything related to W&B logging and tracking. No local logging or storage settings here. """ # Core wandb settings enabled: bool = False project_name: str = "jaxarc-experiments" entity: str | None = None tags: tuple[str, ...] = ("jaxarc",) notes: str = "JaxARC experiment" group: str | None = None job_type: str = "training" # Error handling offline_mode: bool = False # Storage save_code: bool = True def __init__(self, **kwargs): """Initialize with automatic list-to-tuple conversion.""" self.enabled = kwargs.get("enabled", False) self.project_name = kwargs.get("project_name", "jaxarc-experiments") self.entity = kwargs.get("entity") self.tags = ensure_tuple(kwargs.get("tags", ("jaxarc",)), default=("jaxarc",)) self.notes = kwargs.get("notes", "JaxARC experiment") self.group = kwargs.get("group") self.job_type = kwargs.get("job_type", "training") self.offline_mode = kwargs.get("offline_mode", False) self.save_code = kwargs.get("save_code", True)
[docs] def validate(self) -> tuple[str, ...]: """Validate wandb configuration and return tuple of errors.""" errors: list[str] = [] try: if not self.project_name.strip(): errors.append("project_name cannot be empty") except ConfigValidationError as e: errors.append(str(e)) return tuple(errors)
def __check_init__(self): check_hashable(self, "WandbConfig")
[docs] @classmethod def from_hydra(cls, cfg: DictConfig) -> WandbConfig: """Create wandb config from Hydra DictConfig.""" return cls( tags=cfg.get("tags", ["jaxarc"]), enabled=cfg.get("enabled", False), project_name=cfg.get("project_name", "jaxarc-experiments"), entity=cfg.get("entity"), notes=cfg.get("notes", "JaxARC experiment"), group=cfg.get("group"), job_type=cfg.get("job_type", "training"), offline_mode=cfg.get("offline_mode", False), save_code=cfg.get("save_code", True), )