|
1 | 1 | function ChainRulesCore.rrule(::typeof(getproperty), x::ComponentArray, s::Union{Symbol,Val}) |
2 | | - return getproperty(x, s), Δ -> getproperty_adjoint(Δ, x, s) |
| 2 | + return getproperty(x, s), Δ -> getproperty_adjoint(ChainRulesCore.unthunk(Δ), x, s) |
3 | 3 | end |
4 | 4 |
|
5 | 5 | function getproperty_adjoint(Δ, x, s) |
@@ -28,9 +28,9 @@ function ChainRulesCore.rrule(cfg::ChainRulesCore.RuleConfig{>:ChainRulesCore.Ha |
28 | 28 | return y_, pb_f |
29 | 29 | end |
30 | 30 |
|
31 | | -ChainRulesCore.rrule(::typeof(getdata), x::ComponentArray) = getdata(x), Δ -> (ChainRulesCore.NoTangent(), ComponentArray(Δ, getaxes(x))) |
| 31 | +ChainRulesCore.rrule(::typeof(getdata), x::ComponentArray) = getdata(x), Δ -> (ChainRulesCore.NoTangent(), ComponentArray(ChainRulesCore.unthunk(Δ), getaxes(x))) |
32 | 32 |
|
33 | | -ChainRulesCore.rrule(::Type{ComponentArray}, data, axes) = ComponentArray(data, axes), Δ -> (ChainRulesCore.NoTangent(), getdata(Δ), ChainRulesCore.NoTangent()) |
| 33 | +ChainRulesCore.rrule(::Type{ComponentArray}, data, axes) = ComponentArray(data, axes), Δ -> (ChainRulesCore.NoTangent(), getdata(ChainRulesCore.unthunk(Δ)), ChainRulesCore.NoTangent()) |
34 | 34 |
|
35 | 35 | function ChainRulesCore.ProjectTo(ca::ComponentArray) |
36 | 36 | return ChainRulesCore.ProjectTo{ComponentArray}(; project=ChainRulesCore.ProjectTo(getdata(ca)), axes=getaxes(ca)) |
|
49 | 49 | function ChainRulesCore.rrule(::Type{CA}, nt::NamedTuple) where {CA<:ComponentArray} |
50 | 50 | y = CA(nt) |
51 | 51 |
|
| 52 | + ∇NamedTupleToComponentArray(Δ) = ∇NamedTupleToComponentArray(ChainRulesCore.unthunk(Δ)) |
| 53 | + |
52 | 54 | function ∇NamedTupleToComponentArray(Δ::AbstractArray) |
53 | 55 | if length(Δ) == length(y) |
54 | 56 | return ∇NamedTupleToComponentArray(ComponentArray(vec(Δ), getaxes(y))) |
|
0 commit comments