from __future__ import annotations
import equinox as eqx
from loguru import logger
from omegaconf import DictConfig
# Import canonical constants to avoid magic numbers
from jaxarc.constants import MAX_GRID_SIZE, NUM_COLORS
from .validation import (
ConfigValidationError,
check_hashable,
validate_path_string,
validate_positive_int,
validate_string_choice,
)
[docs]
class DatasetConfig(eqx.Module):
"""Dataset-specific settings and constraints.
This config contains all dataset-related settings including grid constraints,
color limits, task sampling, and dataset identification.
"""
# Dataset identification
dataset_name: str = "arc-agi-1"
dataset_path: str = ""
dataset_repo: str = ""
parser_entry_point: str = "jaxarc.parsers:ArcAgiParser"
expected_subdirs: tuple[str, ...] = ("data",)
# Dataset-specific grid constraints
max_grid_height: int = MAX_GRID_SIZE
max_grid_width: int = MAX_GRID_SIZE
min_grid_height: int = 3
min_grid_width: int = 3
# Color constraints
max_colors: int = NUM_COLORS
background_color: int = -1
# Task Configuration
max_train_pairs: int = 10
max_test_pairs: int = 3
# Task sampling parameters
task_split: str = "train"
shuffle_tasks: bool = True
[docs]
def validate(self) -> tuple[str, ...]:
"""Validate dataset configuration and return tuple of errors."""
errors: list[str] = []
try:
# Validate dataset name
if not self.dataset_name.strip():
errors.append("dataset_name cannot be empty")
# Validate dataset path
validate_path_string(self.dataset_path, "dataset_path")
# Validate grid dimensions
validate_positive_int(self.max_grid_height, "max_grid_height")
validate_positive_int(self.max_grid_width, "max_grid_width")
validate_positive_int(self.min_grid_height, "min_grid_height")
validate_positive_int(self.min_grid_width, "min_grid_width")
# Validate task pair counts
validate_positive_int(self.max_train_pairs, "max_train_pairs")
validate_positive_int(self.max_test_pairs, "max_test_pairs")
if self.max_train_pairs > 20:
logger.warning(f"max_train_pairs is very large: {self.max_train_pairs}")
if self.max_test_pairs > 5:
logger.warning(f"max_test_pairs is very large: {self.max_test_pairs}")
# Validate reasonable bounds
if self.max_grid_height > 200:
logger.warning(f"max_grid_height is very large: {self.max_grid_height}")
if self.max_grid_width > 200:
logger.warning(f"max_grid_width is very large: {self.max_grid_width}")
# Validate color constraints
validate_positive_int(self.max_colors, "max_colors")
# Validate background_color: -1 is valid for padding, 0-9 are valid ARC colors
if not isinstance(self.background_color, int):
errors.append(
f"background_color must be an integer, got {type(self.background_color).__name__}"
)
elif self.background_color < -1:
errors.append(
f"background_color must be >= -1 (for padding) or a valid color index, got {self.background_color}"
)
if self.max_colors < 2:
errors.append("max_colors must be at least 2")
if self.max_colors > 50:
logger.warning(f"max_colors is very large: {self.max_colors}")
# Validate task split
valid_splits = [
"train",
"eval",
"test",
"all",
"training",
"evaluation",
"corpus",
]
validate_string_choice(self.task_split, "task_split", tuple(valid_splits))
# Cross-field validation
if self.max_grid_height < self.min_grid_height:
errors.append(
f"max_grid_height ({self.max_grid_height}) < min_grid_height ({self.min_grid_height})"
)
if self.max_grid_width < self.min_grid_width:
errors.append(
f"max_grid_width ({self.max_grid_width}) < min_grid_width ({self.min_grid_width})"
)
# Validate background_color against max_colors (but allow -1 for padding)
if self.background_color >= 0 and self.background_color >= self.max_colors:
errors.append(
f"background_color ({self.background_color}) must be < max_colors ({self.max_colors}) when >= 0"
)
except ConfigValidationError as e:
errors.append(str(e))
return tuple(errors)
def __check_init__(self):
check_hashable(self, "DatasetConfig")
[docs]
@classmethod
def from_hydra(cls, cfg: DictConfig) -> DatasetConfig:
"""Create dataset config from Hydra DictConfig."""
return cls(
dataset_name=cfg.get("dataset_name", "arc-agi-1"),
dataset_path=cfg.get("dataset_path", ""),
dataset_repo=cfg.get("dataset_repo", ""),
parser_entry_point=cfg.get(
"parser_entry_point", "jaxarc.parsers:ArcAgiParser"
),
expected_subdirs=tuple(cfg.get("expected_subdirs", ["data"])),
max_grid_height=cfg.get("max_grid_height", MAX_GRID_SIZE),
max_grid_width=cfg.get("max_grid_width", MAX_GRID_SIZE),
min_grid_height=cfg.get("min_grid_height", 3),
min_grid_width=cfg.get("min_grid_width", 3),
max_colors=cfg.get("max_colors", NUM_COLORS),
background_color=cfg.get("background_color", -1),
task_split=cfg.get("task_split", "train"),
max_train_pairs=cfg.get("max_train_pairs", 10),
max_test_pairs=cfg.get("max_test_pairs", 3),
shuffle_tasks=cfg.get("shuffle_tasks", True),
)