Optax
Gradient processing and optimization library for JAX — composable optimizer building blocks for training neural networks. Optax features: standard optimizers (optax.adam, optax.sgd, optax.adamw, optax.rmsprop, optax.adagrad, optax.lion), learning rate schedules (optax.cosine_decay_schedule, optax.warmup_cosine_decay_schedule, optax.linear_schedule), gradient clipping (optax.clip, optax.clip_by_global_norm), optimizer chaining (optax.chain), masked updates (optax.masked), exponential moving average (optax.ema), gradient accumulation (optax.MultiSteps), and custom optimizer composition. JAX ecosystem optimizer library — pairs with Flax for neural network training. DeepMind-maintained.
Score Breakdown
⚙ Agent Friendliness
🔒 Security
Local optimization library — no network access, no data exfiltration. Pure JAX pytree transformations. No security concerns beyond standard Python dependency hygiene.
⚡ Reliability
Best When
Training JAX/Flax neural networks where you need composable optimizer building blocks — Optax's chain() and mask() allow building complex optimizer pipelines (gradient clipping + warmup + AdamW + layer-wise LR) as composable transforms.
Avoid When
You're not using JAX, need non-gradient optimization, or prefer all-in-one frameworks.
Use Cases
- • Agent AdamW optimizer — optimizer = optax.adamw(learning_rate=1e-4, weight_decay=0.01); opt_state = optimizer.init(params); updates, opt_state = optimizer.update(grads, opt_state, params) — standard optimizer for agent transformer fine-tuning; weight_decay in AdamW applied correctly to params not gradients
- • Agent cosine LR schedule — schedule = optax.warmup_cosine_decay_schedule(init_value=0, peak_value=1e-3, warmup_steps=1000, decay_steps=10000); optimizer = optax.adam(schedule) — agent training with warmup + cosine decay; learning rate decays from peak to near-zero; standard schedule for transformer agent training
- • Agent gradient clipping — optimizer = optax.chain(optax.clip_by_global_norm(1.0), optax.adam(1e-3)) — clip gradients before Adam update; prevents agent training instability from exploding gradients; chain() composes gradient clipping with optimizer in correct order
- • Agent layer-wise learning rates — tx = optax.masked(optax.adam(1e-4), mask={'backbone': True, 'head': False}); optax.chain(tx, optax.masked(optax.adam(1e-3), mask={'backbone': False, 'head': True})) — different LR for pretrained backbone vs task head; agent fine-tuning with frozen backbone uses masked optimizer
- • Agent gradient accumulation — optimizer = optax.MultiSteps(optax.adam(1e-3), every_k_schedule=4) — accumulate gradients over 4 steps before applying update; agent training with large effective batch size on memory-constrained GPU; MultiSteps correctly handles gradient accumulation with stateful optimizers
Not For
- • Non-JAX training — Optax is JAX-only; for PyTorch optimizers use torch.optim; for TensorFlow use tf.keras.optimizers
- • Non-gradient optimization — Optax is for gradient-based optimization; for evolutionary algorithms or black-box optimization use scipy.optimize or evosax
- • Production hyperparameter tuning — Optax defines optimizers; for automated HPO use Optuna or Ray Tune
Interface
Authentication
No auth — local optimization library.
Pricing
Optax is Apache 2.0 licensed by Google DeepMind. Free for all use.
Agent Metadata
Known Gotchas
- ⚠ optimizer.init() must be called with parameters pytree — optimizer = optax.adam(1e-3); opt_state = optimizer.init(params) initializes Adam moment vectors matching params structure; calling optimizer.update(grads, opt_state) with different pytree structure than params raises ValueError; agent code must init optimizer after defining model parameters
- ⚠ Gradient pytree must match parameter pytree exactly — optimizer.update(grads, opt_state) requires grads and params to have identical pytree structure (same keys, shapes); agent code where model architecture changes between init and update gets inscrutable JAX errors; verify pytree structures with jax.tree_util.tree_map
- ⚠ optax.chain order matters — optax.chain(optax.clip_by_global_norm(1.0), optax.adam(1e-3)) clips THEN updates; optax.chain(optax.adam(1e-3), optax.clip_by_global_norm(1.0)) updates THEN clips parameter changes (wrong); agent gradient clipping must come before optimizer transform in chain
- ⚠ Learning rate schedules need step counter — optax.cosine_decay_schedule produces a callable; optimizer.update(grads, opt_state, params, count=step) requires step counter for schedule evaluation; agent training loop must track and pass step count; forgetting count uses step 0 for entire training
- ⚠ AdamW weight decay excludes bias and LayerNorm — optax.adamw applies weight decay to ALL parameters including biases and LayerNorm weights; for correct transformer training use optax.masked to exclude bias terms from weight decay; incorrect weight decay degrades agent model quality
- ⚠ MultiSteps accumulate_fn must handle empty gradients — optax.MultiSteps accumulates over K steps; agent code computing gradients on empty batches (end of dataset) gets NaN accumulated gradients; ensure dataset length is divisible by accumulation steps or handle remainder batches explicitly
Alternatives
Full Evaluation Report
Detailed scoring breakdown, competitive positioning, security analysis, and improvement recommendations for Optax.
Scores are editorial opinions as of 2026-03-06.