diff --git a/Project.toml b/Project.toml index c120a91..61743e5 100644 --- a/Project.toml +++ b/Project.toml @@ -28,5 +28,5 @@ NNlib = "0.9" SpecialFunctions = "2" SymbolicUtils = "3" Symbolics = "6" -Zygote = "0.6" +Zygote = "0.6, 0.7" julia = "1.10" diff --git a/src/chainrules.jl b/src/chainrules.jl index 5d57291..fc934dd 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -1,4 +1,4 @@ -import ChainRulesCore: rrule, RuleConfig, ProjectTo, backing, @opt_out +import ChainRulesCore: rrule, RuleConfig, ProjectTo, backing, @opt_out, unthunk using Base.Broadcast: broadcasted function rrule(::Type{TaylorScalar}, v::T, p::NTuple{N, T}) where {N, T} @@ -26,6 +26,9 @@ function rrule(::typeof(partials), t::TaylorScalar{T, N}) where {N, T} function partials_pullback(::ZeroTangent) NoTangent(), TaylorScalar(z, ntuple(j -> zero(T), Val(N))) end + function partials_pullback(v̄::ChainRulesCore.AbstractThunk) + partials_pullback(unthunk(v̄)) + end return partials(t), partials_pullback end