@@ -40,6 +40,8 @@ Base.propertynames(p::ProjectTo) = propertynames(backing(p))
4040backing(project:: ProjectTo ) = getfield(project, :info)
4141
4242project_type(p:: ProjectTo{T} ) where {T} = T
43+ project_type(:: Type{<:ProjectTo{T}} ) where {T} = T
44+ project_type(_) = Any
4345
4446function Base. show(io:: IO , project:: ProjectTo{T} ) where {T}
4547 print(io, " ProjectTo{" )
@@ -142,42 +144,16 @@ end
142144# dx::AbstractArray (when both are possible), or the reverse. So for now we just pass them through:
143145(:: ProjectTo{T} )(dx:: Tangent{<:T} ) where {T} = dx
144146
145- # ####
146- # #### A related utility which wants to live nearby
147- # ####
148-
149- """
150- is_non_differentiable(x) == is_non_differentiable(typeof(x))
151-
152- Returns `true` if `x` is known from its type not to have derivatives, else `false`.
153-
154- Should mostly agree with whether `ProjectTo(x)` maps to `AbstractZero`,
155- which is what the fallback method checks. The exception is that it will not look
156- inside abstractly typed containers like `x = Any[true, false]`.
157- """
158- is_non_differentiable(x) = is_non_differentiable(typeof(x))
159-
160- is_non_differentiable(:: Type{<:Number} ) = false
161- is_non_differentiable(:: Type{<:NTuple{N,T}} ) where {N,T} = is_non_differentiable(T)
162- is_non_differentiable(:: Type{<:AbstractArray{T}} ) where {T} = is_non_differentiable(T)
163-
164- function is_non_differentiable(:: Type{T} ) where {T} # fallback
165- PT = Base. _return_type(ProjectTo, Tuple{T}) # might be Union{} if unstable
166- return isconcretetype(PT) && PT <: ProjectTo{<:AbstractZero}
167- end
168-
169147# ####
170148# #### `Base`
171149# ####
172150
173151# Bool
174152ProjectTo(:: Bool ) = ProjectTo{NoTangent}() # same projector as ProjectTo(::AbstractZero) above
175- is_non_differentiable(:: Type{Bool} ) = true
176153
177154# Other never-differentiable types
178- for T in (:Symbol, :Char, :AbstractString, :RoundingMode, :IndexStyle)
155+ for T in (:Symbol, :Char, :AbstractString, :RoundingMode, :IndexStyle, :Nothing )
179156 @eval ProjectTo(:: $T ) = ProjectTo{NoTangent}()
180- @eval is_non_differentiable(:: Type{<:$T} ) = true
181157end
182158
183159# Numbers
@@ -627,3 +603,40 @@ function (project::ProjectTo{SparseMatrixCSC})(dx::SparseMatrixCSC)
627603 invoke(project, Tuple{AbstractArray}, dx)
628604 end
629605end
606+
607+ # ####
608+ # #### A related utility which wants to live nearby
609+ # ####
610+
611+ """
612+ differential_type(x)
613+ differential_type(typeof(x))
614+
615+ Testing `differential_type(x) <: AbstractZero` will tell you whether `x` is
616+ known to be non-differentiable.
617+
618+ This relies on `ProjectTo(x)`, and the method accepting a type relies on type inference.
619+ Thus it will not look inside abstractly typed containers such as `x = Any[true, false]`.
620+
621+ ```jldoctest
622+ julia> differential_type(true)
623+ NoTangent
624+
625+ julia> differential_type(Int)
626+ Float64
627+
628+ julia> x = Any[true, false];
629+
630+ julia> differential_type(x)
631+ NoTangent
632+
633+ julia> differential_type(typeof(x))
634+ Any
635+ ```
636+ """
637+ differential_type(x) = project_type(ProjectTo(x))
638+
639+ function differential_type(:: Type{T} ) where {T}
640+ PT = Base. _return_type(ProjectTo, Tuple{T}) # might be Union{} if unstable
641+ return isconcretetype(PT) ? project_type(PT) : Any
642+ end
0 commit comments