Stoix Integration

Train RL agents on ARC puzzles using Stoix, a JAX-based RL framework.

Factory Function

JaxARC provides make_jaxarc_env() which creates Stoix-compatible environments with the correct wrapper stack:

from omegaconf import OmegaConf
from jaxarc.stoix_adapter import make_jaxarc_env

config = OmegaConf.create({
    "env": {
        "scenario": {"name": "Mini"},
        "action": {"mode": "point"},
        "observation_wrappers": {
            "answer_grid": True,
            "input_grid": True,
            "contextual": True,
        },
    }
})

train_env, eval_env = make_jaxarc_env(config)

The factory:

  1. Creates JaxARC environments via jaxarc.make()

  2. Applies action wrappers (PointActionWrapper / BboxActionWrapper + FlattenActionWrapper)

  3. Applies observation wrappers (answer grid, input grid, contextual)

ExtendedMetrics

JaxARC’s ExtendedMetrics wrapper tracks domain-specific episode statistics:

Metric

Description

best_similarity

Highest grid similarity achieved during the episode

solved

Whether the puzzle was solved (similarity = 1.0)

steps_to_solve

Number of steps taken to solve (0 if unsolved)

final_similarity

Grid similarity at episode end

was_truncated

Whether the episode was truncated (hit max steps)

When used with Stoix, ExtendedMetrics should be applied after RecordEpisodeMetrics so its fields merge into the episode metrics dict:

from stoa.core_wrappers.wrapper import AddRNGKey
from stoa.core_wrappers.episode_metrics import RecordEpisodeMetrics
from stoa.core_wrappers.auto_reset import AutoResetWrapper
from jaxarc.wrappers import ExtendedMetrics

env = AddRNGKey(env)
env = RecordEpisodeMetrics(env)
env = ExtendedMetrics(env)        # After REM, not before!
env = AutoResetWrapper(env, next_obs_in_extras=True)

Custom Metrics Processing

jaxarc_custom_metrics() computes derived statistics from raw episode metrics:

from jaxarc.stoix_adapter import jaxarc_custom_metrics

# raw_metrics comes from timestep.extras["episode_metrics"]
processed = jaxarc_custom_metrics(raw_metrics)
# Adds: success_rate, avg_steps_to_solve, truncation_rate,
#        best_similarity_mean, final_similarity_mean

Using with jaxarc-baselines

The jaxarc-baselines repository provides ready-to-use experiment configs:

git clone --recurse-submodules https://github.com/aadimator/jaxarc-baselines
cd jaxarc-baselines
pixi install
pixi run python run_experiment.py

Override configuration from the command line:

pixi run python run_experiment.py \
    env.scenario.name=Mini-easy \
    env.action.mode=point \
    arch.total_num_envs=256