From 03f3d8eaf68d9477677d0f737adb1ce11a20802c Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Thu, 19 Jun 2025 11:17:12 +0000 Subject: [PATCH 1/2] Allow Zygote v0.7 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index c120a91..61743e5 100644 --- a/Project.toml +++ b/Project.toml @@ -28,5 +28,5 @@ NNlib = "0.9" SpecialFunctions = "2" SymbolicUtils = "3" Symbolics = "6" -Zygote = "0.6" +Zygote = "0.6, 0.7" julia = "1.10" From b81292ae884427030a1206f368e3a1b9bb48d080 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Sat, 19 Jul 2025 10:25:56 -0400 Subject: [PATCH 2/2] auto-unthunk partials --- src/chainrules.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/chainrules.jl b/src/chainrules.jl index 5d57291..fc934dd 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -1,4 +1,4 @@ -import ChainRulesCore: rrule, RuleConfig, ProjectTo, backing, @opt_out +import ChainRulesCore: rrule, RuleConfig, ProjectTo, backing, @opt_out, unthunk using Base.Broadcast: broadcasted function rrule(::Type{TaylorScalar}, v::T, p::NTuple{N, T}) where {N, T} @@ -26,6 +26,9 @@ function rrule(::typeof(partials), t::TaylorScalar{T, N}) where {N, T} function partials_pullback(::ZeroTangent) NoTangent(), TaylorScalar(z, ntuple(j -> zero(T), Val(N))) end + function partials_pullback(v̄::ChainRulesCore.AbstractThunk) + partials_pullback(unthunk(v̄)) + end return partials(t), partials_pullback end