feat: AutoSpecialize norecompile infrastructure for NonlinearSolveBase#838
feat: AutoSpecialize norecompile infrastructure for NonlinearSolveBase#838ChrisRackauckas-Claude wants to merge 16 commits intoSciML:masterfrom
Conversation
CI Fix: ImmutableNonlinearProblem wrappingThe adjoint test in SimpleNonlinearSolve failed because Fix: Skip FunctionWrapper wrapping for The other CI failures (runic, alloc_check, wrappers) are pre-existing/infrastructure issues:
|
Reverted automatic wrapping - infrastructure only for nowThe
These bypass Current state: The AutoSpecialize infrastructure is fully implemented and ready:
What's needed for automatic wrapping: All ForwardDiff call sites need to be coordinated to use chunksize=1 when the function is wrapped. This is analogous to how DiffEqBase/OrdinaryDiffEq controls all internal ForwardDiff calls. |
Update: AutoSpecialize wrapping now fully activatedThis commit activates the wrapping infrastructure end-to-end. Key design decisions: Architecture
What gets wrapped
Tag standardization
Known trade-off
Local test results (all clean)
|
Adjoint/Reverse-Mode AD Fix (commit 7abe63b)The Root CauseWhen FixTwo-pronged approach:
Test Results (local, Julia 1.10)
Also updated the IIP |
…Base
Port the FunctionWrappersWrappers-based norecompile pattern from DiffEqBase
to NonlinearSolveBase. For standard problem types (Vector{Float64} state,
Vector{Float64} or NullParameters parameters), the problem function is
wrapped in a FunctionWrappersWrapper with precompiled type signatures for
both Float64 and ForwardDiff.Dual arguments, avoiding recompilation for
each unique user function type.
Key components:
- src/autospecialize.jl: NonlinearSolveTag, wrapfun_iip/oop base methods,
maybe_wrap_nonlinear_f, standardize_forwarddiff_tag fallback
- ForwardDiff extension: dual-aware wrapfun dispatches with 6 type
combinations (Float64, Dual, NullParameters), tag standardization that
stamps NonlinearSolveTag on AutoForwardDiff and forces chunksize=1 when
the function is wrapped
- solve.jl: maybe_wrap_f wired into get_concrete_problem for all problem
types (NonlinearProblem, NonlinearLeastSquaresProblem,
ImmutableNonlinearProblem), using EvalFunc wrapper for invokelatest
- jacobian.jl: standardize_forwarddiff_tag called in
construct_jacobian_cache so DI produces correctly-tagged duals
Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
ImmutableNonlinearProblem (used by SimpleNonlinearSolve) doesn't support Setfield reconstruction with wrapped function types. Skip wrapping since SimpleNonlinearSolve's lighter solvers don't benefit from the norecompile pathway. Fixes CI adjoint test failure in SimpleNonlinearSolve. Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com> Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The FunctionWrapper wrapping cannot be automatically applied at solve time because multiple code paths (∂f/∂p, ∂f/∂u, bounds transform) call ForwardDiff directly with default chunk sizes, bypassing the standardized chunksize=1 path. This caused "No matching function wrapper found!" errors whenever ForwardDiff used chunksize > 1. The infrastructure (autospecialize.jl, extension wrappers, tag standardization) remains available for targeted use. Automatic wrapping requires coordinating ALL ForwardDiff call sites to use chunksize=1. Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com> Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The standardize_forwarddiff_tag calls in autodiff.jl and jacobian.jl cause dual tag ordering errors when nested ForwardDiff is used (e.g., NLLS sensitivity + inner VJP). Remove these call sites and the unused maybe_wrap_f function since automatic wrapping is not yet active. The autospecialize infrastructure (NonlinearSolveTag, wrapfun_iip/oop, ForwardDiff extension wrappers) remains available for future activation when all direct ForwardDiff call sites are standardized. Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
Wire `maybe_wrap_f` into `get_concrete_problem` for NonlinearProblem and
NonlinearLeastSquaresProblem (IIP). Functions are wrapped in
`AutoSpecializeCallable{FW}` which holds a `FunctionWrappersWrapper` for
precompiled dispatch and the original function (type-erased as `Any`) for
try-catch fallback when dual tags mismatch (JVP paths, external packages).
Key changes:
- AutoSpecializeCallable uses `orig::Any` for type erasure (no EvalFunc)
- Skip OOP NLLS wrapping (return type may differ from u0)
- Standardize JVP/VJP autodiff tags in construct_jacobian_cache
- Replace AutoPolyesterForwardDiff with AutoForwardDiff{1,tag} when wrapped
- Use get_raw_f for nested ForwardDiff in NLLS VJP generation
- ForwardDiff sensitivity functions use chunksize=1 + tag when wrapped
Tests: core 727/0/0, wrapper 195/0/0, ForwardDiff 135636/0/0
OOP @inferred regresses (expected, same trade-off as DiffEqBase)
Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
Reverse-mode AD backends (Zygote, Mooncake, Enzyme) cannot differentiate through FunctionWrapper internals (llvmcall). This adds: - ChainRulesCore rrule for AutoSpecializeCallable that redirects reverse-mode AD through the original unwrapped callable - _DISABLE_AUTOSPECIALIZE flag set in the solve_up rrule to prevent wrapping entirely during the adjoint code path - @test_broken for IIP @inferred (same wrapping-induced regression as OOP case) Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
…type - Remove task-local `_DISABLE_AUTOSPECIALIZE` flag entirely - Replace with `@set prob.f.f = get_raw_f(prob.f.f)` unwrapping in rrule - Remove parameter type restriction (any p works, mismatches fall back) - Add idempotency check to prevent double-wrapping - Remove `_DISABLE_AUTOSPECIALIZE` from public API Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
These are internal implementation details, not public API. Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
OOP wrapping requires guessing return types which doesn't always work. Only wrap IIP functions where the return type is always Nothing. IIP TTFX improvement (2nd/3rd function, same types): - NewtonRaphson: 2.2-2.5x faster - TrustRegion: 18x faster Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
The existing workload used scalar p=2.0, which produces different
FunctionWrapper types than the common user case of Vector{Float64}
parameters. This caused the precompiled wrappers to miss the user path.
TTFX for IIP Vector{Float64} first solve: 2.7s → 1.0s
Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
The try-catch in AutoSpecializeCallable prevented inlining and added
~32 bytes per call, exceeding the 64-byte @ballocated budget in
NonlinearSolveFirstOrder, QuasiNewton, and SpectralMethods tests.
Replace with explicit dispatch methods for known argument types
(Vector{Float64}, Float64, NullParameters, and ForwardDiff duals),
routing to f.fw for zero-allocation calls. Unsupported types fall
back to f.orig via vararg dispatch. Also fix @test_broken -> @test
for @inferred solve(prob) which now passes.
Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
Enzyme cannot differentiate through FunctionWrappers' llvmcall, causing
EnzymeMutabilityException in all IIP Vector{Float64} tests with AutoEnzyme.
Unwrap the function in construct_jacobian_cache when the AD backend is
Enzyme-based (including AutoSparse(AutoEnzyme(...))), so DI sees the raw
user function. Also apply Runic formatting to SCCNonlinearSolve files.
Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com> Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
5f8115f to
b6917fb
Compare
The previous commit only unwrapped for the concrete Jacobian path (DI.prepare_jacobian/DI.jacobian). This extends the fix to the JacobianOperator path used by Krylov solvers (GMRES, etc.) and backslash with concrete_jac=false. When Enzyme is used for JVP/VJP autodiff, create a modified problem with the raw user function so SciMLJacobianOperators' DI.pushforward!/ DI.pullback! calls don't go through FunctionWrappers' llvmcall. Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com> Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Fix: Extend Enzyme unwrap to JacobianOperator pathThe previous push fixed Enzyme compatibility for the concrete Jacobian path ( Root cause analysis from CI logs: When Fix: In Verified locally:
|
… operators The TrustRegion scheme creates VecJacOperator and JacVecOperator directly from the problem, bypassing construct_jacobian_cache. When Enzyme is the AD backend, these operators need the unwrapped function (without FunctionWrappers) to avoid EnzymeMutabilityException. Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com> Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Additional Enzyme fix: TrustRegion VecJac/JacVec operatorsThe previous commit fixed the Root cause from CI log: Fix: Added Enzyme unwrap check before creating VecJac/JacVec operators in |
Instead of fixing individual call sites (trust_region.jl VecJac/JacVec, jacobian.jl construct_jacobian_cache), create _ad_prob with unwrapped function early in __init for both FirstOrder and QuasiNewton. This ensures ALL downstream AD consumers (Jacobian cache, trust region, linesearch, forcing) receive the unwrapped problem when Enzyme is used. - Add maybe_unwrap_prob_for_enzyme helper in NonlinearSolveBase - FirstOrder: create _ad_prob from alg.autodiff/jvp_autodiff/vjp_autodiff - QuasiNewton: detect Enzyme from kwargs and alg.linesearch/trustregion - Revert trust_region.jl inline fix (now handled upstream in solve.jl) Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
Note: Enzyme workaround simplificationThe Enzyme-specific workarounds in this PR ( A companion PR has been opened at EnzymeAD/Enzyme.jl#2980 that adds an Once that Enzyme PR is merged and released, the following code can be removed from this PR (~73 lines):
All references to |
|
Waiting on the Enzyme PR |
Session State Summary (for continuation)What was doneEnzyme.jl PR #2980 — EnzymeAD/Enzyme.jl#2980
NonlinearSolve.jl PR #838 — #838
Simplification attempt (reverted)Attempted to remove ~73 lines of Enzyme workaround code from 4 files:
Result: 116 Dependency chain: Enzyme PR #2980 must be merged and released BEFORE these workarounds can be removed. Key technical details
RemotesNonlinearSolve.jl:
Enzyme.jl:
Next steps
Co-Authored-By: Chris Rackauckas accounts@chrisrackauckas.com |
Summary
Ports the FunctionWrappersWrappers-based norecompile infrastructure from DiffEqBase to NonlinearSolveBase, following the approach described in https://sciml.ai/news/2022/09/21/compile_time/.
This PR adds the infrastructure only — the wrappers, tag types, and extension methods needed for the norecompile/AutoSpecialize pattern. The automatic wrapping at solve-time is not yet activated because NonlinearSolve has several direct
ForwardDiff.jacobiancall sites (sensitivity analysis df/dp, df/du, bounds transforms) that bypass the DI-based Jacobian path and would receive duals with mismatched chunk sizes.What's included
src/autospecialize.jl(new):NonlinearSolveTag,wrapfun_iip/wrapfun_oopstub methods,maybe_wrap_nonlinear_f,standardize_forwarddiff_tagfallbackNonlinearSolveBaseForwardDiffExt.jl): Dual-awarewrapfun_iip/wrapfun_oopdispatches with 6 type combinations each (Float64,Dual{NonlinearSolveTag}, NullParameters). Tag standardization that stampsNonlinearSolveTagonAutoForwardDiffand forceschunksize=1when the function is wrapped viaEvalFunc.NonlinearSolveBase.jl: FunctionWrappers/FunctionWrappersWrappers imports, exports for the new public APIProject.toml: FunctionWrappers and FunctionWrappersWrappers dependenciesWhat's NOT included (deferred)
get_concrete_problem/ solve path — requires standardizing all direct ForwardDiff call sites first (nonlinearsolve_∂f_∂p,nonlinearsolve_∂f_∂u, bounds transform code)Design notes
The infrastructure mirrors DiffEqBase's pattern:
Vector{Float64}state,Vector{Float64}orNullParametersparameters),maybe_wrap_nonlinear_fwraps the function in aFunctionWrappersWrapperwith precompiled dual type signaturesstandardize_forwarddiff_tagcoordinates the ForwardDiff tag (NonlinearSolveTag) and chunk size (N=1) so duals match the wrapper signaturesNext steps to activate
ForwardDiff.jacobian/derivative/gradientcalls through DI or standardize their tag/chunksizemaybe_wrap_fintoget_concrete_problemstandardize_forwarddiff_tagbeforeconstruct_concrete_adtypein jacobian.jl