Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 15 additions & 12 deletions src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,17 @@ const STATE_LOST_PZ = UInt8(7)
const STATE_LOST_Z = UInt8(8)

# Always SOA
struct Coords{S,V,Q}
struct Coords{S,V,Q,W}
state::S # Array of particle states
v::V # Matrix of particle coordinates
q::Q # Matrix of particle quaternions if spin else nothing
function Coords(state, v, q)
weight::W # Array of particle weights if weighted else nothing
function Coords(state, v, q, weight)
if !isnothing(q) && eltype(v) != eltype(q)
error("Cannot initialize Coords with orbital coordinates of type $(eltype(v))
and quaternion coordinates of type $(typeof(q)).")
end
return new{typeof(state),typeof(v),typeof(q)}(state, v, q)
return new{typeof(state),typeof(v),typeof(q),typeof(weight)}(state, v, q, weight)
end
end

Expand All @@ -45,42 +46,44 @@ Adapt.@adapt_structure Coords

get_N_particle(bunch::Bunch) = size(bunch.coords.v, 1)

function Bunch(N::Integer; p_over_q_ref=NaN, t_ref=0., species=Species(), spin=false)
function Bunch(N::Integer; p_over_q_ref=NaN, t_ref=0., species=Species(), spin=false, weight=nothing)
v = rand(N,6)
q = spin ? rand(N,4) : nothing
state = similar(v, UInt8, N)
state .= STATE_ALIVE
return Bunch(species, p_over_q_ref, t_ref, Coords(state, v, q))
return Bunch(species, p_over_q_ref, t_ref, Coords(state, v, q, weight))
end

function Bunch(v::AbstractMatrix, q=nothing; p_over_q_ref=NaN, t_ref=0., species=Species())
function Bunch(v::AbstractMatrix, q=nothing, weight=nothing; p_over_q_ref=NaN, t_ref=0., species=Species())
size(v, 2) == 6 || error("The number of columns must be equal to 6")
N_particle = size(v, 1)
state = similar(v, UInt8, N_particle)
state .= STATE_ALIVE
return Bunch(species, p_over_q_ref, t_ref, Coords(state, v, q))
return Bunch(species, p_over_q_ref, t_ref, Coords(state, v, q, weight))
end

function Bunch(v::AbstractVector, q=nothing; p_over_q_ref=NaN, t_ref=0., species=Species())
function Bunch(v::AbstractVector, q=nothing, weight=nothing; p_over_q_ref=NaN, t_ref=0., species=Species())
length(v) == 6 || error("Bunch accepts a N x 6 matrix of N particle coordinates,
or alternatively a single particle as a vector. Received
a vector of length $(length(v))")
return Bunch(reshape(v, (1,6)), q; p_over_q_ref=p_over_q_ref, t_ref=t_ref, species=species)
return Bunch(reshape(v, (1,6)), q, weight; p_over_q_ref=p_over_q_ref, t_ref=t_ref, species=species)
end

struct ParticleView{B,T,S,V,Q}
struct ParticleView{B,T,S,V,Q,W}
index::Int
species::Species
p_over_q_ref::B
t_ref::T
state::S
v::V
q::Q
q::Q
weight::W
ParticleView(args...) = new{typeof.(args)...}(args...)
end

function ParticleView(bunch::Bunch, i=1)
v = bunch.coords.v
q = bunch.coords.q
return ParticleView(i, bunch.species, bunch.p_over_q_ref, bunch.t_ref, bunch.coords.state[i], view(v, :, i), isnothing(q) ? q : view(q, :, i))
weight = bunch.coords.weight
return ParticleView(i, bunch.species, bunch.p_over_q_ref, bunch.t_ref, bunch.coords.state[i], view(v, :, i), isnothing(q) ? q : view(q, :, i), isnothing(weight) ? weight : weight[i])
end
8 changes: 4 additions & 4 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ function test_matrix(
v = transpose(@vars(D1))
state = similar(v, UInt8, 1)
state .= STATE_ALIVE
coords = Coords(state, v, nothing)
coords = Coords(state, v, nothing, nothing)

# Set up kernel chain and launch!
BeamTracking.launch!(coords, kernel_call)
Expand All @@ -60,7 +60,7 @@ function test_matrix(
q = repeat([1.0 0.0 0.0 0.0], 2)
state = [STATE_ALIVE STATE_ALIVE]
@test @ballocated(BeamTracking.launch!(coords, $kernel_call; use_KA=false),
setup=(coords = Coords(copy($state), copy($v), copy($q)))) == 0
setup=(coords = Coords(copy($state), copy($v), copy($q), nothing))) == 0
end
end

Expand Down Expand Up @@ -104,7 +104,7 @@ function test_map(
q = TPS64{D10}[1 0 0 0]
state = similar(v, UInt8, 1)
state .= STATE_ALIVE
coords = Coords(state, v, q)
coords = Coords(state, v, q, nothing)

# Set up kernel chain and launch!
BeamTracking.launch!(coords, kernel_call)
Expand All @@ -121,7 +121,7 @@ function test_map(
q = repeat([1.0 0.0 0.0 0.0], 2)
state = [STATE_ALIVE STATE_ALIVE]
@test @ballocated(BeamTracking.launch!(coords, $kernel_call; use_KA=false),
setup=(coords = Coords(copy($state), copy($v), copy($q)))) == 0
setup=(coords = Coords(copy($state), copy($v), copy($q), nothing))) == 0
end


Expand Down