Source code for jaxarc.configs.reward_config
from __future__ import annotations
import equinox as eqx
from omegaconf import DictConfig
from .validation import (
ConfigValidationError,
check_hashable,
validate_float_range,
)
[docs]
class RewardConfig(eqx.Module):
"""Configuration for reward calculation.
This config contains all settings related to reward computation,
penalties, bonuses, and reward shaping with mode-aware enhancements.
"""
# Basic reward settings
step_penalty: float = -0.01
success_bonus: float = 10.0
similarity_weight: float = 1.0
unsolved_submission_penalty: float = 0.0
[docs]
def validate(self) -> tuple[str, ...]:
"""Validate reward configuration and return tuple of errors."""
errors: list[str] = []
try:
validate_float_range(self.step_penalty, "step_penalty", -10.0, 1.0)
validate_float_range(self.success_bonus, "success_bonus", -100.0, 1000.0)
validate_float_range(
self.similarity_weight, "similarity_weight", 0.0, 100.0
)
validate_float_range(
self.unsolved_submission_penalty,
"unsolved_submission_penalty",
-1000.0,
0.0,
)
except ConfigValidationError as e:
errors.append(str(e))
return tuple(errors)
def __check_init__(self):
check_hashable(self, "RewardConfig")
[docs]
@classmethod
def from_hydra(cls, cfg: DictConfig) -> RewardConfig:
"""Create reward config from Hydra DictConfig."""
return cls(
step_penalty=cfg.get("step_penalty", -0.01),
success_bonus=cfg.get("success_bonus", 10.0),
similarity_weight=cfg.get("similarity_weight", 1.0),
unsolved_submission_penalty=cfg.get("unsolved_submission_penalty", 0.0),
)