diff --git a/src/types.jl b/src/types.jl index c6f065f9..e54d2c52 100644 --- a/src/types.jl +++ b/src/types.jl @@ -20,16 +20,17 @@ const STATE_LOST_PZ = UInt8(7) const STATE_LOST_Z = UInt8(8) # Always SOA -struct Coords{S,V,Q} +struct Coords{S,V,Q,W} state::S # Array of particle states v::V # Matrix of particle coordinates q::Q # Matrix of particle quaternions if spin else nothing - function Coords(state, v, q) + weight::W # Array of particle weights if weighted else nothing + function Coords(state, v, q, weight) if !isnothing(q) && eltype(v) != eltype(q) error("Cannot initialize Coords with orbital coordinates of type $(eltype(v)) and quaternion coordinates of type $(typeof(q)).") end - return new{typeof(state),typeof(v),typeof(q)}(state, v, q) + return new{typeof(state),typeof(v),typeof(q),typeof(weight)}(state, v, q, weight) end end @@ -45,42 +46,44 @@ Adapt.@adapt_structure Coords get_N_particle(bunch::Bunch) = size(bunch.coords.v, 1) -function Bunch(N::Integer; p_over_q_ref=NaN, t_ref=0., species=Species(), spin=false) +function Bunch(N::Integer; p_over_q_ref=NaN, t_ref=0., species=Species(), spin=false, weight=nothing) v = rand(N,6) q = spin ? rand(N,4) : nothing state = similar(v, UInt8, N) state .= STATE_ALIVE - return Bunch(species, p_over_q_ref, t_ref, Coords(state, v, q)) + return Bunch(species, p_over_q_ref, t_ref, Coords(state, v, q, weight)) end -function Bunch(v::AbstractMatrix, q=nothing; p_over_q_ref=NaN, t_ref=0., species=Species()) +function Bunch(v::AbstractMatrix, q=nothing, weight=nothing; p_over_q_ref=NaN, t_ref=0., species=Species()) size(v, 2) == 6 || error("The number of columns must be equal to 6") N_particle = size(v, 1) state = similar(v, UInt8, N_particle) state .= STATE_ALIVE - return Bunch(species, p_over_q_ref, t_ref, Coords(state, v, q)) + return Bunch(species, p_over_q_ref, t_ref, Coords(state, v, q, weight)) end -function Bunch(v::AbstractVector, q=nothing; p_over_q_ref=NaN, t_ref=0., species=Species()) +function Bunch(v::AbstractVector, q=nothing, weight=nothing; p_over_q_ref=NaN, t_ref=0., species=Species()) length(v) == 6 || error("Bunch accepts a N x 6 matrix of N particle coordinates, or alternatively a single particle as a vector. Received a vector of length $(length(v))") - return Bunch(reshape(v, (1,6)), q; p_over_q_ref=p_over_q_ref, t_ref=t_ref, species=species) + return Bunch(reshape(v, (1,6)), q, weight; p_over_q_ref=p_over_q_ref, t_ref=t_ref, species=species) end -struct ParticleView{B,T,S,V,Q} +struct ParticleView{B,T,S,V,Q,W} index::Int species::Species p_over_q_ref::B t_ref::T state::S v::V - q::Q + q::Q + weight::W ParticleView(args...) = new{typeof.(args)...}(args...) end function ParticleView(bunch::Bunch, i=1) v = bunch.coords.v q = bunch.coords.q - return ParticleView(i, bunch.species, bunch.p_over_q_ref, bunch.t_ref, bunch.coords.state[i], view(v, :, i), isnothing(q) ? q : view(q, :, i)) + weight = bunch.coords.weight + return ParticleView(i, bunch.species, bunch.p_over_q_ref, bunch.t_ref, bunch.coords.state[i], view(v, :, i), isnothing(q) ? q : view(q, :, i), isnothing(weight) ? weight : weight[i]) end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 7fefaea6..e245ae87 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -33,7 +33,7 @@ function test_matrix( v = transpose(@vars(D1)) state = similar(v, UInt8, 1) state .= STATE_ALIVE - coords = Coords(state, v, nothing) + coords = Coords(state, v, nothing, nothing) # Set up kernel chain and launch! BeamTracking.launch!(coords, kernel_call) @@ -60,7 +60,7 @@ function test_matrix( q = repeat([1.0 0.0 0.0 0.0], 2) state = [STATE_ALIVE STATE_ALIVE] @test @ballocated(BeamTracking.launch!(coords, $kernel_call; use_KA=false), - setup=(coords = Coords(copy($state), copy($v), copy($q)))) == 0 + setup=(coords = Coords(copy($state), copy($v), copy($q), nothing))) == 0 end end @@ -104,7 +104,7 @@ function test_map( q = TPS64{D10}[1 0 0 0] state = similar(v, UInt8, 1) state .= STATE_ALIVE - coords = Coords(state, v, q) + coords = Coords(state, v, q, nothing) # Set up kernel chain and launch! BeamTracking.launch!(coords, kernel_call) @@ -121,7 +121,7 @@ function test_map( q = repeat([1.0 0.0 0.0 0.0], 2) state = [STATE_ALIVE STATE_ALIVE] @test @ballocated(BeamTracking.launch!(coords, $kernel_call; use_KA=false), - setup=(coords = Coords(copy($state), copy($v), copy($q)))) == 0 + setup=(coords = Coords(copy($state), copy($v), copy($q), nothing))) == 0 end