Optax

Gradient processing and optimization library for JAX — composable optimizer building blocks for training neural networks. Optax features: standard optimizers (optax.adam, optax.sgd, optax.adamw, optax.rmsprop, optax.adagrad, optax.lion), learning rate schedules (optax.cosine_decay_schedule, optax.warmup_cosine_decay_schedule, optax.linear_schedule), gradient clipping (optax.clip, optax.clip_by_global_norm), optimizer chaining (optax.chain), masked updates (optax.masked), exponential moving average (optax.ema), gradient accumulation (optax.MultiSteps), and custom optimizer composition. JAX ecosystem optimizer library — pairs with Flax for neural network training. DeepMind-maintained.

Evaluated Mar 06, 2026 (0d ago) v0.2.x
Homepage ↗ Repo ↗ AI & Machine Learning python optax jax optimization gradient-descent adam sgd lr-schedule
⚙ Agent Friendliness
64
/ 100
Can an agent use this?
🔒 Security
91
/ 100
Is it safe for agents?
⚡ Reliability
77
/ 100
Does it work consistently?

Score Breakdown

⚙ Agent Friendliness

MCP Quality
--
Documentation
80
Error Messages
75
Auth Simplicity
98
Rate Limits
98

🔒 Security

TLS Enforcement
92
Auth Strength
92
Scope Granularity
88
Dep. Hygiene
88
Secret Handling
92

Local optimization library — no network access, no data exfiltration. Pure JAX pytree transformations. No security concerns beyond standard Python dependency hygiene.

⚡ Reliability

Uptime/SLA
80
Version Stability
78
Breaking Changes
72
Error Recovery
78
AF Security Reliability

Best When

Training JAX/Flax neural networks where you need composable optimizer building blocks — Optax's chain() and mask() allow building complex optimizer pipelines (gradient clipping + warmup + AdamW + layer-wise LR) as composable transforms.

Avoid When

You're not using JAX, need non-gradient optimization, or prefer all-in-one frameworks.

Use Cases

  • Agent AdamW optimizer — optimizer = optax.adamw(learning_rate=1e-4, weight_decay=0.01); opt_state = optimizer.init(params); updates, opt_state = optimizer.update(grads, opt_state, params) — standard optimizer for agent transformer fine-tuning; weight_decay in AdamW applied correctly to params not gradients
  • Agent cosine LR schedule — schedule = optax.warmup_cosine_decay_schedule(init_value=0, peak_value=1e-3, warmup_steps=1000, decay_steps=10000); optimizer = optax.adam(schedule) — agent training with warmup + cosine decay; learning rate decays from peak to near-zero; standard schedule for transformer agent training
  • Agent gradient clipping — optimizer = optax.chain(optax.clip_by_global_norm(1.0), optax.adam(1e-3)) — clip gradients before Adam update; prevents agent training instability from exploding gradients; chain() composes gradient clipping with optimizer in correct order
  • Agent layer-wise learning rates — tx = optax.masked(optax.adam(1e-4), mask={'backbone': True, 'head': False}); optax.chain(tx, optax.masked(optax.adam(1e-3), mask={'backbone': False, 'head': True})) — different LR for pretrained backbone vs task head; agent fine-tuning with frozen backbone uses masked optimizer
  • Agent gradient accumulation — optimizer = optax.MultiSteps(optax.adam(1e-3), every_k_schedule=4) — accumulate gradients over 4 steps before applying update; agent training with large effective batch size on memory-constrained GPU; MultiSteps correctly handles gradient accumulation with stateful optimizers

Not For

  • Non-JAX training — Optax is JAX-only; for PyTorch optimizers use torch.optim; for TensorFlow use tf.keras.optimizers
  • Non-gradient optimization — Optax is for gradient-based optimization; for evolutionary algorithms or black-box optimization use scipy.optimize or evosax
  • Production hyperparameter tuning — Optax defines optimizers; for automated HPO use Optuna or Ray Tune

Interface

REST API
No
GraphQL
No
gRPC
No
MCP Server
No
SDK
Yes
Webhooks
No

Authentication

Methods: none
OAuth: No Scopes: No

No auth — local optimization library.

Pricing

Model: open_source
Free tier: Yes
Requires CC: No

Optax is Apache 2.0 licensed by Google DeepMind. Free for all use.

Agent Metadata

Pagination
none
Idempotent
Full
Retry Guidance
Not documented

Known Gotchas

  • optimizer.init() must be called with parameters pytree — optimizer = optax.adam(1e-3); opt_state = optimizer.init(params) initializes Adam moment vectors matching params structure; calling optimizer.update(grads, opt_state) with different pytree structure than params raises ValueError; agent code must init optimizer after defining model parameters
  • Gradient pytree must match parameter pytree exactly — optimizer.update(grads, opt_state) requires grads and params to have identical pytree structure (same keys, shapes); agent code where model architecture changes between init and update gets inscrutable JAX errors; verify pytree structures with jax.tree_util.tree_map
  • optax.chain order matters — optax.chain(optax.clip_by_global_norm(1.0), optax.adam(1e-3)) clips THEN updates; optax.chain(optax.adam(1e-3), optax.clip_by_global_norm(1.0)) updates THEN clips parameter changes (wrong); agent gradient clipping must come before optimizer transform in chain
  • Learning rate schedules need step counter — optax.cosine_decay_schedule produces a callable; optimizer.update(grads, opt_state, params, count=step) requires step counter for schedule evaluation; agent training loop must track and pass step count; forgetting count uses step 0 for entire training
  • AdamW weight decay excludes bias and LayerNorm — optax.adamw applies weight decay to ALL parameters including biases and LayerNorm weights; for correct transformer training use optax.masked to exclude bias terms from weight decay; incorrect weight decay degrades agent model quality
  • MultiSteps accumulate_fn must handle empty gradients — optax.MultiSteps accumulates over K steps; agent code computing gradients on empty batches (end of dataset) gets NaN accumulated gradients; ensure dataset length is divisible by accumulation steps or handle remainder batches explicitly

Alternatives

Full Evaluation Report

Detailed scoring breakdown, competitive positioning, security analysis, and improvement recommendations for Optax.

$99

Scores are editorial opinions as of 2026-03-06.

5173
Packages Evaluated
26151
Need Evaluation
173
Need Re-evaluation
Community Powered