Skip to content

AD through remake fails #4275

@hersle

Description

@hersle

I want to do use MTK+SciMLSensitivity in a way like this, but cannot get it to work, as it fails in the remake(prob):

using ModelingToolkit, OrdinaryDiffEq, Zygote, SciMLSensitivity

@independent_variables t
vars = @variables a(t)
pars = @parameters Ωr0 Ωm0 ΩΛ0
eqs = [Differential(t)(a) ~ (Ωr0 + Ωm0*a + ΩΛ0*a^4)]
@mtkcompile M = System(eqs, t, vars, pars)
prob = ODEProblem(M, [M.a => 1e-5, M.Ωr0 => NaN, M.Ωm0 => NaN, M.ΩΛ0 => NaN], (0.0, 2.0))

function a_final(x; sensealg = SciMLSensitivity.ForwardSensitivity())
    Ωr0, Ωm0 = x[1], x[2]
    ΩΛ0 = 1 - Ωr0 - Ωm0
    newprob = remake(prob; p = [M.Ωr0 => Ωr0, M.Ωm0 => Ωm0, M.ΩΛ0 => ΩΛ0])
    sol = solve(newprob, Tsit5(); save_everystep = false, sensealg)
    return sol[M.a][end]
end

x0 = [1e-5, 0.3, 0.7]
a_final(x0)
Zygote.gradient(a_final, x0)
ERROR: UndefVarError: `parameter_index` not defined in `MTKChainRulesCoreExt`
Suggestion: check for spelling errors or missing imports.
Hint: a global variable of this name also exists in SymbolicIndexingInterface.
Stacktrace:
  [1] (::MTKChainRulesCoreExt.var"#12#13"{ODEProblem{…}})(i::SymbolicUtils.BasicSymbolicImpl.var"typeof(BasicSymbolicImpl)"{SymReal})
    @ MTKChainRulesCoreExt ~/.julia/dev/ModelingToolkit/lib/ModelingToolkitBase/ext/MTKChainRulesCoreExt.jl:91
  [2] iterate
    @ ./generator.jl:48 [inlined]
  [3] _collect(c::Vector{Any}, itr::Base.Generator{Vector{…}, MTKChainRulesCoreExt.var"#12#13"{…}}, ::Base.EltypeUnknown, isz::Base.HasShape{1})
    @ Base ./array.jl:810
  [4] collect_similar
    @ ./array.jl:732 [inlined]
  [5] map
    @ ./abstractarray.jl:3372 [inlined]
  [6] rrule(::typeof(SymbolicIndexingInterface.remake_buffer), indp::ODEProblem{…}, oldbuf::MTKParameters{…}, idxs::Base.KeySet{…}, vals::Base.ValueIterator{…})
    @ MTKChainRulesCoreExt ~/.julia/dev/ModelingToolkit/lib/ModelingToolkitBase/ext/MTKChainRulesCoreExt.jl:90
  [7] rrule
    @ ~/.julia/packages/ChainRulesCore/Vsbj9/src/rules.jl:138 [inlined]
  [8] chain_rrule
    @ ~/.julia/packages/Zygote/55SqB/src/compiler/chainrules.jl:234 [inlined]
  [9] macro expansion
    @ ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:-1 [inlined]
 [10] _pullback(::Zygote.Context{…}, ::typeof(SymbolicIndexingInterface.remake_buffer), ::ODEProblem{…}, ::MTKParameters{…}, ::Base.KeySet{…}, ::Base.ValueIterator{…})
    @ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:81
 [11] _updated_u0_p_symmap
    @ ~/.julia/packages/SciMLBase/cgq4R/src/remake.jl:1325 [inlined]
 [12] _pullback(::Zygote.Context{…}, ::typeof(SciMLBase._updated_u0_p_symmap), ::ODEProblem{…}, ::Vector{…}, ::Val{…}, ::Dict{…}, ::Val{…}, ::Float64)
    @ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0
 [13] #_updated_u0_p_internal#887
    @ ~/.julia/packages/SciMLBase/cgq4R/src/remake.jl:1176 [inlined]
 [14] _pullback(::Zygote.Context{…}, ::SciMLBase.var"##_updated_u0_p_internal#887", ::Bool, ::Bool, ::typeof(SciMLBase._updated_u0_p_internal), ::ODEProblem{…}, ::Missing, ::Vector{…}, ::Float64)
    @ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0
 [15] _updated_u0_p_internal
    @ ~/.julia/packages/SciMLBase/cgq4R/src/remake.jl:1170 [inlined]
 [16] _pullback(::Zygote.Context{…}, ::typeof(Core.kwcall), ::@NamedTuple{…}, ::typeof(SciMLBase._updated_u0_p_internal), ::ODEProblem{…}, ::Missing, ::Vector{…}, ::Float64)
    @ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0
 [17] #updated_u0_p#905
    @ ~/.julia/packages/SciMLBase/cgq4R/src/remake.jl:1394 [inlined]
 [18] _pullback(::Zygote.Context{…}, ::SciMLBase.var"##updated_u0_p#905", ::Bool, ::Bool, ::typeof(SciMLBase.updated_u0_p), ::ODEProblem{…}, ::Missing, ::Vector{…}, ::Float64)
    @ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0
 [19] updated_u0_p
    @ ~/.julia/packages/SciMLBase/cgq4R/src/remake.jl:1387 [inlined]
 [20] _pullback(::Zygote.Context{…}, ::typeof(Core.kwcall), ::@NamedTuple{…}, ::typeof(SciMLBase.updated_u0_p), ::ODEProblem{…}, ::Missing, ::Vector{…}, ::Float64)
    @ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0
 [21] _pullback(::Zygote.Context{…}, ::SciMLBase.var"##remake#869", ::Missing, ::Missing, ::Missing, ::Vector{…}, ::Missing, ::Bool, ::Type{…}, ::Bool, ::Nothing, ::@Kwargs{}, ::typeof(remake), ::ODEProblem{…})
    @ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:81
 [22] remake
    @ ~/.julia/packages/SciMLBase/cgq4R/src/remake.jl:222 [inlined]
 [23] _pullback(::Zygote.Context{…}, ::typeof(Core.kwcall), ::@NamedTuple{…}, ::typeof(remake), ::ODEProblem{…})
    @ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0
 [24] #a_final#1
    @ ./REPL[1]:13 [inlined]
 [25] _pullback(::Zygote.Context{false}, ::var"##a_final#1", ::ForwardSensitivity{0, true, Val{:central}}, ::typeof(a_final), ::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0
 [26] a_final
    @ ./REPL[1]:10 [inlined]
 [27] _pullback(ctx::Zygote.Context{false}, f::typeof(a_final), args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0
 [28] pullback(f::Function, cx::Zygote.Context{false}, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface.jl:96
 [29] pullback
    @ ~/.julia/packages/Zygote/55SqB/src/compiler/interface.jl:94 [inlined]
 [30] gradient(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface.jl:153
 [31] top-level scope
    @ REPL[2]:1
Some type information was truncated. Use `show(err)` to see complete types.

The error points to an obvious missing import of parameter_index.
But deving MTK and adding the import in the chain rules extension only results in different errors.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions