-
-
Notifications
You must be signed in to change notification settings - Fork 241
Open
Labels
bugSomething isn't workingSomething isn't working
Description
I want to do use MTK+SciMLSensitivity in a way like this, but cannot get it to work, as it fails in the remake(prob):
using ModelingToolkit, OrdinaryDiffEq, Zygote, SciMLSensitivity
@independent_variables t
vars = @variables a(t)
pars = @parameters Ωr0 Ωm0 ΩΛ0
eqs = [Differential(t)(a) ~ √(Ωr0 + Ωm0*a + ΩΛ0*a^4)]
@mtkcompile M = System(eqs, t, vars, pars)
prob = ODEProblem(M, [M.a => 1e-5, M.Ωr0 => NaN, M.Ωm0 => NaN, M.ΩΛ0 => NaN], (0.0, 2.0))
function a_final(x; sensealg = SciMLSensitivity.ForwardSensitivity())
Ωr0, Ωm0 = x[1], x[2]
ΩΛ0 = 1 - Ωr0 - Ωm0
newprob = remake(prob; p = [M.Ωr0 => Ωr0, M.Ωm0 => Ωm0, M.ΩΛ0 => ΩΛ0])
sol = solve(newprob, Tsit5(); save_everystep = false, sensealg)
return sol[M.a][end]
end
x0 = [1e-5, 0.3, 0.7]
a_final(x0)
Zygote.gradient(a_final, x0)ERROR: UndefVarError: `parameter_index` not defined in `MTKChainRulesCoreExt`
Suggestion: check for spelling errors or missing imports.
Hint: a global variable of this name also exists in SymbolicIndexingInterface.
Stacktrace:
[1] (::MTKChainRulesCoreExt.var"#12#13"{ODEProblem{…}})(i::SymbolicUtils.BasicSymbolicImpl.var"typeof(BasicSymbolicImpl)"{SymReal})
@ MTKChainRulesCoreExt ~/.julia/dev/ModelingToolkit/lib/ModelingToolkitBase/ext/MTKChainRulesCoreExt.jl:91
[2] iterate
@ ./generator.jl:48 [inlined]
[3] _collect(c::Vector{Any}, itr::Base.Generator{Vector{…}, MTKChainRulesCoreExt.var"#12#13"{…}}, ::Base.EltypeUnknown, isz::Base.HasShape{1})
@ Base ./array.jl:810
[4] collect_similar
@ ./array.jl:732 [inlined]
[5] map
@ ./abstractarray.jl:3372 [inlined]
[6] rrule(::typeof(SymbolicIndexingInterface.remake_buffer), indp::ODEProblem{…}, oldbuf::MTKParameters{…}, idxs::Base.KeySet{…}, vals::Base.ValueIterator{…})
@ MTKChainRulesCoreExt ~/.julia/dev/ModelingToolkit/lib/ModelingToolkitBase/ext/MTKChainRulesCoreExt.jl:90
[7] rrule
@ ~/.julia/packages/ChainRulesCore/Vsbj9/src/rules.jl:138 [inlined]
[8] chain_rrule
@ ~/.julia/packages/Zygote/55SqB/src/compiler/chainrules.jl:234 [inlined]
[9] macro expansion
@ ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:-1 [inlined]
[10] _pullback(::Zygote.Context{…}, ::typeof(SymbolicIndexingInterface.remake_buffer), ::ODEProblem{…}, ::MTKParameters{…}, ::Base.KeySet{…}, ::Base.ValueIterator{…})
@ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:81
[11] _updated_u0_p_symmap
@ ~/.julia/packages/SciMLBase/cgq4R/src/remake.jl:1325 [inlined]
[12] _pullback(::Zygote.Context{…}, ::typeof(SciMLBase._updated_u0_p_symmap), ::ODEProblem{…}, ::Vector{…}, ::Val{…}, ::Dict{…}, ::Val{…}, ::Float64)
@ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0
[13] #_updated_u0_p_internal#887
@ ~/.julia/packages/SciMLBase/cgq4R/src/remake.jl:1176 [inlined]
[14] _pullback(::Zygote.Context{…}, ::SciMLBase.var"##_updated_u0_p_internal#887", ::Bool, ::Bool, ::typeof(SciMLBase._updated_u0_p_internal), ::ODEProblem{…}, ::Missing, ::Vector{…}, ::Float64)
@ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0
[15] _updated_u0_p_internal
@ ~/.julia/packages/SciMLBase/cgq4R/src/remake.jl:1170 [inlined]
[16] _pullback(::Zygote.Context{…}, ::typeof(Core.kwcall), ::@NamedTuple{…}, ::typeof(SciMLBase._updated_u0_p_internal), ::ODEProblem{…}, ::Missing, ::Vector{…}, ::Float64)
@ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0
[17] #updated_u0_p#905
@ ~/.julia/packages/SciMLBase/cgq4R/src/remake.jl:1394 [inlined]
[18] _pullback(::Zygote.Context{…}, ::SciMLBase.var"##updated_u0_p#905", ::Bool, ::Bool, ::typeof(SciMLBase.updated_u0_p), ::ODEProblem{…}, ::Missing, ::Vector{…}, ::Float64)
@ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0
[19] updated_u0_p
@ ~/.julia/packages/SciMLBase/cgq4R/src/remake.jl:1387 [inlined]
[20] _pullback(::Zygote.Context{…}, ::typeof(Core.kwcall), ::@NamedTuple{…}, ::typeof(SciMLBase.updated_u0_p), ::ODEProblem{…}, ::Missing, ::Vector{…}, ::Float64)
@ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0
[21] _pullback(::Zygote.Context{…}, ::SciMLBase.var"##remake#869", ::Missing, ::Missing, ::Missing, ::Vector{…}, ::Missing, ::Bool, ::Type{…}, ::Bool, ::Nothing, ::@Kwargs{}, ::typeof(remake), ::ODEProblem{…})
@ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:81
[22] remake
@ ~/.julia/packages/SciMLBase/cgq4R/src/remake.jl:222 [inlined]
[23] _pullback(::Zygote.Context{…}, ::typeof(Core.kwcall), ::@NamedTuple{…}, ::typeof(remake), ::ODEProblem{…})
@ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0
[24] #a_final#1
@ ./REPL[1]:13 [inlined]
[25] _pullback(::Zygote.Context{false}, ::var"##a_final#1", ::ForwardSensitivity{0, true, Val{:central}}, ::typeof(a_final), ::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0
[26] a_final
@ ./REPL[1]:10 [inlined]
[27] _pullback(ctx::Zygote.Context{false}, f::typeof(a_final), args::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0
[28] pullback(f::Function, cx::Zygote.Context{false}, args::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface.jl:96
[29] pullback
@ ~/.julia/packages/Zygote/55SqB/src/compiler/interface.jl:94 [inlined]
[30] gradient(f::Function, args::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface.jl:153
[31] top-level scope
@ REPL[2]:1
Some type information was truncated. Use `show(err)` to see complete types.
The error points to an obvious missing import of parameter_index.
But deving MTK and adding the import in the chain rules extension only results in different errors.
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working