diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index b86cccf..a3e09da 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -40,6 +40,16 @@ jobs: arch: aarch64 - os: macos-latest arch: x64 + include: + - os: ubuntu-24.04-arm + version: 'lts' + arch: aarch64 + - os: ubuntu-24.04-arm + version: '1.11' + arch: aarch64 + - os: ubuntu-24.04-arm + version: '1.12' + arch: aarch64 steps: - uses: actions/checkout@v6 - uses: julia-actions/setup-julia@v2 diff --git a/.ipynb_checkpoints/Untitled-checkpoint.ipynb b/.ipynb_checkpoints/Untitled-checkpoint.ipynb new file mode 100644 index 0000000..363fcab --- /dev/null +++ b/.ipynb_checkpoints/Untitled-checkpoint.ipynb @@ -0,0 +1,6 @@ +{ + "cells": [], + "metadata": {}, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/Project.toml b/Project.toml index 2832778..ee70594 100644 --- a/Project.toml +++ b/Project.toml @@ -1,12 +1,13 @@ name = "BeamTracking" uuid = "8ef5c10a-4ca3-437f-8af5-b84d8af36df0" authors = ["mattsignorelli and contributors"] -version = "0.5.5" +version = "0.5.6" [deps] Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" AtomicAndPhysicalConstants = "5c0d271c-5419-4163-b387-496237733d8b" +EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56" GTPSA = "b27dd330-f138-47c5-815b-40db9dd9b6e8" HostCPUFeatures = "3e5b6fbb-0976-4d2c-9146-d79de83f2fb0" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" @@ -21,15 +22,18 @@ Unrolled = "9602ed7d-8fef-5bc8-8597-8f21381861e8" [weakdeps] Beamlines = "5bb90b03-0719-46b8-8ce4-1ef3afd3cd4b" +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" [extensions] BeamTrackingBeamlinesExt = "Beamlines" +BeamTrackingCUDAExt = "CUDA" [compat] Accessors = "0.1.42" Adapt = "4.3.0" AtomicAndPhysicalConstants = "0.8.0" Beamlines = "0.8.0" +CUDA = "5.9" GTPSA = "1.5.3" HostCPUFeatures = "0.1" KernelAbstractions = "0.9.35" @@ -48,8 +52,9 @@ BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" GTPSA = "b27dd330-f138-47c5-815b-40db9dd9b6e8" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test", "Distributions", "JET", "GTPSA", "BenchmarkTools", "Beamlines", "StaticArrays"] +test = ["Test", "Distributions", "JET", "GTPSA", "BenchmarkTools", "Beamlines", "StaticArrays", "LinearAlgebra"] diff --git a/ext/BeamTrackingBeamlinesExt/exact.jl b/ext/BeamTrackingBeamlinesExt/exact.jl index 5150dc0..1c09d8e 100644 --- a/ext/BeamTrackingBeamlinesExt/exact.jl +++ b/ext/BeamTrackingBeamlinesExt/exact.jl @@ -76,8 +76,16 @@ end @inline function thick_bend_pure_bdipole(tm::Exact, bunch, bendparams, bm1, L) g = bendparams.g_ref tilt = bendparams.tilt_ref - e1 = bendparams.e1 - e2 = bendparams.e2 + if tm.fringe_at == Fringe.BothEnds || tm.fringe_at == Fringe.EntranceEnd + e1 = bendparams.e1 + else + e1 = 0 + end + if tm.fringe_at == Fringe.BothEnds || tm.fringe_at == Fringe.ExitEnd + e2 = bendparams.e2 + else + e2 = 0 + end w = rot_quaternion(0,0,-tilt) w_inv = inv_rot_quaternion(0,0,-tilt) theta = g * L diff --git a/ext/BeamTrackingBeamlinesExt/utils.jl b/ext/BeamTrackingBeamlinesExt/utils.jl index 0072bd6..0482138 100644 --- a/ext/BeamTrackingBeamlinesExt/utils.jl +++ b/ext/BeamTrackingBeamlinesExt/utils.jl @@ -77,31 +77,34 @@ This works for both BMultipole and BMultipoleParams. Branchless bc SIMD -> basic no loss in computing both but benefit of branchless. """ @inline function get_strengths(bm, L, p_over_q_ref) - if isconcretetype(eltype(bm.n)) - T = promote_type(eltype(bm.n), + bmn = getfield(bm, :n) + bms = getfield(bm, :s) + bmtilt = getfield(bm, :tilt) + if isconcretetype(eltype(bmn)) + T = promote_type(eltype(bmn), typeof(L), typeof(p_over_q_ref) ) else - if bm.n isa AbstractArray - T = promote_type(reduce(promote_type, typeof.(bm.n)), - reduce(promote_type, typeof.(bm.s)), - reduce(promote_type, typeof.(bm.tilt)), + if bmn isa AbstractArray + T = promote_type(reduce(promote_type, typeof.(bmn)), + reduce(promote_type, typeof.(bms)), + reduce(promote_type, typeof.(bmtilt)), typeof(L), typeof(p_over_q_ref) ) else - T = promote_type(typeof(bm.n), - typeof(bm.s), - typeof(bm.tilt), + T = promote_type(typeof(bmn), + typeof(bms), + typeof(bmtilt), typeof(L), typeof(p_over_q_ref) ) end end - n = T.(make_static(bm.n)) - s = T.(make_static(bm.s)) - tilt = T.(make_static(bm.tilt)) - order = bm.order - normalized = bm.normalized - integrated = bm.integrated + n = T.(make_static(bmn)) + s = T.(make_static(bms)) + tilt = T.(make_static(bmtilt)) + order = getfield(bm, :order) + normalized = getfield(bm, :normalized) + integrated = getfield(bm, :integrated) np = @. n*cos(order*tilt) + s*sin(order*tilt) sp = @. -n*sin(order*tilt) + s*cos(order*tilt) np = @. ifelse(!normalized, np/p_over_q_ref, np) @@ -112,31 +115,34 @@ no loss in computing both but benefit of branchless. end @inline function get_integrated_strengths(bm, L, p_over_q_ref) - if isconcretetype(eltype(bm.n)) - T = promote_type(eltype(bm.n), + bmn = getfield(bm, :n) + bms = getfield(bm, :s) + bmtilt = getfield(bm, :tilt) + if isconcretetype(eltype(bmn)) + T = promote_type(eltype(bmn), typeof(L), typeof(p_over_q_ref) ) else - if bm.n isa AbstractArray - T = promote_type(reduce(promote_type, typeof.(bm.n)), - reduce(promote_type, typeof.(bm.s)), - reduce(promote_type, typeof.(bm.tilt)), + if bmn isa AbstractArray + T = promote_type(reduce(promote_type, typeof.(bmn)), + reduce(promote_type, typeof.(bms)), + reduce(promote_type, typeof.(bmtilt)), typeof(L), typeof(p_over_q_ref) ) else - T = promote_type(typeof(bm.n), - typeof(bm.s), - typeof(bm.tilt), + T = promote_type(typeof(bmn), + typeof(bms), + typeof(bmtilt), typeof(L), typeof(p_over_q_ref) ) end end - n = T.(make_static(bm.n)) - s = T.(make_static(bm.s)) - tilt = T.(make_static(bm.tilt)) - order = bm.order - normalized = bm.normalized - integrated = bm.integrated + n = T.(make_static(bmn)) + s = T.(make_static(bms)) + tilt = T.(make_static(bmtilt)) + order = getfield(bm, :order) + normalized = getfield(bm, :normalized) + integrated = getfield(bm, :integrated) np = @. n*cos(order*tilt) + s*sin(order*tilt) sp = @. -n*sin(order*tilt) + s*cos(order*tilt) np = @. ifelse(!normalized, np/p_over_q_ref, np) @@ -169,4 +175,22 @@ function rf_phi0(rfparams) else error("RF parameter zero_phase value not set correctly.") end +end + +#--------------------------------------------------------------------------------------------------- + +function fringe_in(f::Fringe.T) + if f == Fringe.BothEnds || f == Fringe.EntranceEnd + return Val{true}() + else + return Val{false}() + end +end + +function fringe_out(f::Fringe.T) + if f == Fringe.BothEnds || f == Fringe.ExitEnd + return Val{true}() + else + return Val{false}() + end end \ No newline at end of file diff --git a/ext/BeamTrackingBeamlinesExt/yoshida.jl b/ext/BeamTrackingBeamlinesExt/yoshida.jl index 416006f..76e4efc 100644 --- a/ext/BeamTrackingBeamlinesExt/yoshida.jl +++ b/ext/BeamTrackingBeamlinesExt/yoshida.jl @@ -9,14 +9,16 @@ num_steps = Int(ceil(L / ds_step)) ds_step = L / num_steps end + fin = fringe_in(tm.fringe_at) + fout = fringe_out(tm.fringe_at) if order == 2 - return KernelCall(BeamTracking.order_two_integrator!, (ker, params, photon_params, ds_step, num_steps, edge_params, L)) + return KernelCall(BeamTracking.order_two_integrator!, (ker, params, photon_params, ds_step, num_steps, edge_params, fin, fout, L)) elseif order == 4 - return KernelCall(BeamTracking.order_four_integrator!, (ker, params, photon_params, ds_step, num_steps, edge_params, L)) + return KernelCall(BeamTracking.order_four_integrator!, (ker, params, photon_params, ds_step, num_steps, edge_params, fin, fout, L)) elseif order == 6 - return KernelCall(BeamTracking.order_six_integrator!, (ker, params, photon_params, ds_step, num_steps, edge_params, L)) + return KernelCall(BeamTracking.order_six_integrator!, (ker, params, photon_params, ds_step, num_steps, edge_params, fin, fout, L)) elseif order == 8 - return KernelCall(BeamTracking.order_eight_integrator!, (ker, params, photon_params, ds_step, num_steps, edge_params, L)) + return KernelCall(BeamTracking.order_eight_integrator!, (ker, params, photon_params, ds_step, num_steps, edge_params, fin, fout, L)) end end @@ -81,7 +83,7 @@ end q = chargeof(bunch.species) mc2 = massof(bunch.species) a = gyromagnetic_anomaly(bunch.species) - edge_params = ifelse(tm.fringe_on, (a, tilde_m, Ksol, 0, 0, 0), nothing) + edge_params = (a, tilde_m, Ksol, 0, 0, 0) E0 = mc2/tilde_m/beta_0 params = (q, mc2, tm.radiation_damping_on, beta_0, gamsqr_0, tilde_m, a, Ksol, mm, kn, ks) photon_params = ifelse(tm.radiation_fluctuations_on, (q, mc2, E0, 0, 0, mm, kn, ks), nothing) @@ -97,39 +99,48 @@ end Ksol = kn[1] q = chargeof(bunch.species) mc2 = massof(bunch.species) + a = gyromagnetic_anomaly(bunch.species) + Kn0 = ifelse(mm[2] == 1, kn[2], 0) + edge_params = (a, tilde_m, Ksol, Kn0, 0, 0) E0 = mc2/tilde_m/beta_0 - params = (q, mc2, tm.radiation_damping_on, beta_0, gamsqr_0, tilde_m, gyromagnetic_anomaly(bunch.species), Ksol, mm, kn, ks) + params = (q, mc2, tm.radiation_damping_on, beta_0, gamsqr_0, tilde_m, a, Ksol, mm, kn, ks) photon_params = ifelse(tm.radiation_fluctuations_on, (q, mc2, E0, 0, 0, mm, kn, ks), nothing) - return integration_launcher(BeamTracking.sks_multipole!, params, photon_params, tm, nothing, L) + return integration_launcher(BeamTracking.sks_multipole!, params, photon_params, tm, edge_params, L) end -@inline function thick_pure_bdipole(tm::DriftKick, bunch, bm, L) +@inline function thick_pure_bdipole(tm::Union{Yoshida,DriftKick}, bunch, bm, L) p_over_q_ref = bunch.p_over_q_ref tilde_m, gamsqr_0, beta_0 = BeamTracking.drift_params(bunch.species, p_over_q_ref) mm = bm.order kn, ks = get_strengths(bm, L, p_over_q_ref) q = chargeof(bunch.species) mc2 = massof(bunch.species) + a = gyromagnetic_anomaly(bunch.species) + Kn0 = ifelse(mm == 1, kn, 0) + edge_params = (a, tilde_m, 0, Kn0, 0, 0) E0 = mc2/tilde_m/beta_0 - params = (q, mc2, tm.radiation_damping_on, beta_0, gamsqr_0, tilde_m, gyromagnetic_anomaly(bunch.species), SA[mm], SA[kn], SA[ks]) + params = (q, mc2, tm.radiation_damping_on, beta_0, gamsqr_0, tilde_m, a, SA[mm], SA[kn], SA[ks]) photon_params = ifelse(tm.radiation_fluctuations_on, (q, mc2, E0, 0, 0, SA[mm], SA[kn], SA[ks]), nothing) - return integration_launcher(BeamTracking.dkd_multipole!, params, photon_params, tm, nothing, L) + return integration_launcher(BeamTracking.dkd_multipole!, params, photon_params, tm, edge_params, L) end -@inline function thick_bdipole(tm::DriftKick, bunch, bm, L) +@inline function thick_bdipole(tm::Union{Yoshida,DriftKick}, bunch, bm, L) p_over_q_ref = bunch.p_over_q_ref tilde_m, gamsqr_0, beta_0 = BeamTracking.drift_params(bunch.species, p_over_q_ref) mm = bm.order kn, ks = get_strengths(bm, L, p_over_q_ref) q = chargeof(bunch.species) mc2 = massof(bunch.species) + a = gyromagnetic_anomaly(bunch.species) + Kn0 = ifelse(mm[1] == 1, kn[1], 0) + edge_params = (a, tilde_m, 0, Kn0, 0, 0) E0 = mc2/tilde_m/beta_0 - params = (q, mc2, tm.radiation_damping_on, beta_0, gamsqr_0, tilde_m, gyromagnetic_anomaly(bunch.species), mm, kn, ks) + params = (q, mc2, tm.radiation_damping_on, beta_0, gamsqr_0, tilde_m, a, mm, kn, ks) photon_params = ifelse(tm.radiation_fluctuations_on, (q, mc2, E0, 0, 0, mm, kn, ks), nothing) - return integration_launcher(BeamTracking.dkd_multipole!, params, photon_params, tm, nothing, L) + return integration_launcher(BeamTracking.dkd_multipole!, params, photon_params, tm, edge_params, L) end -@inline function thick_pure_bdipole(tm::Union{Yoshida,BendKick}, bunch, bm1, L) +@inline function thick_pure_bdipole(tm::BendKick, bunch, bm1, L) if isnothing(bunch.coords.q) && !(tm.radiation_damping_on || tm.radiation_fluctuations_on) return thick_pure_bdipole(Exact(), bunch, bm1, L) else @@ -144,9 +155,11 @@ end q = chargeof(bunch.species) mc2 = massof(bunch.species) E0 = mc2/tilde_m/beta_0 - params = (q, mc2, tm.radiation_damping_on, tilde_m, beta_0, gyromagnetic_anomaly(bunch.species), 0, w, w_inv, k0, SA[mm], SA[kn], SA[ks]) + a = gyromagnetic_anomaly(bunch.species) + edge_params = (a, tilde_m, 0, k0, 0, 0) + params = (q, mc2, tm.radiation_damping_on, tilde_m, beta_0, a, 0, w, w_inv, k0, SA[mm], SA[kn], SA[ks]) photon_params = ifelse(tm.radiation_fluctuations_on, (q, mc2, E0, 0, 0, SA[mm], SA[kn], SA[ks]), nothing) - return integration_launcher(BeamTracking.bkb_multipole!, params, photon_params, tm, nothing, L) + return integration_launcher(BeamTracking.bkb_multipole!, params, photon_params, tm, edge_params, L) end end @@ -161,10 +174,12 @@ end w_inv = inv_rot_quaternion(0,0,tilt) q = chargeof(bunch.species) mc2 = massof(bunch.species) + a = gyromagnetic_anomaly(bunch.species) + edge_params = (a, tilde_m, 0, k0, 0, 0) E0 = mc2/tilde_m/beta_0 - params = (q, mc2, tm.radiation_damping_on, tilde_m, beta_0, gyromagnetic_anomaly(bunch.species), 0, w, w_inv, k0, mm, kn, ks) + params = (q, mc2, tm.radiation_damping_on, tilde_m, beta_0, a, 0, w, w_inv, k0, mm, kn, ks) photon_params = ifelse(tm.radiation_fluctuations_on, (q, mc2, E0, 0, 0, mm, kn, ks), nothing) - return integration_launcher(BeamTracking.bkb_multipole!, params, photon_params, tm, nothing, L) + return integration_launcher(BeamTracking.bkb_multipole!, params, photon_params, tm, edge_params, L) end @inline function thick_bdipole(tm::MatrixKick, bunch, bm, L) @@ -185,18 +200,12 @@ end w_inv = inv_rot_quaternion(0,0,tilt) q = chargeof(bunch.species) mc2 = massof(bunch.species) + a = gyromagnetic_anomaly(bunch.species) + edge_params = (a, tilde_m, 0, kn[1], 0, 0) E0 = mc2/tilde_m/beta_0 - params = (q, mc2, tm.radiation_damping_on, beta_0, gamsqr_0, tilde_m, gyromagnetic_anomaly(bunch.species), w, w_inv, k1, mm, kn, ks) + params = (q, mc2, tm.radiation_damping_on, beta_0, gamsqr_0, tilde_m, a, w, w_inv, k1, mm, kn, ks) photon_params = ifelse(tm.radiation_fluctuations_on, (q, mc2, E0, 0, 0, mm, kn, ks), nothing) - return integration_launcher(BeamTracking.mkm_quadrupole!, params, photon_params, tm, nothing, L) -end - -@inline function thick_bdipole(tm::Yoshida, bunch, bm, L) - if bm.order[2] == 2 - return thick_bdipole(MatrixKick(order=tm.order, num_steps=tm.num_steps, ds_step=tm.ds_step, radiation_damping_on=tm.radiation_damping_on, radiation_fluctuations_on=tm.radiation_fluctuations_on), bunch, bm, L) - else - return thick_bdipole(BendKick(order=tm.order, num_steps=tm.num_steps, ds_step=tm.ds_step, radiation_damping_on=tm.radiation_damping_on, radiation_fluctuations_on=tm.radiation_fluctuations_on), bunch, bm, L) - end + return integration_launcher(BeamTracking.mkm_quadrupole!, params, photon_params, tm, edge_params, L) end @inline function thick_pure_bquadrupole(tm::Union{Yoshida,MatrixKick}, bunch, bm, L) @@ -272,7 +281,7 @@ end w = rot_quaternion(0,0,tilt) w_inv = inv_rot_quaternion(0,0,tilt) a = gyromagnetic_anomaly(bunch.species) - edge_params = ifelse(tm.fringe_on, (a, tilde_m, 0, Kn0, e1, e2), nothing) + edge_params = (a, tilde_m, 0, Kn0, e1, e2) q = chargeof(bunch.species) mc2 = massof(bunch.species) E0 = mc2/tilde_m/beta_0 @@ -313,8 +322,24 @@ end p0c = BeamTracking.R_to_pc(bunch.species, p_over_q_ref) q = chargeof(bunch.species) mc2 = massof(bunch.species) + if mm[1] == 0 + Ksol = kn[1] + if length(mm) > 1 && mm[2] == 1 + Kn0 = kn[2] + else + Kn0 = 0 + end + elseif mm[1] == 1 + Ksol = 0 + Kn0 = kn[1] + else + Ksol = 0 + Kn0 = 0 + end + a = gyromagnetic_anomaly(bunch.species) + edge_params = (a, tilde_m, Ksol, Kn0, 0, 0) E0 = mc2/tilde_m/beta_0 - params = (q, mc2, tm.radiation_damping_on, beta_0, gamsqr_0, tilde_m, E_ref, p0c, gyromagnetic_anomaly(bunch.species), omega, t0, E0_over_Rref, mm, kn, ks) + params = (q, mc2, tm.radiation_damping_on, beta_0, gamsqr_0, tilde_m, E_ref, p0c, a, omega, t0, E0_over_Rref, mm, kn, ks) photon_params = ifelse(tm.radiation_fluctuations_on, (q, mc2, E0, 0, 0, mm, kn, ks), nothing) - return integration_launcher(BeamTracking.cavity!, params, photon_params, tm, nothing, L) + return integration_launcher(BeamTracking.cavity!, params, photon_params, tm, edge_params, L) end diff --git a/ext/BeamTrackingCUDAExt/BeamTrackingCUDAExt.jl b/ext/BeamTrackingCUDAExt/BeamTrackingCUDAExt.jl new file mode 100644 index 0000000..fd14bba --- /dev/null +++ b/ext/BeamTrackingCUDAExt/BeamTrackingCUDAExt.jl @@ -0,0 +1,21 @@ +module BeamTrackingCUDAExt +using CUDA: CuDeviceArray +import BeamTracking: gaussian_random + +""" +This function returns two Gaussian random numbers with +mean 0 and standard deviations sigma1, sigma2 using a +Box-Muller transform. + +This was implemented because CUDA.randn has some horrible +compiler bug, but CUDA.rand seems to be ok. +""" +function gaussian_random(::CuDeviceArray, sigma1, sigma2) + s, c = sincospi(2 * rand()) + t = sqrt(-2 * log(rand())) + z0 = c*t*sigma1 + z1 = s*t*sigma2 + return z0, z1 +end + +end \ No newline at end of file diff --git a/src/BeamTracking.jl b/src/BeamTracking.jl index cb00b02..ac941f8 100644 --- a/src/BeamTracking.jl +++ b/src/BeamTracking.jl @@ -11,14 +11,16 @@ using GTPSA, Accessors, SpecialFunctions, AtomicAndPhysicalConstants, - Random + Random, + EnumX using KernelAbstractions import GTPSA: sincu, sinhcu -export Bunch, State, ParticleView, Time, TimeDependentParam +export Bunch, State, ParticleView, Time, TimeDependentParam, BatchParam export Yoshida, Yoshida, MatrixKick, BendKick, SolenoidKick, DriftKick, Exact +export Fringe export track! @@ -30,6 +32,7 @@ include("utils/z_to_time.jl") include("types.jl") include("time.jl") +include("batch.jl") include("kernel.jl") include("tracking_methods.jl") diff --git a/src/batch.jl b/src/batch.jl new file mode 100644 index 0000000..c9331b4 --- /dev/null +++ b/src/batch.jl @@ -0,0 +1,348 @@ +#= + +BatchParam is very similar to TimeDependentParam, but instead +of storing a function, it stores an arbitrary array of parameters. + +Given a batch = [k1, k2, k3], batch parameters are seen by the particles as + +Particle 1: k1 +Particle 2: k2 +Particle 3: k3 +Particle 4: k1 +Particle 5: k2 +Particle 6: k3 + +etc. + +Just like time, there are two types for batches - one type unstable +generic wrapper for an AbstractArray/Number which is manipulated at the +highest level, and a lowered representation where the numbers are made +numbers and arrays are kept but the length of the array is stored in +the type. The reason for this is so that scalars are not unnecessarily +represented as large arrays, to save memory both outside and inside the +kernel. The goal is that for large, batched simulations the type instability +of the UNPACKING step (simulation step is always type stable) is outweighted +by the simulation step. + +The lowered type does NOT have any arithmetic operations defined on it, as +it should be untouched after lowering and in the kernel evaluated for each +particle. + +=# +struct BatchParam + batch::Union{AbstractArray,Number} + #= + Currently, batches of TimeDependentParam are not supported + It is essentially impossible on the GPU because accessing an + array of functions is not type stable (every particle has its + own function). The alternative - a time function which returns a + big array - is also not possible for CPU-SIMD and would run into + memory problems on both CPU and GPU depending on how large that + static array is. + + On the CPU, the first option may be doable using FunctionWrappers, + but that would require special handling because current Time + uses bona-fide Julia functions for GPU compatibility. + + If one would like to do this right now, they should just start up + separate processes where each process has its own lattice with its + own TimeDependentParams. + =# + function BatchParam(batch::AbstractArray) + if length(batch) == 1 + error("Cannot make BatchParam with array of length 1") + end + return new(batch) + end + BatchParam(n::Number) = new(n) +end + +struct _LoweredBatchParam{N,V<:AbstractArray} + batch::V + _LoweredBatchParam(batch::AbstractArray) = new{length(batch),typeof(batch)}(batch) + _LoweredBatchParam{N}(batch::AbstractArray) where {N} = new{N,typeof(batch)}(batch) +end + +# Necessary for GPU compatibility if batch is a GPU array +function Adapt.adapt_structure(to, lbp::_LoweredBatchParam{N}) where {N} + batch = Adapt.adapt_structure(to, lbp.batch) + return _LoweredBatchParam{N}(batch) +end + +# BatchParam will act like a number +# Conversion of types to BatchParam +BatchParam(a::BatchParam) = a + +# Make these apply via convert +Base.convert(::Type{BatchParam}, a::Number) = BatchParam(a) # Scalar BatchParam +Base.convert(::Type{BatchParam}, a::BatchParam) = a + +Base.zero(b::BatchParam) = BatchParam(zero(first(b.batch))) +Base.one(b::BatchParam) = BatchParam(one(first(b.batch))) + +# Now define the math operations: +# The operations are individually-specialized for each operator, assuming that the +# most expensive step is creating temporary arrays, not type instability. As such, +# each are defined in a way as to minimize number of temporary arrays during unpacking, +# checking for e.g. 0's and 1's. +function _batch_addsub(batch_a, batch_b, op::T) where {T<:Union{typeof(+),typeof(-)}} + if batch_a isa Number + if batch_b isa Number + return BatchParam(op(batch_a, batch_b)) + else + if batch_a ≈ 0 # add/sub by zero gives identity + return BatchParam(batch_b) + else + let a = batch_a + return BatchParam(map((bi)->op(a, bi), batch_b)) + end + end + end + elseif batch_b isa Number + if batch_b ≈ 0 # add/sub by zero gives identity + return BatchParam(batch_a) + else + let b = batch_b + return BatchParam(map((ai)->op(ai, b), batch_a)) + end + end + elseif length(batch_a) == length(batch_b) + return BatchParam(map((ai,bi)->op(ai, bi), batch_a, batch_b)) + else + error("Cannot perform operation $(op) with two non-scalar BatchParams of differing + lengths (received lengths $(length(batch_a)) and $(length(batch_b))).") + end +end + +Base.:+(ba::BatchParam, n::Number) = _batch_addsub(ba.batch, n, +) +Base.:+(n::Number, bb::BatchParam) = _batch_addsub(n, bb.batch, +) +Base.:+(ba::BatchParam, bb::BatchParam) = _batch_addsub(ba.batch, bb.batch, +) + +Base.:-(ba::BatchParam, n::Number) = _batch_addsub(ba.batch, n, -) +Base.:-(n::Number, bb::BatchParam) = _batch_addsub(n, bb.batch, -) +Base.:-(ba::BatchParam, bb::BatchParam) = _batch_addsub(ba.batch, bb.batch, -) + +function _batch_mul(batch_a, batch_b) + if batch_a isa Number + if batch_b isa Number + return BatchParam(*(batch_a, batch_b)) + else + if batch_a ≈ 0 # mul by 0 gives 0 -> make scalar + return BatchParam(0f0) + elseif batch_a ≈ 1 # mul by 1 gives identity + return BatchParam(batch_b) + else + let a = batch_a + return BatchParam(map((bi)->*(a, bi), batch_b)) + end + end + end + elseif batch_b isa Number + if batch_b ≈ 0 # mul by 0 gives 0 -> make scalar + return BatchParam(0f0) + elseif batch_b ≈ 1 # mul by 1 gives identity + return BatchParam(batch_a) + else + let b = batch_b + return BatchParam(map((ai)->*(ai, b), batch_a)) + end + end + elseif length(batch_a) == length(batch_b) + return BatchParam(map((ai,bi)->*(ai, bi), batch_a, batch_b)) + else + error("Cannot perform operation * with two non-scalar BatchParams of differing + lengths (received lengths $(length(batch_a)) and $(length(batch_b))).") + end +end + +Base.:*(ba::BatchParam, n::Number) = _batch_mul(ba.batch, n) +Base.:*(n::Number, bb::BatchParam) = _batch_mul(n, bb.batch) +Base.:*(ba::BatchParam, bb::BatchParam) = _batch_mul(ba.batch, bb.batch) + +function _batch_div(batch_a, batch_b) + if batch_a isa Number + if batch_b isa Number + return BatchParam(/(batch_a, batch_b)) + else + let a = batch_a + return BatchParam(map((bi)->/(a, bi), batch_b)) + end + end + elseif batch_b isa Number + if batch_b ≈ 0 # div by 0 gives Inf -> make scalar + return BatchParam(Inf32) + elseif batch_b ≈ 1 # div by 1 gives identity + return BatchParam(batch_a) + else + let b = batch_b + return BatchParam(map((ai)->/(ai, b), batch_a)) + end + end + elseif length(batch_a) == length(batch_b) + return BatchParam(map((ai,bi)->/(ai, bi), batch_a, batch_b)) + else + error("Cannot perform operation / with two non-scalar BatchParams of differing + lengths (received lengths $(length(batch_a)) and $(length(batch_b))).") + end +end + +Base.:/(ba::BatchParam, n::Number) = _batch_div(ba.batch, n) +Base.:/(n::Number, bb::BatchParam) = _batch_div(n, bb.batch) +Base.:/(ba::BatchParam, bb::BatchParam) = _batch_div(ba.batch, bb.batch) + +# for now no special things for pow, unsure if called anywhere. +function _batch_pow(batch_a, batch_b) + if batch_a isa Number + if batch_b isa Number + return BatchParam(^(batch_a, batch_b)) + else + let a = batch_a + return BatchParam(map((bi)->^(a, bi), batch_b)) + end + end + elseif batch_b isa Number + let b = batch_b + return BatchParam(map((ai)->^(ai, b), batch_a)) + end + elseif length(batch_a) == length(batch_b) + return BatchParam(map((ai,bi)->^(ai, bi), batch_a, batch_b)) + else + error("Cannot perform operation ^ with two non-scalar BatchParams of differing + lengths (received lengths $(length(batch_a)) and $(length(batch_b))).") + end +end + +Base.:^(ba::BatchParam, n::Number) = _batch_pow(ba.batch, n) +Base.:^(n::Number, bb::BatchParam) = _batch_pow(n, bb.batch) +Base.:^(ba::BatchParam, bb::BatchParam) = _batch_pow(ba.batch, bb.batch) + +function Base.literal_pow(::typeof(^), ba::BatchParam, ::Val{N}) where {N} + return BatchParam(map(x->Base.literal_pow(^, x, Val{N}()), ba.batch)) +end + +atan2(bpa::BatchParam, bpb::BatchParam) = _batch_atan2(bpa.batch, bpb.batch) + +function _batch_atan2(batch_a, batch_b) + if batch_a isa Number + if batch_b isa Number + return BatchParam(atan2(batch_a, batch_b)) + else + let a = batch_a + return BatchParam(map((bi)->atan2(a, bi), batch_b)) + end + end + elseif batch_b isa Number + let b = batch_b + return BatchParam(map((ai)->atan2(ai, b), batch_a)) + end + elseif length(batch_a) == length(batch_b) + return BatchParam(map((ai,bi)->atan2(ai, bi), batch_a, batch_b)) + else + error("Cannot perform operation ^ with two non-scalar BatchParams of differing + lengths (received lengths $(length(batch_a)) and $(length(batch_b))).") + end +end + +Base.:+(b::BatchParam) = b # identity + +for t = (:-, :sqrt, :exp, :log, :sin, :cos, :tan, :cot, :sinh, :cosh, :tanh, :inv, + :coth, :asin, :acos, :atan, :acot, :asinh, :acosh, :atanh, :acoth, :sinc, :csc, :float, + :csch, :acsc, :acsch, :sec, :sech, :asec, :asech, :conj, :log10, :isnan, :sign, :abs) + @eval begin + Base.$t(b::BatchParam) = BatchParam(map(x->($t)(x), b.batch)) + end +end + +for t = (:unit, :sincu, :sinhc, :sinhcu, :asinc, :asincu, :asinhc, :asinhcu, :erf, + :erfc, :erfcx, :erfi, :wf, :rect) + @eval begin + GTPSA.$t(b::BatchParam) = BatchParam(map(x->($t)(x), b.batch)) + end +end + + +Base.promote_rule(::Type{BatchParam}, ::Type{U}) where {U<:Number} = BatchParam +Base.promote_rule(::Type{BatchParam}, ::Type{TimeDependentParam}) = error("Unable to combine BatchParams with TimeDependentParams") +Base.promote_rule(::Type{TimeDependentParam}, ::Type{BatchParam}) = error("Unable to combine BatchParams with TimeDependentParams") +Base.broadcastable(o::BatchParam) = Ref(o) + +Base.isapprox(b::BatchParam, n::Number; kwargs...) = all(x->isapprox(x, n, kwargs...), b.batch) +Base.isapprox(n::Number, b::BatchParam; kwargs...) = all(x->isapprox(n, x, kwargs...), b.batch) +Base.:(==)(b::BatchParam, n::Number) = all(x->x == n, b.batch) +Base.:(==)(n::Number, b::BatchParam) = all(x->n == x, b.batch) +Base.isinf(b::BatchParam) = all(x->isinf(x), b.batch) + +# Batch lowering should convert types to _LoweredBatchParam +function batch_lower(b::BatchParam) + if b.batch isa AbstractArray + return _LoweredBatchParam(b.batch) # Only arrays are lowered to batchparams + else + return b.batch + end +end + +batch_lower(bp) = bp +# We can use map on the CPU, but not the GPU. This step of batch_lower-ing is on +# the CPU and we are already type unstable here anyways, so we should do this. +batch_lower(bp::T) where {T<:Tuple} = map(bi->batch_lower(bi), bp) + +# Arrays MUST be converted into tuples, for SIMD +batch_lower(bp::SArray{N,BatchParam}) where {N} = batch_lower(Tuple(bp)) + +static_batchcheck(bp) = false +static_batchcheck(::_LoweredBatchParam) = true +@unroll function static_batchcheck(t::Tuple) + @unroll for ti in t + if static_batchcheck(ti) + return true + end + end + return false +end + +@inline beval(b::_LoweredBatchParam{B}, i) where {B} = b.batch[mod1(i, B)] + +@inline function beval(b::_LoweredBatchParam{B}, lane::SIMD.VecRange{N}) where {B,N} + @static if (VERSION < v"1.11" && Sys.ARCH == :x86_64) + error("Julia's explicit SIMD.jl has a compiler bug that appears with batch + parameters on versions < 1.11 AND an x86_64 bit architecture, which we + detected that you have. To get around this, specify the `track!` + keyword argument `use_explicit_SIMD=false`") + end + m = rem(lane2vec(lane), B) + i = vifelse(m == 0, B, m) + return b.batch[i] +end + +""" + lane2vec(lane::SIMD.VecRange{N}) + +Given a SIMD.VecRange, will return an equivalent SIMD.Vec that +can be used in arithmetic operations for mapping integer indices +of particles to a given element in a batch. +""" +function lane2vec(lane::SIMD.VecRange{N}) where {N} + # Try to match with vector register size, but + # only up to UInt32 -> ~4.3 billion particles, + # probably max on CPU... + if Int(pick_vector_width(UInt32)) == N + return SIMD.Vec{N,UInt32}(ntuple(i->lane.i+i-1, Val{N}())) + else + return SIMD.Vec{N,UInt64}(ntuple(i->lane.i+i-1, Val{N}())) + end +end + +@inline beval(b, i) = b + +# === THIS BLOCK WAS WRITTEN BY CLAUDE === +# Generated function for arbitrary-length tuples +@generated function beval(f::T, t) where {T<:Tuple} + N = length(T.parameters) + if N == 0 + return :(()) + end + # Use getfield with literal integer arguments + exprs = [:(beval(Base.getfield(f, $i), t)) for i in 1:N] + return :(tuple($(exprs...))) +end +# === END CLAUDE === \ No newline at end of file diff --git a/src/kernel.jl b/src/kernel.jl index 07794e1..3898b9b 100644 --- a/src/kernel.jl +++ b/src/kernel.jl @@ -9,11 +9,14 @@ blank_kernel!(args...) = nothing kernel::K = blank_kernel! args::A = () function KernelCall(kernel, args) - _args = map(t->time_lower(t), args) + _args = map(t->time_lower(batch_lower(t)), args) new{typeof(kernel),typeof(_args)}(kernel, _args) end end +# In case KernelCall contains batch GPU array +Adapt.@adapt_structure KernelCall + # Store the state of the reference coordinate system # Needed for time-dependent parameters struct RefState{T,U} @@ -28,6 +31,9 @@ struct KernelChain{C<:Tuple{Vararg{<:KernelCall}}, S<:Union{Nothing,RefState}} KernelChain(chain, ref=nothing) = new{typeof(chain), typeof(ref)}(chain, ref) end +# In case KernelChain contains batch GPU array +Adapt.@adapt_structure KernelChain + KernelChain(::Val{N}, ref=nothing) where {N} = KernelChain(ntuple(t->KernelCall(), Val{N}()), ref) push(kc::KernelChain, kcall::Nothing) = kc @@ -55,17 +61,25 @@ _generic_kernel!(i, coords, kc) = __generic_kernel!(i, coords, kc.chain, kc.ref) @unroll function __generic_kernel!(i, coords::Coords, chain, ref) @unroll for kcall in chain - args = process_args(i, coords, kcall.args, ref) + bargs = process_batch_args(i, kcall.args) + args = process_time_args(i, coords, bargs, ref) (kcall.kernel)(i, coords, args...) end return nothing end -function process_args(i, coords, args, ref) +function process_batch_args(i, args) + if static_batchcheck(args) + return beval(args, i) + else + return args + end +end + +function process_time_args(i, coords, args, ref) if !isnothing(ref) && static_timecheck(args) let t = compute_time(coords.v[i,ZI], coords.v[i,PZI], ref) - new_args = map(arg->teval(arg, t), args) - return map(arg->teval(arg, t), args) + return teval(args, t) end else return args @@ -81,7 +95,7 @@ end groupsize::Union{Nothing,Integer}=nothing, #backend isa CPU ? floor(Int,REGISTER_SIZE/sizeof(eltype(v))) : 256 multithread_threshold::Integer=Threads.nthreads() > 1 ? 1750*Threads.nthreads() : typemax(Int), use_KA::Bool=!(get_backend(coords.v) isa CPU && isnothing(groupsize)), - use_explicit_SIMD::Bool=!use_KA && (@static VERSION < v"1.11" || Sys.ARCH != :aarch64) # Default to use explicit SIMD on CPU, excepts for Macs above LTS bc SIMD.jl bug + use_explicit_SIMD::Bool=!use_KA #&& (@static VERSION < v"1.11" || Sys.ARCH != :aarch64) # Default to use explicit SIMD on CPU, excepts for Macs above LTS bc SIMD.jl bug ) where {V} v = coords.v N_particle = size(v, 1) diff --git a/src/kernels/radiation.jl b/src/kernels/radiation.jl index f9bcd20..acbead4 100644 --- a/src/kernels/radiation.jl +++ b/src/kernels/radiation.jl @@ -160,7 +160,7 @@ end sigma2_1 = one(sigma2) sigma = sqrt(vifelse(alive, sigma2, sigma2_1)) - dpz, theta = gaussian_random(sigma, sqrt(13/55)/gamma) + dpz, theta = gaussian_random(v, sigma, sqrt(13/55)/gamma) s, c = sincos(theta) b_perp_hat_x = vifelse(b_perp > 0, b_perp_x / b_perp, 0) diff --git a/src/kernels/yoshida.jl b/src/kernels/yoshida.jl index 6bc1df3..4bed375 100644 --- a/src/kernels/yoshida.jl +++ b/src/kernels/yoshida.jl @@ -2,129 +2,141 @@ # =============== I N T E G R A T O R S =============== # -@makekernel fastgtpsa=true function order_two_integrator!(i, coords::Coords, ker, params, photon_params, ds_step, num_steps, edge_params, L) - if !isnothing(edge_params) - a, tilde_m, Ksol, Kn0, e1, e2 = edge_params - linear_bend_fringe!(i, coords, a, tilde_m, Ksol, Kn0, e1, 1) - end - if !isnothing(photon_params) - stochastic_radiation!(i, coords, photon_params..., ds_step / 2) - end - for step in 1:num_steps - ker(i, coords, params..., ds_step) - if !isnothing(photon_params) && (step < num_steps) - stochastic_radiation!(i, coords, photon_params..., ds_step) +@inline function order_two_integrator!(i, coords::Coords, ker, params, photon_params, ds_step, num_steps, edge_params, ::Val{fringe_in}, ::Val{fringe_out}, L) where {fringe_in,fringe_out} + @inbounds begin + if !isnothing(edge_params) && fringe_in + a, tilde_m, Ksol, Kn0, e1, e2 = edge_params + linear_bend_fringe!(i, coords, a, tilde_m, Ksol, Kn0, e1, 1) + end + if !isnothing(photon_params) + stochastic_radiation!(i, coords, photon_params..., ds_step / 2) + end + for step in 1:num_steps + ker(i, coords, params..., ds_step) + if !isnothing(photon_params) && (step < num_steps) + stochastic_radiation!(i, coords, photon_params..., ds_step) + end + end + if !isnothing(photon_params) + stochastic_radiation!(i, coords, photon_params..., ds_step / 2) + end + if !isnothing(edge_params) && fringe_out + a, tilde_m, Ksol, Kn0, e1, e2 = edge_params + linear_bend_fringe!(i, coords, a, tilde_m, Ksol, Kn0, e2, -1) end - end - if !isnothing(photon_params) - stochastic_radiation!(i, coords, photon_params..., ds_step / 2) - end - if !isnothing(edge_params) - linear_bend_fringe!(i, coords, a, tilde_m, Ksol, Kn0, e2, -1) end end -@makekernel fastgtpsa=true function order_four_integrator!(i, coords::Coords, ker, params, photon_params, ds_step, num_steps, edge_params, L) - w0 = -1.7024143839193153215916254339390434324741363525390625*ds_step - w1 = 1.3512071919596577718181151794851757586002349853515625*ds_step - if !isnothing(edge_params) - a, tilde_m, Ksol, Kn0, e1, e2 = edge_params - linear_bend_fringe!(i, coords, a, tilde_m, Ksol, Kn0, e1, 1) - end - if !isnothing(photon_params) - stochastic_radiation!(i, coords, photon_params..., ds_step / 2) - end - for step in 1:num_steps - ker(i, coords, params..., w1) - ker(i, coords, params..., w0) - ker(i, coords, params..., w1) - if !isnothing(photon_params) && (step < num_steps) - stochastic_radiation!(i, coords, photon_params..., ds_step) +@inline function order_four_integrator!(i, coords::Coords, ker, params, photon_params, ds_step, num_steps, edge_params, ::Val{fringe_in}, ::Val{fringe_out}, L) where {fringe_in,fringe_out} + @inbounds begin + w0 = -1.7024143839193153215916254339390434324741363525390625*ds_step + w1 = 1.3512071919596577718181151794851757586002349853515625*ds_step + if !isnothing(edge_params) && fringe_in + a, tilde_m, Ksol, Kn0, e1, e2 = edge_params + linear_bend_fringe!(i, coords, a, tilde_m, Ksol, Kn0, e1, 1) + end + if !isnothing(photon_params) + stochastic_radiation!(i, coords, photon_params..., ds_step / 2) + end + for step in 1:num_steps + ker(i, coords, params..., w1) + ker(i, coords, params..., w0) + ker(i, coords, params..., w1) + if !isnothing(photon_params) && (step < num_steps) + stochastic_radiation!(i, coords, photon_params..., ds_step) + end + end + if !isnothing(photon_params) + stochastic_radiation!(i, coords, photon_params..., ds_step / 2) + end + if !isnothing(edge_params) && fringe_out + a, tilde_m, Ksol, Kn0, e1, e2 = edge_params + linear_bend_fringe!(i, coords, a, tilde_m, Ksol, Kn0, e2, -1) end - end - if !isnothing(photon_params) - stochastic_radiation!(i, coords, photon_params..., ds_step / 2) - end - if !isnothing(edge_params) - linear_bend_fringe!(i, coords, a, tilde_m, Ksol, Kn0, e2, -1) end end -@makekernel fastgtpsa=true function order_six_integrator!(i, coords::Coords, ker, params, photon_params, ds_step, num_steps, edge_params, L) - w0 = 1.315186320683911169737712043570355*ds_step - w1 = -1.17767998417887100694641568096432*ds_step - w2 = 0.235573213359358133684793182978535*ds_step - w3 = 0.784513610477557263819497633866351*ds_step - if !isnothing(edge_params) - a, tilde_m, Ksol, Kn0, e1, e2 = edge_params - linear_bend_fringe!(i, coords, a, tilde_m, Ksol, Kn0, e1, 1) - end - if !isnothing(photon_params) - stochastic_radiation!(i, coords, photon_params..., ds_step / 2) - end - for step in 1:num_steps - ker(i, coords, params..., w3) - ker(i, coords, params..., w2) - ker(i, coords, params..., w1) - ker(i, coords, params..., w0) - ker(i, coords, params..., w1) - ker(i, coords, params..., w2) - ker(i, coords, params..., w3) - if !isnothing(photon_params) && (step < num_steps) - stochastic_radiation!(i, coords, photon_params..., ds_step) +@inline function order_six_integrator!(i, coords::Coords, ker, params, photon_params, ds_step, num_steps, edge_params, ::Val{fringe_in}, ::Val{fringe_out}, L) where {fringe_in,fringe_out} + @inbounds begin + w0 = 1.315186320683911169737712043570355*ds_step + w1 = -1.17767998417887100694641568096432*ds_step + w2 = 0.235573213359358133684793182978535*ds_step + w3 = 0.784513610477557263819497633866351*ds_step + if !isnothing(edge_params) && fringe_in + a, tilde_m, Ksol, Kn0, e1, e2 = edge_params + linear_bend_fringe!(i, coords, a, tilde_m, Ksol, Kn0, e1, 1) + end + if !isnothing(photon_params) + stochastic_radiation!(i, coords, photon_params..., ds_step / 2) + end + for step in 1:num_steps + ker(i, coords, params..., w3) + ker(i, coords, params..., w2) + ker(i, coords, params..., w1) + ker(i, coords, params..., w0) + ker(i, coords, params..., w1) + ker(i, coords, params..., w2) + ker(i, coords, params..., w3) + if !isnothing(photon_params) && (step < num_steps) + stochastic_radiation!(i, coords, photon_params..., ds_step) + end + end + if !isnothing(photon_params) + stochastic_radiation!(i, coords, photon_params..., ds_step / 2) + end + if !isnothing(edge_params) && fringe_out + a, tilde_m, Ksol, Kn0, e1, e2 = edge_params + linear_bend_fringe!(i, coords, a, tilde_m, Ksol, Kn0, e2, -1) end - end - if !isnothing(photon_params) - stochastic_radiation!(i, coords, photon_params..., ds_step / 2) - end - if !isnothing(edge_params) - linear_bend_fringe!(i, coords, a, tilde_m, Ksol, Kn0, e2, -1) end end -@makekernel fastgtpsa=true function order_eight_integrator!(i, coords::Coords, ker, params, photon_params, ds_step, num_steps, edge_params, L) - w0 = 1.7084530707869978*ds_step - w1 = 0.102799849391985*ds_step - w2 = -1.96061023297549*ds_step - w3 = 1.93813913762276*ds_step - w4 = -0.158240635368243*ds_step - w5 = -1.44485223686048*ds_step - w6 = 0.253693336566229*ds_step - w7 = 0.914844246229740*ds_step - if !isnothing(edge_params) - a, tilde_m, Ksol, Kn0, e1, e2 = edge_params - linear_bend_fringe!(i, coords, a, tilde_m, Ksol, Kn0, e1, 1) - end - if !isnothing(photon_params) - stochastic_radiation!(i, coords, photon_params..., ds_step / 2) - end - for step in 1:num_steps - ker(i, coords, params..., w7) - ker(i, coords, params..., w6) - ker(i, coords, params..., w5) - ker(i, coords, params..., w4) - ker(i, coords, params..., w3) - ker(i, coords, params..., w2) - ker(i, coords, params..., w1) - ker(i, coords, params..., w0) - ker(i, coords, params..., w1) - ker(i, coords, params..., w2) - ker(i, coords, params..., w3) - ker(i, coords, params..., w4) - ker(i, coords, params..., w5) - ker(i, coords, params..., w6) - ker(i, coords, params..., w7) - if !isnothing(photon_params) && (step < num_steps) - stochastic_radiation!(i, coords, photon_params..., ds_step) +@inline function order_eight_integrator!(i, coords::Coords, ker, params, photon_params, ds_step, num_steps, edge_params, ::Val{fringe_in}, ::Val{fringe_out}, L) where {fringe_in,fringe_out} + @inbounds begin + w0 = 1.7084530707869978*ds_step + w1 = 0.102799849391985*ds_step + w2 = -1.96061023297549*ds_step + w3 = 1.93813913762276*ds_step + w4 = -0.158240635368243*ds_step + w5 = -1.44485223686048*ds_step + w6 = 0.253693336566229*ds_step + w7 = 0.914844246229740*ds_step + if !isnothing(edge_params) && fringe_in + a, tilde_m, Ksol, Kn0, e1, e2 = edge_params + linear_bend_fringe!(i, coords, a, tilde_m, Ksol, Kn0, e1, 1) + end + if !isnothing(photon_params) + stochastic_radiation!(i, coords, photon_params..., ds_step / 2) + end + for step in 1:num_steps + ker(i, coords, params..., w7) + ker(i, coords, params..., w6) + ker(i, coords, params..., w5) + ker(i, coords, params..., w4) + ker(i, coords, params..., w3) + ker(i, coords, params..., w2) + ker(i, coords, params..., w1) + ker(i, coords, params..., w0) + ker(i, coords, params..., w1) + ker(i, coords, params..., w2) + ker(i, coords, params..., w3) + ker(i, coords, params..., w4) + ker(i, coords, params..., w5) + ker(i, coords, params..., w6) + ker(i, coords, params..., w7) + if !isnothing(photon_params) && (step < num_steps) + stochastic_radiation!(i, coords, photon_params..., ds_step) + end + end + if !isnothing(photon_params) + stochastic_radiation!(i, coords, photon_params..., ds_step / 2) + end + if !isnothing(edge_params) && fringe_out + a, tilde_m, Ksol, Kn0, e1, e2 = edge_params + linear_bend_fringe!(i, coords, a, tilde_m, Ksol, Kn0, e2, -1) end - end - if !isnothing(photon_params) - stochastic_radiation!(i, coords, photon_params..., ds_step / 2) - end - if !isnothing(edge_params) - linear_bend_fringe!(i, coords, a, tilde_m, Ksol, Kn0, e2, -1) end end \ No newline at end of file diff --git a/src/time.jl b/src/time.jl index 82c57f0..b7c2e9e 100644 --- a/src/time.jl +++ b/src/time.jl @@ -46,7 +46,7 @@ function Base.literal_pow(::typeof(^), da::TimeDependentParam, ::Val{N}) where { end for t = (:+, :-, :sqrt, :exp, :log, :sin, :cos, :tan, :cot, :sinh, :cosh, :tanh, :inv, - :coth, :asin, :acos, :atan, :acot, :asinh, :acosh, :atanh, :acoth, :sinc, :csc, + :coth, :asin, :acos, :atan, :acot, :asinh, :acosh, :atanh, :acoth, :sinc, :csc, :float, :csch, :acsc, :acsch, :sec, :sech, :asec, :asech, :conj, :log10, :isnan, :sign, :abs) @eval begin Base.$t(d::TimeDependentParam) = (let f = d.f; return TimeDependentParam((t)-> ($t)(f(t))); end) diff --git a/src/tracking_methods.jl b/src/tracking_methods.jl index 483aea7..b3d7034 100644 --- a/src/tracking_methods.jl +++ b/src/tracking_methods.jl @@ -1,3 +1,6 @@ +# ========== Fringe =========================== +@enumx Fringe NoEnd BothEnds EntranceEnd ExitEnd + # ========== Yoshida =========================== abstract type AbstractYoshida end macro def_integrator_struct(name) @@ -8,9 +11,9 @@ macro def_integrator_struct(name) ds_step::Float64 radiation_damping_on::Bool radiation_fluctuations_on::Bool - fringe_on::Bool + fringe_at::Fringe.T - function $(esc(name))(; order::Int=4, num_steps::Int=-1, ds_step::Float64=-1.0, radiation_damping_on::Bool=false, radiation_fluctuations_on::Bool=false, fringe_on::Bool=true) + function $(esc(name))(; order::Int=4, num_steps::Int=-1, ds_step::Float64=-1.0, radiation_damping_on::Bool=false, radiation_fluctuations_on::Bool=false, fringe_at::Fringe.T=Fringe.BothEnds) _order = order _num_steps = num_steps _ds_step = ds_step @@ -27,7 +30,7 @@ macro def_integrator_struct(name) elseif _ds_step > 0 _num_steps = -1 end - return new(_order, _num_steps, _ds_step, radiation_damping_on, radiation_fluctuations_on, fringe_on) + return new(_order, _num_steps, _ds_step, radiation_damping_on, radiation_fluctuations_on, fringe_at) end end end @@ -41,4 +44,10 @@ end # ========== Exact =========================== -struct Exact end \ No newline at end of file +struct Exact + fringe_at::Fringe.T + + function Exact(; fringe_at::Fringe.T=Fringe.BothEnds) + return new(fringe_at) + end +end \ No newline at end of file diff --git a/src/utils/math_simd.jl b/src/utils/math_simd.jl index 6ccbf19..2941601 100644 --- a/src/utils/math_simd.jl +++ b/src/utils/math_simd.jl @@ -157,26 +157,16 @@ end """ This function returns two Gaussian random numbers with -mean 0 and standard deviations sigma1, sigma2 using a -Box-Muller transform. - -This was implemented because CUDA.randn has some horrible -compiler bug, but CUDA.rand seems to be ok. Nonetheless -this may also give CPU performance benefits with radiation -because we already need to compute two randn's anyways. -""" -function gaussian_random(sigma1, sigma2) - s, c = sincospi(2 * rand()) - t = sqrt(-2 * log(rand())) - z0 = c*t*sigma1 - z1 = s*t*sigma2 - return z0, z1 +mean 0 and standard deviations sigma1, sigma2. +""" +function gaussian_random(::Matrix, sigma1, sigma2) + return randn()*sigma1, randn()*sigma2 end """ See gaussian_random, but for SIMD vectors. """ -function gaussian_random(sigma1::SIMD.Vec, sigma2::SIMD.Vec) - return SIMDMathFunctions.vmap(gaussian_random, sigma1, sigma2) +function gaussian_random(::Matrix, sigma1::SIMD.Vec, sigma2::SIMD.Vec) + return SIMDMathFunctions.vmap((s1,s2)->(randn()*s1, randn()*s2), sigma1, sigma2) end diff --git a/src/utils/quaternions.jl b/src/utils/quaternions.jl index cbd44ea..cec8e0a 100644 --- a/src/utils/quaternions.jl +++ b/src/utils/quaternions.jl @@ -31,7 +31,7 @@ function sincos_quaternion(x::TPS{T}) where {T} #sq = one(x) # Using FastGTPSA! for the following makes other kernels run out of temps @FastGTPSA begin - if x < 0.1 + if x < ε #0.1 while !(conv_sin && conv_cos) && N < N_max y = -y*x/((2*N)*(2*N - 1)) result_sin = prev_sin + y/(2*N + 1) diff --git a/test/BeamlinesExt/beamlines_stochastic_test.jl b/test/BeamlinesExt/beamlines_stochastic_test.jl index 9bf854b..e96a96c 100644 --- a/test/BeamlinesExt/beamlines_stochastic_test.jl +++ b/test/BeamlinesExt/beamlines_stochastic_test.jl @@ -3,6 +3,7 @@ using Random @testset "Stochastic radiation" begin Random.seed!(0) + # Here just check that they don't bug out p_over_q_ref = BeamTracking.E_to_R(Species("electron"), 18e9) bend = SBend(g = 0.01, L = 2.0, tracking_method = Yoshida(order = 2, num_steps = 1, @@ -13,109 +14,41 @@ using Random b0 = Bunch(copy(v0), species = line.species_ref, p_over_q_ref = line.p_over_q_ref) track!(b0, line) - @test b0.coords.v ≈ - [0.04889165969752571 - 0.021174046860678353 - 0.10556186702264685 - 0.03999912146491459 - 0.04761050793826471 - 0.05997673560385752]' - bend.tracking_method = Yoshida(order = 2, num_steps = 2, radiation_damping_on = true, radiation_fluctuations_on = true) b0 = Bunch(copy(v0), species = line.species_ref, p_over_q_ref = line.p_over_q_ref) track!(b0, line) - @test b0.coords.v ≈ - [0.04889193258818915 - 0.021174302497216673 - 0.10556186827683302 - 0.03999908434104034 - 0.04761050068397177 - 0.059975733285604405]' - bend.tracking_method = Yoshida(order = 4, num_steps = 1, radiation_damping_on = true, radiation_fluctuations_on = true) b0 = Bunch(copy(v0), species = line.species_ref, p_over_q_ref = line.p_over_q_ref) track!(b0, line) - @test b0.coords.v ≈ - [0.048891974265822244 - 0.02117431265208611 - 0.10556186796033079 - 0.03999910506374448 - 0.047610499467786394 - 0.059976284409328576]' - bend.tracking_method = Yoshida(order = 4, num_steps = 2, radiation_damping_on = true, radiation_fluctuations_on = true) b0 = Bunch(copy(v0), species = line.species_ref, p_over_q_ref = line.p_over_q_ref) track!(b0, line) - @test b0.coords.v ≈ - [0.0488920016565342 - 0.021174218561770715 - 0.10556186797663822 - 0.039998878356961864 - 0.04761049873848289 - 0.05997027467880636]' - - bend.tracking_method = Yoshida(order = 6, num_steps = 1, radiation_damping_on = true, radiation_fluctuations_on = true) b0 = Bunch(copy(v0), species = line.species_ref, p_over_q_ref = line.p_over_q_ref) track!(b0, line) - @test b0.coords.v ≈ - [0.04889226281541181 - 0.021174668413673003 - 0.10556186827012498 - 0.03999920072349538 - 0.04761049187637356 - 0.059978810393760046]' - - bend.tracking_method = Yoshida(order = 6, num_steps = 2, radiation_damping_on = true, radiation_fluctuations_on = true) b0 = Bunch(copy(v0), species = line.species_ref, p_over_q_ref = line.p_over_q_ref) track!(b0, line) - - @test b0.coords.v ≈ - [0.04889189599231092 - 0.02117408337833361 - 0.10556186775984797 - 0.03999890918412512 - 0.04761050140381129 - 0.059971097099417885]' - bend.tracking_method = Yoshida(order = 8, num_steps = 1, radiation_damping_on = true, radiation_fluctuations_on = true) b0 = Bunch(copy(v0), species = line.species_ref, p_over_q_ref = line.p_over_q_ref) track!(b0, line) - @test b0.coords.v ≈ - [0.04889209188518841 - 0.02117482934639319 - 0.10556186796145814 - 0.039999846102478705 - 0.047610496377937794 - 0.059995919985280796]' - - bend.tracking_method = Yoshida(order = 8, num_steps = 2, radiation_damping_on = true, radiation_fluctuations_on = true) b0 = Bunch(copy(v0), species = line.species_ref, p_over_q_ref = line.p_over_q_ref) track!(b0, line) - @test b0.coords.v ≈ - [0.04889219169051286 - 0.02117499046271834 - 0.10556186830041747 - 0.039999857154632036 - 0.04761049389108041 - 0.05999621385138163]' - # Now just check SIMD , if it doesn't bug out bend.tracking_method = Yoshida(order = 8, num_steps = 2, radiation_damping_on = true, radiation_fluctuations_on = true) diff --git a/test/BeamlinesExt_test.jl b/test/BeamlinesExt_test.jl index c95df85..4abd74f 100644 --- a/test/BeamlinesExt_test.jl +++ b/test/BeamlinesExt_test.jl @@ -421,7 +421,7 @@ @test quaternion_coeffs_approx_equal(q_expected, q_z, 1e-14) # Pure bend: - ele = LineElement(L=2.0, g=0.1, tracking_method=Yoshida(order=6, num_steps=10, fringe_on=false)) + ele = LineElement(L=2.0, g=0.1, tracking_method=Yoshida(order=6, num_steps=10, fringe_at=Fringe.NoEnd)) v = [0.01 0.02 0.03 0.04 0.05 0.06] q = [1.0 0.0 0.0 0.0] b0 = Bunch(v, q, p_over_q_ref=p_over_q_ref, species=Species("electron")) @@ -432,7 +432,7 @@ || b0.coords.q ≈ -[0.99999555887473 0.00000197685011 0.00297918168991 0.00008187412527]) # Pure solenoid: - ele = LineElement(L=1.0, Ksol=2.0, tracking_method=Yoshida(order=2, fringe_on=false)) + ele = LineElement(L=1.0, Ksol=2.0, tracking_method=Yoshida(order=2, fringe_at=Fringe.NoEnd)) v = collect(transpose(@vars(D10))) q = TPS64{D10}[1 0 0 0] b0 = Bunch(v, q, p_over_q_ref=p_over_q_ref, species=Species("electron")) @@ -444,7 +444,7 @@ @test quaternion_coeffs_approx_equal(q_expected, q_z, 6e-9) # Solenoid with quadrupole: - ele = LineElement(L=2.0, Ksol=0.1, Kn1=0.1, tracking_method=SolenoidKick(order=2, fringe_on=false)) + ele = LineElement(L=2.0, Ksol=0.1, Kn1=0.1, tracking_method=SolenoidKick(order=2, fringe_at=Fringe.NoEnd)) v = collect(transpose(@vars(D10))) q = TPS64{D10}[1 0 0 0] b0 = Bunch(v, q, p_over_q_ref=p_over_q_ref, species=Species("electron")) @@ -456,7 +456,7 @@ @test quaternion_coeffs_approx_equal(q_expected, q_z, 6e-9) # SK multiple steps: - ele = LineElement(L=2.0, Ksol=0.1, Kn1=0.1, tracking_method=SolenoidKick(order=4, num_steps=2, fringe_on=false)) + ele = LineElement(L=2.0, Ksol=0.1, Kn1=0.1, tracking_method=SolenoidKick(order=4, num_steps=2, fringe_at=Fringe.NoEnd)) v = collect(transpose(@vars(D10))) q = TPS64{D10}[1 0 0 0] b0 = Bunch(v, q, p_over_q_ref=p_over_q_ref, species=Species("electron")) @@ -468,7 +468,7 @@ @test quaternion_coeffs_approx_equal(q_expected, q_z, 6e-9) # Step size: - ele = LineElement(L=2.0, Ksol=0.1, Kn1=0.1, tracking_method=SolenoidKick(order=4, ds_step=1.0, fringe_on=false)) + ele = LineElement(L=2.0, Ksol=0.1, Kn1=0.1, tracking_method=SolenoidKick(order=4, ds_step=1.0, fringe_at=Fringe.NoEnd)) v = collect(transpose(@vars(D10))) q = TPS64{D10}[1 0 0 0] b0 = Bunch(v, q, p_over_q_ref=p_over_q_ref, species=Species("electron")) @@ -480,7 +480,7 @@ @test quaternion_coeffs_approx_equal(q_expected, q_z, 6e-9) # Straight pure dipole (DK): - ele = LineElement(L=2.0, Kn0=0.1, tilt0=pi/3, tracking_method=DriftKick(order=2)) + ele = LineElement(L=2.0, Kn0=0.1, tilt0=pi/3, tracking_method=DriftKick(order=2, fringe_at=Fringe.NoEnd)) v = collect(transpose(@vars(D10))) q = TPS64{D10}[1 0 0 0] b0 = Bunch(v, q, p_over_q_ref=p_over_q_ref, species=Species("electron")) @@ -492,7 +492,7 @@ @test quaternion_coeffs_approx_equal(q_expected, q_z, 6e-9) # Straight pure dipole (BK): - ele = LineElement(L=2.0, Kn0=0.1, tracking_method=BendKick(order=6, num_steps=10)) + ele = LineElement(L=2.0, Kn0=0.1, tracking_method=BendKick(order=6, num_steps=10, fringe_at=Fringe.NoEnd)) v = collect(transpose(@vars(D10))) q = TPS64{D10}[1 0 0 0] b0 = Bunch(v, q, p_over_q_ref=p_over_q_ref, species=Species("electron")) @@ -504,7 +504,7 @@ @test quaternion_coeffs_approx_equal(q_expected, q_z, 2e-6) # Straight dipole with quadrupole (DK): - ele = LineElement(L=2.0, Kn0=0.1, Kn1=0.03, tracking_method=DriftKick(order=2)) + ele = LineElement(L=2.0, Kn0=0.1, Kn1=0.03, tracking_method=DriftKick(order=2, fringe_at=Fringe.NoEnd)) v = collect(transpose(@vars(D10))) q = TPS64{D10}[1 0 0 0] b0 = Bunch(v, q, p_over_q_ref=p_over_q_ref, species=Species("electron")) @@ -516,7 +516,7 @@ @test quaternion_coeffs_approx_equal(q_expected, q_z, 6e-9) # Straight dipole with quadrupole (BK): - ele = LineElement(L=2.0, Kn0=0.1, Kn1=0.1, tracking_method=BendKick(order=6, num_steps=10)) + ele = LineElement(L=2.0, Kn0=0.1, Kn1=0.1, tracking_method=BendKick(order=6, num_steps=10, fringe_at=Fringe.NoEnd)) v = collect(transpose(@vars(D10))) q = TPS64{D10}[1 0 0 0] b0 = Bunch(v, q, p_over_q_ref=p_over_q_ref, species=Species("electron")) @@ -528,7 +528,7 @@ @test quaternion_coeffs_approx_equal(q_expected, q_z, 2e-6) # Straight dipole with quadrupole (MK): - ele = LineElement(L=2.0, Kn0=0.1, Kn1=0.1, tracking_method=Yoshida(order=6, num_steps=10)) + ele = LineElement(L=2.0, Kn0=0.1, Kn1=0.1, tracking_method=Yoshida(order=6, num_steps=10, fringe_at=Fringe.NoEnd)) v = collect(transpose(@vars(D10))) q = TPS64{D10}[1 0 0 0] b0 = Bunch(v, q, p_over_q_ref=p_over_q_ref, species=Species("electron")) @@ -612,7 +612,7 @@ @test quaternion_coeffs_approx_equal(q_expected, q_z, 2e-7) # Quadrupole with dipole and sextupole (MK): - ele = LineElement(L=2.0, Kn0=0.1, Kn1=0.2, Kn2=0.3, tracking_method=MatrixKick(order=6, num_steps=10)) + ele = LineElement(L=2.0, Kn0=0.1, Kn1=0.2, Kn2=0.3, tracking_method=MatrixKick(order=6, num_steps=10, fringe_at=Fringe.NoEnd)) v = [0.01 0.02 0.03 0.04 0.05 0.06] q = [1.0 0.0 0.0 0.0] b0 = Bunch(v, q, p_over_q_ref=p_over_q_ref, species=Species("electron")) @@ -745,7 +745,7 @@ @test quaternion_coeffs_approx_equal(q_expected, q_z, 1e-7) # With solenoid (RK4): - ele = LineElement(L=4.01667, voltage=3321.0942126011, rf_frequency=591142.68014977, Ksol=0.6, tracking_method=Yoshida(order=6, num_steps=2)) + ele = LineElement(L=4.01667, voltage=3321.0942126011, rf_frequency=591142.68014977, Ksol=0.6, tracking_method=Yoshida(order=6, num_steps=2, fringe_at=Fringe.NoEnd)) v = [0.01 0.02 0.03 0.04 0.05 0.06] q = [1.0 0.0 0.0 0.0] b0 = Bunch(v, q, p_over_q_ref=p_over_q_ref, species=Species("electron")) @@ -769,7 +769,7 @@ @test b0.coords.q ≈ q_expected || b0.coords.q ≈ -q_expected # With solenoid and quadrupole: - ele = LineElement(L=4.01667, voltage=3321.0942126011, rf_frequency=591142.68014977, Ksol=-0.3, Kn1=0.15, tracking_method=Yoshida(order=6, num_steps=20)) + ele = LineElement(L=4.01667, voltage=3321.0942126011, rf_frequency=591142.68014977, Ksol=-0.3, Kn1=0.15, tracking_method=Yoshida(order=6, num_steps=20, fringe_at=Fringe.NoEnd)) v = [0.01 0.02 0.03 0.04 0.05 0.06] q = [1.0 0.0 0.0 0.0] b0 = Bunch(v, q, p_over_q_ref=p_over_q_ref, species=Species("electron")) @@ -859,7 +859,7 @@ b0 = Bunch([0.4 0.4 0.4 0.4 0.4 -0.5], [1.0 0.0 0.0 0.0], p_over_q_ref=p_over_q_ref, species=Species("electron")) v_init = copy(b0.coords.v) q_init = copy(b0.coords.q) - ele_dipole = LineElement(L=1.0, Kn0=1e-8, Kn1=1e-8, tracking_method=BendKick()) + ele_dipole = LineElement(L=1.0, Kn0=1e-8, Kn1=1e-8, tracking_method=BendKick(fringe_at=Fringe.NoEnd)) track!(b0, Beamline([ele_dipole], p_over_q_ref=p_over_q_ref)) @test b0.coords.state[1] == STATE_LOST @test v_init == b0.coords.v diff --git a/test/IntegrationTracking_test.jl b/test/IntegrationTracking_test.jl index bbdf52a..f07dcba 100644 --- a/test/IntegrationTracking_test.jl +++ b/test/IntegrationTracking_test.jl @@ -83,7 +83,7 @@ ker = BeamTracking.bkb_multipole! num_steps = 10 ds_step = T(0.2) - return ker, params, nothing, ds_step, num_steps, nothing, L + return ker, params, nothing, ds_step, num_steps, nothing, Val{false}(), Val{false}(), L end function integrator_args(::Type{T}) where {T} @@ -105,7 +105,7 @@ ker = BeamTracking.dkd_multipole! num_steps = 1 ds_step = T(2) - return ker, params, nothing, ds_step, num_steps, nothing, L + return ker, params, nothing, ds_step, num_steps, nothing, Val{false}(), Val{false}(), L end function cavity_args(::Type{T}) where {T} diff --git a/test/batch_test.jl b/test/batch_test.jl new file mode 100644 index 0000000..ee4b71c --- /dev/null +++ b/test/batch_test.jl @@ -0,0 +1,172 @@ +using Beamlines +using BeamTracking +using Test +using LinearAlgebra + +# Pushes 4 equal particles with 4 batched parameters +# compares when particle is pushed through 4 different elements +function test_batch( + batch_params, + nonbatch_params=(;); + v=nothing, + q=nothing, + extra_tests=(), +) + + if isnothing(v) + v = repeat((rand(1,6).-0.5)*1e-4, 4, 1) + end + if isnothing(q) + q = repeat(mapslices(normalize, rand(1,4), dims=2), 4, 1) + end + + E_ref = 5e9 + species = Species("electron") + + ele_1 = LineElement() + ele_2 = LineElement() + ele_3 = LineElement() + ele_4 = LineElement() + ele_batch = LineElement() + for (k,val) in pairs(batch_params) + val_1 = -val + val_2 = -0.5*val + val_3 = 0.5*val + val_4 = val + setproperty!(ele_1, k, val_1) + setproperty!(ele_2, k, val_2) + setproperty!(ele_3, k, val_3) + setproperty!(ele_4, k, val_4) + setproperty!(ele_batch, k, BatchParam([val_1, val_2, val_3, val_4])) + end + for (k,val) in pairs(nonbatch_params) + setproperty!(ele_1, k, val) + setproperty!(ele_2, k, val) + setproperty!(ele_3, k, val) + setproperty!(ele_4, k, val) + setproperty!(ele_batch, k, val) + end + + bl_1 = Beamline([ele_1], E_ref=E_ref, species_ref=species) + bl_2 = Beamline([ele_2], E_ref=E_ref, species_ref=species) + bl_3 = Beamline([ele_3], E_ref=E_ref, species_ref=species) + bl_4 = Beamline([ele_4], E_ref=E_ref, species_ref=species) + bl_batch = Beamline([ele_batch], E_ref=E_ref, species_ref=species) + + b0_1 = Bunch(v[1,:]', q[1,:]'; p_over_q_ref=bl_1.p_over_q_ref, species=bl_1.species_ref) + b0_2 = Bunch(v[1,:]', q[1,:]'; p_over_q_ref=bl_2.p_over_q_ref, species=bl_2.species_ref) + b0_3 = Bunch(v[1,:]', q[1,:]'; p_over_q_ref=bl_3.p_over_q_ref, species=bl_3.species_ref) + b0_4 = Bunch(v[1,:]', q[1,:]'; p_over_q_ref=bl_4.p_over_q_ref, species=bl_4.species_ref) + + b0_batch = Bunch(v, q; p_over_q_ref=bl_batch.p_over_q_ref, species=bl_batch.species_ref) + + track!(b0_1, bl_1) + track!(b0_2, bl_2) + track!(b0_3, bl_3) + track!(b0_4, bl_4) + + # Ensure branchlessness of parameters with explicit SIMD + if (VERSION < v"1.11" && Sys.ARCH == :x86_64) + use_explicit_SIMD=false + else + use_explicit_SIMD=true + end + track!(b0_batch, bl_batch; use_explicit_SIMD=use_explicit_SIMD) + + @test b0_batch.coords.v[1,:]' ≈ b0_1.coords.v + @test b0_batch.coords.v[2,:]' ≈ b0_2.coords.v + @test b0_batch.coords.v[3,:]' ≈ b0_3.coords.v + @test b0_batch.coords.v[4,:]' ≈ b0_4.coords.v + + @test b0_batch.coords.q[1,:]' ≈ b0_1.coords.q + @test b0_batch.coords.q[2,:]' ≈ b0_2.coords.q + @test b0_batch.coords.q[3,:]' ≈ b0_3.coords.q + @test b0_batch.coords.q[4,:]' ≈ b0_4.coords.q + + for extra_test in extra_tests + @test extra_test(b0_plus, b0_minus, b0_time) + end +end + +@testset "Batch" begin + # Test each of the splits: DKD, MKM, SKS, BKB + # MKM: + test_batch((;Kn1=0.36), (;L=0.5)) + test_batch((;Kn1=0.36, Ks2=-1.2, Kn12=105.), (;L=3.4)) + + # SKS: + test_batch((;Ksol=1.2), (;L=3.4)) + test_batch((;Ksol=0.23, Kn1=0.36, Ks2=-1.2, Kn12=105.), (;L=3.4)) + + # DKD: + test_batch((;Kn0=1e-2, Kn1=-2, Kn3=14), (;L=5.1)) + + # BKB not working at the moment because bend multipoles not + # implemented and Ks0 (also stored) becomes a TimeDependentParam + # test_batch((;Kn0=1e-4, e1=2e-2, e2=3e-2), (;L=1.4)) + + Kn1 = 0.36 + Ks2 = -1.2 + L = 3.4 + Kn12 = 105. + Kn3 = 14 + Kn0 = 1e-2 + Ksol = -0.23 + test_batch((;Kn1=Kn1), (;L=L)) + test_batch((;Kn1=Kn1, Ks2=Ks2, Kn12=Kn12), (;L=L)) + + # SKS: + test_batch((;Ksol=Ksol), (;L=L)) + test_batch((;Ksol=Ksol, Kn1=Kn1, Ks2=Ks2, Kn12=Kn12), (;L=L)) + + # DKD: + test_batch((;Kn0=Kn0, Kn1=Kn1, Kn3=Kn3), (;L=L)) + + # Now with different types of multipoles entered: + test_batch((;Bn1=Kn1), (;L=0.5)) + test_batch((;Bn1L=Kn1, Ks2L=Ks2, Bn12=Kn12), (;L=L)) + + # SKS: + test_batch((;Bsol=Ksol), (;L=L)) + test_batch((;BsolL=Ksol), (;L=L)) + test_batch((;KsolL=Ksol, Bn1L=Kn1, Bs2=Ks2, Kn12L=Kn12), (;L=L)) + + # DKD: + test_batch((;Bn0L=Kn0, Kn1L=Kn1, Bn3=Kn3), (;L=L)) + test_batch((;Bn0L=Kn0, Kn1L=Kn1, Bn3=Kn3), (;L=L)) + test_batch((;Bn0L=Kn0, Kn1L=Kn1, Bn3=Kn3), (;L=0)) + test_batch((;Bn0L=Kn0, Kn1L=Kn1, Bn3=Kn3), (;L=0)) + + #= + # Aperture: + # let's make a time-dependent aperture which oscillates but will allow both + # particles through + function check_state(p, m, t, state=BeamTracking.STATE_ALIVE) + if !all(p.coords.state .== m.coords.state .== t.coords.state .== state) + error("Test failed: p.coords.state = $(p.coords.state), m.coords.state = $(m.coords.state), + t.coords.state = $(t.coords.state), state = $state") + else + return true + end + end + + test_batch( + (;x2_limit=1), + (;aperture_shape = ApertureShape.Rectangular); + v=[0.5 0 0 0 0 0], + ft=(v)->v*cos(2*Time()), + extra_tests=(check_state,) + ) + + # now one where all particle should die + test_batch( + (;x2_limit=1), + (;aperture_shape = ApertureShape.Rectangular); + v=[0.5 0 0 0 0 0], + ft=(v)->v*sin(2*Time()), + extra_tests=((p,m,t)->check_state(p,m,t,BeamTracking.STATE_LOST_POS_X),) + ) + +=# + +end \ No newline at end of file diff --git a/test/lattices/alignment_lat.jl b/test/lattices/alignment_lat.jl index 356c104..45eb861 100644 --- a/test/lattices/alignment_lat.jl +++ b/test/lattices/alignment_lat.jl @@ -1,7 +1,7 @@ using BeamTracking using Beamlines -@eles begin +@elements begin drift1 = Drift(L = 2.0, x_offset = 0.1, y_offset = 0.2, z_offset = 0.3, x_rot = 0.04, y_rot = 0.05, tilt = 0.06, tracking_method = Exact()) diff --git a/test/lattices/aperture_lat.jl b/test/lattices/aperture_lat.jl index 557bee1..8814784 100644 --- a/test/lattices/aperture_lat.jl +++ b/test/lattices/aperture_lat.jl @@ -1,7 +1,7 @@ using BeamTracking using Beamlines -@eles begin +@elements begin d_error = LineElement(L=1, x1_limit=1, dx=5) d1_rect = Drift(L=1, x1_limit = 1, x2_limit = 2, y1_limit = 3, y2_limit = 5, aperture_shape = ApertureShape.Rectangular, aperture_active = false) diff --git a/test/lattices/esr.jl b/test/lattices/esr.jl index a9d3622..01f6444 100644 --- a/test/lattices/esr.jl +++ b/test/lattices/esr.jl @@ -1,6 +1,6 @@ using BeamTracking, Beamlines -@eles begin +@elements begin IP6__1 = Marker() D000001__1 = Drift( L = 5.3) Q1ER_6 = Quadrupole( L = 1.8, Kn1 = -0.2291420342) diff --git a/test/runtests.jl b/test/runtests.jl index 7fefaea..46935da 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -15,6 +15,9 @@ using BeamTracking: Coords, KernelCall, Q0, QX, QY, QZ, STATE_ALIVE, STATE_LOST, quat_mul, quat_rotate, gaussian_random using Beamlines: isactive +@show BeamTracking.REGISTER_SIZE +@show Sys.ARCH + BenchmarkTools.DEFAULT_PARAMETERS.gctrial = false BenchmarkTools.DEFAULT_PARAMETERS.evals = 2 @@ -195,9 +198,11 @@ function quaternion_coeffs_approx_equal(q_expected, q_calculated, ϵ) return all_ok end +include("batch_test.jl") +include("time_test.jl") include("BeamlinesExt_test.jl") include("alignment_tracking_test.jl") include("aperture_tracking_test.jl") include("ExactTracking_test.jl") include("IntegrationTracking_test.jl") -include("time_test.jl") \ No newline at end of file +