Source code for jaxarc.envs.actions

"""
Action system for JaxARC environments.

This module provides the complete action system following KISS principle:
- Action class for representing actions
- Action creation and processing utilities
- Operation validation and filtering utilities

Combined from simplified actions.py and action_filtering.py for better organization.
"""

from __future__ import annotations

import equinox as eqx
import jax
import jax.numpy as jnp

from jaxarc.configs.action_config import ActionConfig
from jaxarc.state import State
from jaxarc.types import (
    NUM_OPERATIONS,
    MaskArray,
    OperationMask,
    SelectionArray,
)


[docs] class Action(eqx.Module): """Simple action representation. Attributes: operation: ARC operation ID (0-34) selection: Boolean mask indicating selected cells """ operation: jnp.int32 selection: SelectionArray
[docs] def validate(self, grid_shape: tuple[int, int], max_operations: int = 35) -> Action: """Validate action parameters. Args: grid_shape: Shape of the grid (height, width) max_operations: Maximum number of operations Returns: Validated action with clipped operation """ # Clip operation to valid range valid_operation = jnp.clip(self.operation, 0, max_operations - 1) # Return with validated operation (assume selection is already correct shape) return Action( operation=valid_operation, selection=self.selection, )
def create_action(operation, selection: SelectionArray) -> Action: """Create an action. Args: operation: ARC operation ID (0-34) selection: Boolean mask indicating selected cells Returns: Action instance """ return Action( operation=jnp.array(operation, dtype=jnp.int32), selection=selection, ) @jax.jit def action_handler(action: Action, working_grid_mask: MaskArray) -> SelectionArray: """Process action to create selection mask. Args: action: Action with operation and selection working_grid_mask: Boolean mask defining valid grid area Returns: Boolean selection mask constrained to working grid area """ # Ensure selection is boolean and constrain to working grid selection = action.selection.astype(jnp.bool_) return selection & working_grid_mask def get_allowed_operations(state: State, config: ActionConfig) -> OperationMask: """Get mask of allowed operations based on config and state. Args: state: Current environment state config: Action configuration Returns: Boolean mask indicating which operations are allowed """ allowed_ops = getattr(config, "allowed_operations", None) if isinstance(allowed_ops, tuple) and len(allowed_ops) > 0: idx = jnp.asarray(allowed_ops, dtype=jnp.int32) idx = jnp.clip(idx, 0, NUM_OPERATIONS - 1) base = jnp.zeros((NUM_OPERATIONS,), dtype=jnp.bool_).at[idx].set(True) else: base = jnp.ones((NUM_OPERATIONS,), dtype=jnp.bool_) if hasattr(state, "allowed_operations_mask"): base = jnp.logical_and(base, state.allowed_operations_mask) return base def validate_operation( operation_id: jnp.ndarray, state: State, config: ActionConfig ) -> jnp.ndarray: """Validate if operation ID is in range and allowed by current mask. Args: operation_id: Operation ID to validate state: Current environment state config: Action configuration Returns: Boolean indicating if operation is valid """ mask = get_allowed_operations(state, config) in_range = (operation_id >= 0) & (operation_id < NUM_OPERATIONS) safe = jnp.clip(operation_id, 0, NUM_OPERATIONS - 1) return in_range & mask[safe] def _find_nearest_valid_operation( op_id: jnp.ndarray, mask: OperationMask ) -> jnp.ndarray: """Find nearest valid operation ID based on allowed operations mask. Args: op_id: Target operation ID mask: Boolean mask of allowed operations Returns: Nearest valid operation ID, or 0 if none available """ ids = jnp.arange(NUM_OPERATIONS) dists = jnp.abs(ids - op_id) dists = jnp.where(mask, dists, jnp.inf) idx = jnp.argmin(dists) return jnp.where(jnp.any(mask), idx, jnp.array(0, dtype=jnp.int32)) def filter_invalid_operation( operation_id: jnp.ndarray, state: State, config: ActionConfig ) -> jnp.ndarray: """Filter invalid operation according to configured policy. Args: operation_id: Operation ID to filter state: Current environment state config: Action configuration with invalid_operation_policy Returns: Filtered operation ID based on policy """ arr = operation_id.astype(jnp.int32) mask = get_allowed_operations(state, config) valid = validate_operation(arr, state, config) policy = getattr(config, "invalid_operation_policy", "clip") clipped = jnp.clip(arr, 0, NUM_OPERATIONS - 1) if policy in ("clip", "penalize"): repl = _find_nearest_valid_operation(clipped, mask) out = jnp.where(valid, arr, repl) elif policy == "reject": out = jnp.where(valid, arr, jnp.array(-1, dtype=jnp.int32)) elif policy == "passthrough": out = arr else: # Default to clip behavior for unknown policies repl = _find_nearest_valid_operation(clipped, mask) out = jnp.where(valid, arr, repl) return out