Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/runtests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
julia-version: ['1.8', '1.9']
julia-version: ['1.8', '1.9', '1.10']
os: [ubuntu-latest]
provider: ['mkl', 'fftw']
fail-fast: false
Expand Down
16 changes: 9 additions & 7 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Bijections = "e2ed5e7c-b2de-5872-ae92-c73ca462fb04"
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
CodecZlib = "944b1d66-785c-5afd-91f1-9de20f533193"
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
Expand All @@ -22,6 +23,7 @@ FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549"
Formatting = "59287772-0a20-5a39-b81b-1366585eb4c0"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
Healpix = "9f4e344d-96bc-545a-84a3-ae6b9e1b672b"
ImageFiltering = "6a3955dd-da59-5b1f-98d4-e7296123deb5"
ImageMorphology = "787d08f9-d448-5407-9aad-5290dd7ab264"
Expand All @@ -47,6 +49,7 @@ Optim = "429524aa-4258-5aef-a3af-852621145aeb"
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
PlotUtils = "995b91a9-d308-5afd-9ec6-746e21dbc043"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Expand All @@ -58,7 +61,6 @@ Roots = "f2b01f46-fcfa-551c-844a-d8ac1e96c665"
Rotations = "6038ab10-8711-5258-84ad-4b1120ba62dc"
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Expand Down Expand Up @@ -86,12 +88,12 @@ CMBLensingPythonPlotExt = "PythonPlot"
AbstractFFTs = "0.5, 1"
Adapt = "1.0.1, 2, 3"
Bijections = "0.1"
CUDA = "3.4, 4"
CUDA = "3.4, 4, 5"
ChainRules = "1.5"
CodecZlib = "0.7"
Combinatorics = "1"
CompositeStructs = "0.1.1"
ComponentArrays = "0.13, 0.14, 0.15"
CompositeStructs = "0.1.1"
CoordinateTransformations = "0.6.2"
DataStructures = "0.17.9, 0.18"
Distributions = "0.25"
Expand All @@ -110,9 +112,9 @@ JLD2 = "0.4.30"
KahanSummation = "0.1, 0.2"
Lazy = "0.13.2, 0.14, 0.15"
Loess = "0.5"
MCMCDiagnosticTools = "0.3"
MacroTools = "0.5"
Match = "1.1"
MCMCDiagnosticTools = "0.3"
Measurements = "2"
Memoization = "0.2"
MuseInference = "0.2.2"
Expand All @@ -121,6 +123,7 @@ NamedTupleTools = "0.13"
Optim = "1"
PDMats = "0.11.5"
PlotUtils = "1.3.2"
PrecompileTools = "1"
Preferences = "1.2"
ProgressMeter = "1.2"
QuadGK = "2.3.1"
Expand All @@ -129,14 +132,13 @@ Requires = "0.5, 1"
Roots = "0.8.4, 1, 2"
Rotations = "1.3.4"
Setfield = "0.6, 0.7, 0.8, 1"
PrecompileTools = "1"
StaticArrays = "0.12.1, 1.0"
StatsBase = "0.32, 0.33"
TimerOutputs = "0.5"
Tullio = "0.3"
UnPack = "1"
Zygote = "0.6.21"
julia = "1.8"
Zygote = "0.6.21, 0.7"
julia = "1.8, 1.10"

[extras]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Expand Down
73 changes: 71 additions & 2 deletions ext/CMBLensingCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@ using AbstractFFTs
using EllipsisNotation
using ForwardDiff
using ForwardDiff: Dual, Partials, value, partials
using GPUArrays
using LinearAlgebra
using Markdown
using Memoization
using Random
using SparseArrays
using StaticArrays
using Zygote

const CuBaseField{B,M,T,A<:CuArray} = BaseField{B,M,T,A}
Expand Down Expand Up @@ -62,7 +65,7 @@ Adapt.adapt_structure(::Type{<:Array}, L::SparseMatrixCSC) = L

# some Random API which CUDA doesn't implement yet
Random.randn(rng::CUDA.CURAND.RNG, T::Random.BitFloatType) =
cpu(randn!(rng, CuVector{T}(undef,1)))[1]
cpu(randn!(rng, CUDA.CuVector{T}(undef,1)))[1]

# perhaps minor type-piracy, but this lets us simulate into a CuArray using the
# CPU random number generator
Expand Down Expand Up @@ -120,6 +123,7 @@ function ForwardDiff.extract_gradient_chunk!(::Type{T}, result::CuArray, dual, i
end

# fix for https://github.com/jonniedie/ComponentArrays.jl/issues/193
# A method ambiguity between Base.reshape and ComponentArrays.reshape for 0-dim arrays
function Base.reshape(a::CuArray{T,M}, dims::Tuple{}) where {T,M}
if prod(dims) != length(a)
throw(DimensionMismatch("new dimensions $(dims) must be consistent with array size $(size(a))"))
Expand All @@ -129,8 +133,73 @@ function Base.reshape(a::CuArray{T,M}, dims::Tuple{}) where {T,M}
return a
end

CUDA._derived_array(T, 0, a, dims)
GPUArrays.derive(T, a, dims, 0)
end


function CMBLensing.BilinearLens(ϕ::FlatField{B1,M1,CT,AA}) where {B1,M1,CT,AA<:CuArray}

# if ϕ == 0 then just return identity operator
if norm(ϕ) == 0
return BilinearLens(ϕ,I,I)
end

@unpack Nbatch,Nx,Ny,Δx = ϕ
T = real(ϕ.T)
Nbatch > 1 && error("BilinearLens with batched ϕ not implemented yet.")

# the (i,j)-th pixel is deflected to (ĩs[i],j̃s[j])
j̃s,ĩs = getindex.((∇*ϕ)./Δx, :Ix)
ĩs .= ĩs .+ (1:Ny)
j̃s .= (j̃s' .+ (1:Nx))'

# sub2ind converts a 2D index to 1D index, including wrapping at edges
indexwrap(i,N) = mod(i - 1, N) + 1
sub2ind(i,j) = Base._sub2ind((Ny,Nx),indexwrap(i,Ny),indexwrap(j,Nx))

# compute the 4 non-zero entries in L[I,:] (ie the Ith row of the sparse
# lensing representation, L) and add these to the sparse constructor
# matrices, M, and V, accordingly. this function is split off so it can be
# called directly or used as a CUDA kernel
function compute_row!(I, ĩ, j̃, M, V)

# (i,j) indices of the 4 nearest neighbors
left,right = floor(Int,ĩ) .+ (0, 1)
top,bottom = floor(Int,j̃) .+ (0, 1)

# 1-D indices of the 4 nearest neighbors
M[4I-3:4I] .= @SVector[sub2ind(left,top), sub2ind(right,top), sub2ind(left,bottom), sub2ind(right,bottom)]

# weights of these neighbors in the bilinear interpolation
Δx⁻, Δx⁺ = ((left,right) .- ĩ)
Δy⁻, Δy⁺ = ((top,bottom) .- j̃)
A = @SMatrix[
1 Δx⁻ Δy⁻ Δx⁻*Δy⁻;
1 Δx⁺ Δy⁻ Δx⁺*Δy⁻;
1 Δx⁻ Δy⁺ Δx⁻*Δy⁺;
1 Δx⁺ Δy⁺ Δx⁺*Δy⁺
]
V[4I-3:4I] .= inv(A)[1,:]

end

# a surprisingly large fraction of the computation for large Nside, so memoize it:
@memoize getK(Nx,Ny) = Int32.((4:4*Nx*Ny+3) .÷ 4)

K = CUDA.CuVector{Cint}(getK(Nx,Ny))
M = similar(K)
V = similar(K,T)
CMBLensing.cuda(ĩs, j̃s, M, V; threads=256) do ĩs, j̃s, M, V
index = CUDA.threadIdx().x
stride = CUDA.blockDim().x
for I in index:stride:length(ĩs)
compute_row!(I, ĩs[I], j̃s[I], M, V)
end
end
spr = CuSparseMatrixCSR(CuSparseMatrixCOO{T}(K,M,V,(Nx*Ny,Nx*Ny)))
return CMBLensing.BilinearLens(ϕ, spr, nothing)
end



end
2 changes: 1 addition & 1 deletion ext/CMBLensingMuseInferenceExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ using Setfield
θ_fixed = (;)
x = ds.d
latent_vars = nothing
autodiff = AD.HigherOrderBackend((AD.ForwardDiffBackend(tag=false), AD.ZygoteBackend()))
autodiff = AD.HigherOrderBackend((AD.ForwardDiffBackend(), AD.ZygoteBackend()))
transform_θ = identity
inv_transform_θ = identity
end
Expand Down
6 changes: 4 additions & 2 deletions src/CMBLensing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ using Base: @kwdef, @propagate_inbounds, Bottom, OneTo, showarg, show_datatype,
show_default, show_vector, typed_vcat, typename, Callable
using Bijections
using ChainRules
using ChainRulesCore
using ChainRules: @opt_out, rrule, unthunk
using CodecZlib
using Combinatorics
Expand All @@ -26,13 +27,14 @@ using FileIO
using FFTW
using ForwardDiff
using ForwardDiff: Dual, Partials, value, partials
import GPUArrays
using Healpix
using InteractiveUtils
using IterTools: flagfirst
using JLD2
using JLD2: jldopen, JLDWriteSession
using KahanSummation
using Loess
import Loess
using LinearAlgebra
using LinearAlgebra: diagzero, matprod, promote_op
using MacroTools: @capture, combinedef, isdef, isexpr, postwalk, prewalk, rmlines, splitdef
Expand Down Expand Up @@ -66,7 +68,7 @@ using TimerOutputs: @timeit, get_defaulttimer, reset_timer!
using Tullio
using UnPack
using Zygote
using Zygote: unbroadcast, Numeric, @adjoint, @nograd
using Zygote: unbroadcast, Numeric, @adjoint
using Zygote.ChainRules: @thunk, NoTangent


Expand Down
33 changes: 27 additions & 6 deletions src/autodiff.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

# accum is basically supposed to do addition, but Zygotes default for
# Arrays does a broadcast which doesnt do a potentially needed basis
# conversion.
Expand All @@ -15,11 +14,11 @@ Zygote.accum(x::FieldOp, y::FieldOp, zs::FieldOp...) = _plus_accum(x, y, zs...)


# constant functions, as far as AD is concerned
@nograd ProjLambert
@nograd fieldinfo
@nograd hasfield
@nograd basetype
@nograd get_storage
ChainRulesCore.@non_differentiable ProjLambert(::Any...)
ChainRulesCore.@non_differentiable fieldinfo(::Any...)
ChainRulesCore.@non_differentiable hasfield(::Any...)
ChainRulesCore.@non_differentiable basetype(::Any...)
ChainRulesCore.@non_differentiable get_storage(::Any...)


# AD for Fourier Fields can be really subtle because such objects are
Expand Down Expand Up @@ -241,6 +240,28 @@ ProjectTo(::L) where {L<:FieldOp} = ProjectTo{L}()

Zygote.wrap_chainrules_output(dxs::LazyBinaryOp) = dxs

# ProjectTo and _eltype_projectto support for ChainRules compatibility
# This is needed for StaticArrays and other packages that rely on ProjectTo
ChainRulesCore._eltype_projectto(::Type{F}) where {F<:BaseField} = ProjectTo{F}()
ChainRulesCore._eltype_projectto(::Type{SA}) where {SA<:SArray} = ProjectTo{SA}()

# Allow multiplication of SVector{2} (like from gradients) with Fields
*(v::SVector{2, <:Real}, f::Field) = @SVector[v[1]*f, v[2]*f]

# Custom adjoint for SVector * Field to ensure proper gradient flow
@adjoint function *(v::SVector{2, <:Real}, f::Field)
result = v * f
function svector_field_pullback(Δ)
# Δ should be SVector{2, <:Field}
# Gradient w.r.t. v should be SVector{2, Real}
# Gradient w.r.t. f should be Field
v_grad = @SVector [sum(Δ[1]), sum(Δ[2])]
f_grad = v[1] * Δ[1] + v[2] * Δ[2]
return (v_grad, f_grad)
end
return result, svector_field_pullback
end

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I confess this function was heavily AI-assisted...

# needed to allow AD through field broadcasts
Zygote.unbroadcast(x::BaseField{B}, x̄::BaseField) where {B} =
BaseField{B}(Zygote.unbroadcast(x.arr, x̄.arr), x.metadata)
Expand Down
39 changes: 9 additions & 30 deletions src/bilinearlens.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ mutable struct BilinearLens{T,Φ<:Field{<:Any,T},S} <: ImplicitOp{T}
anti_lensing_sparse_repr :: Union{S, Nothing}
end

function BilinearLens(ϕ::FlatField)

function BilinearLens(ϕ::FlatField{B1,M1,CT,AA}) where {B1,M1,CT,AA<:AbstractArray}

# if ϕ == 0 then just return identity operator
if norm(ϕ) == 0
Expand Down Expand Up @@ -70,41 +71,19 @@ function BilinearLens(ϕ::FlatField)
1 Δx⁺ Δy⁺ Δx⁺*Δy⁺
]
V[4I-3:4I] .= inv(A)[1,:]

end

# a surprisingly large fraction of the computation for large Nside, so memoize it:
@memoize getK(Nx,Ny) = Int32.((4:4*Nx*Ny+3) .÷ 4)

# CPU
function compute_sparse_repr(is_gpu_backed::Val{false})
K = Vector{Int32}(getK(Nx,Ny))
M = similar(K)
V = similar(K,T)
for I in 1:length(ĩs)
compute_row!(I, ĩs[I], j̃s[I], M, V)
end
sparse(K,M,V,Nx*Ny,Nx*Ny)
end

# GPU
function compute_sparse_repr(is_gpu_backed::Val{true})
K = CuVector{Cint}(getK(Nx,Ny))
M = similar(K)
V = similar(K,T)
cuda(ĩs, j̃s, M, V; threads=256) do ĩs, j̃s, M, V
index = threadIdx().x
stride = blockDim().x
for I in index:stride:length(ĩs)
compute_row!(I, ĩs[I], j̃s[I], M, V)
end
end
CuSparseMatrixCSR(CuSparseMatrixCOO{T}(K,M,V,(Nx*Ny,Nx*Ny)))
K = Vector{Int32}(getK(Nx,Ny))
M = similar(K)
V = similar(K,T)
for I in 1:length(ĩs)
compute_row!(I, ĩs[I], j̃s[I], M, V)
end


BilinearLens(ϕ, compute_sparse_repr(Val(is_gpu_backed(ϕ))), nothing)

spr = sparse(K,M,V,Nx*Ny,Nx*Ny)
return BilinearLens(ϕ, spr, nothing)
end


Expand Down
2 changes: 1 addition & 1 deletion src/cls.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ function smooth(Cℓ::Cℓs; newℓs=minimum(Cℓ.ℓ):maximum(Cℓ.ℓ), xscale
_ => throw(ArgumentError("'xscale' should be :log or :linear"))
end

Cℓs(newℓs, fy⁻¹.(Loess.predict(loess(fx.(Cℓ.ℓ),fy.(Cℓ.Cℓ),span=smoothing),fx.(newℓs))), concrete=Cℓ.concrete)
Cℓs(newℓs, fy⁻¹.(Loess.predict(cmblensing_loess(fx.(Cℓ.ℓ),fy.(Cℓ.Cℓ); span=smoothing),fx.(newℓs))), concrete=Cℓ.concrete)
end


Expand Down
4 changes: 2 additions & 2 deletions src/sampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@ function grid_and_sample(logpdfs::Vector, xs::AbstractVector; progress=false, ns
# interpolate PDF
xmin, xmax = first(xs), last(xs)
logpdfs = logpdfs .- maximum(logpdfs)
interp_logpdfs = loess(xs, logpdfs, span=span)

interp_logpdfs = cmblensing_loess(xs, logpdfs; span=span)
#
# normalize the PDF. note the smoothing is done of the log PDF.
cdf(x) = quadgk(nan2zero∘exp∘interp_logpdfs,xmin,x,rtol=1e-4)[1]
logA = nan2zero(log(cdf(xmax)))
Expand Down
Loading