Skip to content

Add multivariate optimization#331

Open
nyo16 wants to merge 1 commit intoelixir-nx:mainfrom
nyo16:add-multivariate-optimization
Open

Add multivariate optimization#331
nyo16 wants to merge 1 commit intoelixir-nx:mainfrom
nyo16:add-multivariate-optimization

Conversation

@nyo16
Copy link
Contributor

@nyo16 nyo16 commented Jan 30, 2026

This PR adds two multivariate optimization algorithms to Scholar, completing the optimization module:

  • BFGS - Quasi-Newton method with automatic differentiation for smooth, differentiable functions
  • Nelder-Mead - Derivative-free simplex method for functions where gradients are unavailable or expensive

Both implementations follow the patterns established in PRs #327 (Golden Section) and #328 (Brent):

  • defn entry point with deftransformp for option validation
  • Initial point as explicit function argument
  • Module constants used directly (not wrapped in Nx.tensor)
  • Each module has its own NimbleOptions schema
  • JIT/GPU compatible

Changes

File Description
lib/scholar/optimize/bfgs.ex BFGS implementation with backtracking line search
lib/scholar/optimize/nelder_mead.ex Nelder-Mead simplex implementation
test/scholar/optimize/bfgs_test.exs Tests for BFGS (sphere, Rosenbrock, Booth, Beale functions)
test/scholar/optimize/nelder_mead_test.exs Tests for Nelder-Mead
notebooks/optimize.livemd Updated with multivariate optimization section
notebooks/efficient_frontier.livemd Portfolio optimization example using both algorithms
compare_*.py SciPy validation scripts

Test plan

  • All existing tests pass (mix test)
  • New algorithm tests pass (mix test test/scholar/optimize/)
  • Doctests pass
  • Code formatted (mix format)
  • Test livebooks manually in Livebook"

- BFGS: Quasi-Newton method with automatic differentiation
- Nelder-Mead: Derivative-free simplex method
- Update optimize.livemd with multivariate section
- Add efficient_frontier.livemd portfolio optimization example
@nyo16 nyo16 force-pushed the add-multivariate-optimization branch from f9d143a to f2404ad Compare January 30, 2026 23:33
@nyo16
Copy link
Contributor Author

nyo16 commented Jan 30, 2026

And comparing them with python

BFGS Results:
  ┌───────────────────┬──────────────┬────────────────────┬────────┐
  │     Function      │ Expected Min │    SciPy Found     │ Status │
  ├───────────────────┼──────────────┼────────────────────┼────────┤
  │ sphere_2d         │ [0, 0]       │ [~0, ~0]           │ ✓      │
  ├───────────────────┼──────────────┼────────────────────┼────────┤
  │ sphere_3d         │ [0, 0, 0]    │ [~0, ~0, ~0]       │ ✓      │
  ├───────────────────┼──────────────┼────────────────────┼────────┤
  │ rosenbrock        │ [1, 1]       │ [0.99999, 0.99999] │ ✓      │
  ├───────────────────┼──────────────┼────────────────────┼────────┤
  │ booth             │ [1, 3]       │ [1, 3]             │ ✓      │
  ├───────────────────┼──────────────┼────────────────────┼────────┤
  │ beale             │ [3, 0.5]     │ [3, 0.5]           │ ✓      │
  ├───────────────────┼──────────────┼────────────────────┼────────┤
  │ shifted_quadratic │ [2, -3, 1]   │ [2, -3, 1]         │ ✓      │
  └───────────────────┴──────────────┴────────────────────┴────────┘
  Nelder-Mead Results:
  ┌───────────────────┬──────────────┬──────────────┬────────┐
  │     Function      │ Expected Min │ SciPy Found  │ Status │
  ├───────────────────┼──────────────┼──────────────┼────────┤
  │ sphere_2d         │ [0, 0]       │ [~0, ~0]     │ ✓      │
  ├───────────────────┼──────────────┼──────────────┼────────┤
  │ sphere_3d         │ [0, 0, 0]    │ [~0, ~0, ~0] │ ✓      │
  ├───────────────────┼──────────────┼──────────────┼────────┤
  │ rosenbrock        │ [1, 1]       │ [1, 1]       │ ✓      │
  ├───────────────────┼──────────────┼──────────────┼────────┤
  │ booth             │ [1, 3]       │ [1, 3]       │ ✓      │
  ├───────────────────┼──────────────┼──────────────┼────────┤
  │ beale             │ [3, 0.5]     │ [3, 0.5]     │ ✓      │
  ├───────────────────┼──────────────┼──────────────┼────────┤
  │ shifted_quadratic │ [2, -3, 1]   │ [2, -3, 1]   │ ✓      │
  └───────────────────┴──────────────┴──────────────┴────────┘

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant