Quick Start¶
Learn the JaxARC basics in 5 minutes. This guide covers the essential concepts you need to start using JaxARC.
Your First Environment¶
The simplest way to start with JaxARC is to create an environment and interact with it:
import jax
import jaxarc
# Create an environment (returns env and env_params)
# auto_download=True will download the dataset if it doesn't exist
env, env_params = jaxarc.make(
"Mini-Most_Common_color_l6ab0lf3xztbyxsu3p", auto_download=True
)
# Reset the environment to get the initial state and timestep
key = jax.random.PRNGKey(0)
state, timestep = env.reset(key, env_params=env_params)
# Print basic information
print(f"Observation shape: {timestep.observation.shape}")
print(f"Step type: {timestep.step_type}")
Breaking It Down¶
jaxarc.make("Mini-...")creates an environment instance and environment parameters. The ID specifies the dataset (“Mini” for MiniARC) and optionally a specific task.env.reset(key, env_params=env_params)resets the environment and returns bothstateandtimestep. Unlike Gymnasium, JAX environments require an explicit PRNG key for reproducibility.stateis an immutable object containing internal environment statetimestepcontains the observable information:observation: The current grid state (JAX array)step_type: First, mid, or last step of episodereward: Current reward (0.0 initially)discount: Discount factor (1.0 initially)extras: Additional information (dict)
Understanding State¶
JaxARC uses immutable state objects. Once created, state values never change - operations return new state objects instead:
# Reset returns initial state and timestep
state, timestep = env.reset(key, env_params=env_params)
print(f"Initial: step_type={timestep.step_type}, reward={timestep.reward}")
# Stepping returns NEW state and timestep
action_space = env.action_space(env_params)
action = action_space.sample(key)
next_state, next_timestep = env.step(state, action, env_params=env_params)
# Original state is unchanged
print(f"Original step_type: {timestep.step_type}")
print(f"New step_type: {next_timestep.step_type}, reward={next_timestep.reward}")
Why immutable? This enables JAX’s powerful transformations like jax.jit
(just-in-time compilation) and jax.vmap (vectorization).
Taking Actions¶
Environments have action and observation spaces that define valid actions and observations:
# Check action space
action_space = env.action_space(env_params)
print(f"Action space: {action_space}")
# Sample a random action (requires a PRNG key)
key, subkey = jax.random.split(key)
action = action_space.sample(subkey)
# Take the action
next_state, next_timestep = env.step(state, action, env_params=env_params)
print(f"Reward: {next_timestep.reward}")
print(f"Step type: {next_timestep.step_type}")
The Environment Loop¶
Here’s the standard pattern for interacting with an environment:
import jax
import jaxarc
# Setup
env, env_params = jaxarc.make("Mini-Most_Common_color_l6ab0lf3xztbyxsu3p")
key = jax.random.PRNGKey(42)
state, timestep = env.reset(key, env_params=env_params)
# Run episode
total_reward = 0.0
step_count = 0
action_space = env.action_space(env_params)
while not timestep.last() and step_count < 100:
# Sample action
key, subkey = jax.random.split(key)
action = action_space.sample(subkey)
# Take step
state, timestep = env.step(state, action, env_params=env_params)
# Accumulate reward
total_reward += float(timestep.reward)
step_count += 1
print(f"Episode finished after {step_count} steps")
print(f"Total reward: {total_reward}")
PRNG Keys¶
JAX uses explicit random number generation for reproducibility. You must manage PRNG keys:
# Create initial key
key = jax.random.PRNGKey(0)
# Split key before each random operation
key, reset_key = jax.random.split(key)
state, timestep = env.reset(reset_key, env_params=env_params)
action_space = env.action_space(env_params)
key, action_key = jax.random.split(key)
action = action_space.sample(action_key)
# Using the same key twice gives the same result
key1 = jax.random.PRNGKey(0)
key2 = jax.random.PRNGKey(0)
state1, timestep1 = env.reset(key1, env_params=env_params)
state2, timestep2 = env.reset(key2, env_params=env_params)
# These will be identical
assert jax.numpy.array_equal(timestep1.observation, timestep2.observation)
Key Point: Always split your PRNG key before using it. Never reuse the same key for multiple operations.
Next Steps¶
Now that you understand the basics, try:
Complete Example - See a full random agent implementation
Downloading Datasets - Learn how to access ARC datasets
Creating Agents - Build your own agents
Quick Reference¶
import jax
import jaxarc
# Create environment
env, env_params = jaxarc.make(
"Mini-Most_Common_color_l6ab0lf3xztbyxsu3p", auto_download=True
)
# Reset
key = jax.random.PRNGKey(0)
state, timestep = env.reset(key, env_params=env_params)
# Step
action_space = env.action_space(env_params)
action = action_space.sample(key)
next_state, next_timestep = env.step(state, action, env_params=env_params)
# Access timestep fields
observation = timestep.observation
reward = timestep.reward
step_type = timestep.step_type
discount = timestep.discount
# Check if episode is done
is_done = timestep.last()