BordAX is a research-focused framework for Programmatic Reinforcement Learning (PRL) that combines the speed of JAX with support for structured, interpretable policies including neural networks, boolean functions, and decision trees.
- High Performance — Fully JIT-compiled training pipelines leveraging JAX's XLA compilation
- Modular Architecture — Clean separation between agents, algorithms, environments, and training
- Multiple Policy Types — MLPs, boolean functions (HyperBool), and decision trees (DTSemNet)
- Flexible Algorithms — Built-in PPO (on-policy) and DQN (off-policy) with easy extensibility
- Environment Agnostic — Supports both Gymnax (JIT-compiled) and Gymnasium environments
- Production Ready — Checkpointing, logging, WandB integration, and comprehensive tests
BordAX achieves high performance through:
- Full JIT compilation for jittable environments (Gymnax)
- Vectorized environments via
jax.vmap - Efficient loops using
jax.lax.scan - Pure functional design compatible with XLA optimization
| Environment | Algorithm | JIT Scope |
|---|---|---|
| Gymnax | On-policy | Entire train_step |
| Gymnax | Off-policy | update only |
| Gymnasium | Any | update only |
PPO on CartPole-v1 with identical hyperparameters (5 seeds, 51k timesteps):
| Framework | Training Time | Throughput | Speedup |
|---|---|---|---|
| BordAX + Gymnax (Full JIT) | 4.26s ± 0.12s | 12,027 steps/s | 3.2x |
| BordAX + Gymnasium | 6.38s ± 0.27s | 8,021 steps/s | 2.2x |
| Stable-Baselines3 | 13.79s ± 0.55s | 3,714 steps/s | 1.0x |
With Gymnax (fully JIT-compiled), BordAX is 3.2x faster than Stable-Baselines3. Even with Gymnasium (Python environment), BordAX is 2.2x faster.
Run the benchmark yourself:
pip install stable-baselines3
python compare_sb3.py# Clone the repository
git clone https://github.com/SynthesisLab/bordax.git
cd bordax
# Install with uv (recommended)
uv sync
# Or with pip
pip install -e .
# With optional dependencies (WandB, visualization)
pip install -e ".[all]"python -c "from bordax.training.trainer import Trainer; print('BordAX installed successfully')"python train_ppo.py- Solves CartPole-v1 (reward = 500) in ~400k steps
- Training time: ~18 seconds on CPU
- Throughput: ~23,000 steps/s
python train_dqn.py- Solves CartPole-v1 in ~50k steps
- Training time: ~36 seconds on CPU
- Includes 1,000 step warmup phase
import jax
from bordax.training.trainer import Trainer, TrainerConfig
from bordax.algorithms.utils import make_algo
from bordax.environments.utils import make_env
from bordax.agents.utils import make_agent
# Setup environments
env = make_env("gymnax/CartPole-v1", {"init_config": {}, "reset_config": {}}, num_envs=4)
eval_env = make_env("gymnax/CartPole-v1", {"init_config": {}, "reset_config": {}}, num_envs=1)
# Create agent with MLP policy and value networks
agent = make_agent("mlp/mlp", env, {
"policy_layers": [64, 64],
"value_layers": [64, 64],
})
# Configure PPO algorithm
algorithm = make_algo("ppo", {
"lr": 3e-4,
"rollout_length": 2048,
"gamma": 0.99,
"_lambda": 0.95,
"clip_schedule": lambda _: 0.2,
"vf_schedule": lambda _: 0.5,
"ent_schedule": lambda _: 0.01,
"num_minibatches": 16,
"num_sgd_steps": 10,
})
# Setup trainer
config = TrainerConfig(
num_checkpoints=100,
epochs_per_checkpoint=1,
evaluation_episodes=32,
debug=True,
)
trainer = Trainer(env, eval_env, agent, algorithm, config)
# Train
key = jax.random.PRNGKey(0)
init_key, train_key = jax.random.split(key)
trainer.init(init_key)
eval_data = trainer.run(train_key)BordAX uses a modular pipeline architecture that cleanly separates concerns:
Trainer
└─> Algorithm (Collector + BatchBuilder + Updater)
├─> Collector: Generates environment transitions
├─> BatchBuilder: Constructs training batches
└─> Updater: Computes gradients and updates parameters
| Component | Purpose | Examples |
|---|---|---|
| Agent | Defines policy and value networks | MLPPolicyValue, BooleanPolicyValue, DTPolicy, DQNAgent |
| Algorithm | Bundles training pipeline components | ppo_algo(), dqn_algo() |
| Collector | Generates transitions via env interaction | OnPolicyCollector, EpsGreedyCollector |
| BatchBuilder | Transforms data into training batches | FullBufferBatch, MiniBatch, UniformReplayBatch |
| Updater | Updates parameters using gradients | SGDUpdate, DQNUpdater |
| Trainer | Orchestrates full training loop | Trainer |
| Algorithm | Type | Collector | Batch Strategy |
|---|---|---|---|
| PPO | On-policy | OnPolicyCollector |
FullBufferBatch → MiniBatch |
| DQN | Off-policy | EpsGreedyCollector |
UniformReplayBatch |
MLP Policy-Value (mlp/mlp):
agent = make_agent("mlp/mlp", env, {
"policy_layers": [128, 128, 64],
"value_layers": [128, 128, 64],
})HyperBool — Boolean function-based policies (boolean/mlp):
agent = make_agent("boolean/mlp", env, {
"n": 4, # Number of boolean variables
"value_layers": [128, 64, 32],
})DTSemNet — Decision tree policies (dt/mlp):
agent = make_agent("dt/mlp", env, {
"tree_depth": 4,
"value_layers": [64, 64],
})Q-Network (dqn):
agent = make_agent("dqn", env, {
"layers": [64, 64],
})bordax/
├── bordax/
│ ├── agents/ # Agent implementations
│ │ ├── base.py # MLPPolicyValue, BooleanPolicyValue, DTPolicy, DQNAgent
│ │ ├── components.py # Neural modules (MLP, DTSemNet, BooleanFunction)
│ │ └── utils.py # make_agent() factory
│ ├── algorithms/ # RL algorithms
│ │ ├── base.py # Algorithm class, ppo_algo(), dqn_algo()
│ │ ├── losses.py # PPOLoss, DQNLoss
│ │ └── utils.py # make_algo() factory
│ ├── data/ # Data collection and batching
│ │ ├── collectors.py # OnPolicyCollector, EpsGreedyCollector
│ │ ├── batchbuilders.py # Batch transformations
│ │ └── buffer.py # ReplayBuffer
│ ├── environments/ # Environment adapters
│ │ └── utils.py # EnvAdapter, make_env()
│ ├── training/ # Training infrastructure
│ │ ├── trainer.py # Main Trainer class
│ │ ├── evaluation.py # Evaluator
│ │ ├── logging.py # Logger with WandB support
│ │ ├── checkpointing.py # Model checkpointing (Orbax)
│ │ └── updaters.py # SGDUpdate, DQNUpdater
│ └── types.py # Core type definitions
├── tests/ # Test suite (48 tests, 77% coverage)
│ ├── unit/ # Fast component tests
│ ├── integration/ # Pipeline tests
│ └── slow/ # Learning verification tests
├── train_ppo.py # PPO training example
├── train_dqn.py # DQN training example
└── compare_sb3.py # Stable-Baselines3 benchmark
BordAX has a comprehensive test suite with 48 tests achieving 77% code coverage.
# Run all tests (excluding slow)
uv run pytest tests/ -m "not slow" -v
# Run slow learning tests
uv run pytest tests/ -m slow -v
# Run with coverage
uv run pytest tests/ --cov=bordax --cov-report=term-missing| Category | Tests | Purpose |
|---|---|---|
| Unit | 44 | Fast component tests |
| Integration | 2 | Full pipeline verification |
| Slow | 2 | Learning verification |
| Package | Version | Purpose |
|---|---|---|
| JAX | >=0.8.0 | Core computation |
| Flax | >=0.12.0 | Neural networks |
| Optax | >=0.2.6 | Optimizers |
| Gymnax | >=0.0.9 | JAX environments |
| Gymnasium | >=1.2.0 | Standard environments |
| Distrax | >=0.1.7 | Distributions |
| Orbax | >=0.11.32 | Checkpointing |
Optional: WandB (experiment tracking), Matplotlib/Seaborn (visualization)
# Restore last checkpoint and continue training
python train_ppo.py --restore-lastBordAX is released under the MIT License.
BordAX builds on excellent work from the JAX ecosystem:
- JAX — High-performance numerical computing
- Flax — Neural network library
- Gymnax — JAX-compatible RL environments
- Optax — Gradient processing and optimization
- Distrax — Probability distributions


