22jax .config .update ("jax_enable_x64" , True )
33import jax .numpy as jnp
44import matplotlib .pyplot as plt
5- from jax .experimental .shard_map import shard_map
6- from jax .sharding import Mesh , NamedSharding , PartitionSpec
7- from jax import jit , vmap , tree_util , random , lax
5+ from jax .sharding import Mesh , PartitionSpec , NamedSharding
6+ from jax import jit , vmap , tree_util , random , lax , device_put
87from functools import partial
98from diffrax import diffeqsolve , ODETerm , SaveAt , Tsit5 , PIDController , Event
109from essos .coils import Coils
1110from essos .fields import BiotSavart , Vmec
1211from essos .constants import ALPHA_PARTICLE_MASS , ALPHA_PARTICLE_CHARGE , FUSION_ALPHA_PARTICLE_ENERGY
1312from essos .plot import fix_matplotlib_3d
13+ from essos .util import roots
14+
15+ mesh = Mesh (jax .devices (), ("dev" ,))
16+ sharding = NamedSharding (mesh , PartitionSpec ("dev" , None ))
1417
1518def gc_to_fullorbit (field , initial_xyz , initial_vparallel , total_speed , mass , charge , phase_angle_full_orbit = 0 ):
1619 """
@@ -211,6 +214,7 @@ def compute_energy_fo(trajectory):
211214 def trace (self ):
212215 @jit
213216 def compute_trajectory (initial_condition ) -> jnp .ndarray :
217+ # initial_condition = initial_condition[0]
214218 if self .model == 'FullOrbit_Boris' :
215219 dt = self .maxtime / self .timesteps
216220 def update_state (state , _ ):
@@ -233,6 +237,8 @@ def update_state(state, _):
233237 _ , trajectory = lax .scan (update_state , initial_condition , jnp .arange (len (self .times )- 1 ))
234238 trajectory = jnp .vstack ([initial_condition , trajectory ])
235239 else :
240+ import warnings
241+ warnings .simplefilter ("ignore" , category = FutureWarning ) # see https://github.com/patrick-kidger/diffrax/issues/445 for explanation
236242 trajectory = diffeqsolve (
237243 self .ODE_term ,
238244 t0 = 0.0 ,
@@ -245,15 +251,26 @@ def update_state(state, _):
245251 throw = False ,
246252 # adjoint=DirectAdjoint(),
247253 stepsize_controller = PIDController (pcoeff = 0.4 , icoeff = 0.3 , dcoeff = 0 , rtol = self .tol_step_size , atol = self .tol_step_size ),
248- # max_steps=1000000 ,
254+ max_steps = 10000000000 ,
249255 event = Event (self .condition )
250256 ).ys
251257 return trajectory
252- # return jnp.array(vmap(compute_trajectory)(self.initial_conditions))
253- mesh = Mesh (devices = jax .devices (), axis_names = ('workers' ,))
254- in_spec = PartitionSpec ('workers' , None )
255- return shard_map (vmap (compute_trajectory ), mesh ,
256- in_specs = (in_spec ,), out_specs = in_spec , check_rep = False )(self .initial_conditions )
258+
259+ # if len(jax.devices())!=len(self.initial_conditions):
260+ # return vmap(compute_trajectory)(self.initial_conditions[:,None,:])
261+ # else:
262+ # # num_devices = len(jax.devices())
263+ # shape = self.initial_conditions.shape
264+ # # distributed_initial_conditions = self.initial_conditions.reshape(num_devices, -1, *shape[1:])
265+ # mesh = Mesh(devices=jax.devices(), axis_names=('workers'))
266+ # in_spec = PartitionSpec('workers') # Distribute along the workers axis
267+ # out_spec = PartitionSpec('workers') # Gather results along the same axis
268+ # return shard_map(compute_trajectory, mesh, in_specs=in_spec, out_specs=out_spec, check_rep=False)(
269+ # self.initial_conditions).reshape((shape[0], self.timesteps, shape[1]))
270+
271+ return jit (vmap (compute_trajectory ), in_shardings = sharding , out_shardings = sharding )(
272+ device_put (self .initial_conditions , sharding ))
273+
257274 # trajectories = []
258275 # for initial_condition in self.initial_conditions:
259276 # trajectory = compute_trajectory(initial_condition)
@@ -289,12 +306,13 @@ def to_vtk(self, filename):
289306 data = np .array (jnp .concatenate ([i * jnp .ones ((self .trajectories [i ].shape [0 ], )) for i in range (len (self .trajectories ))]))
290307 polyLinesToVTK (filename , x , y , z , pointsPerLine = ppl , pointData = {'idx' : data })
291308
292- def plot (self , ax = None , show = True , axis_equal = True , ** kwargs ):
309+ def plot (self , ax = None , show = True , axis_equal = True , n_trajectories_plot = 5 , ** kwargs ):
293310 if ax is None or ax .name != "3d" :
294311 fig = plt .figure ()
295312 ax = fig .add_subplot (projection = '3d' )
296313 trajectories_xyz = jnp .array (self .trajectories_xyz )
297- for i in range (trajectories_xyz .shape [0 ]):
314+ n_trajectories_plot = jnp .min (jnp .array ([n_trajectories_plot , trajectories_xyz .shape [0 ]]))
315+ for i in random .choice (random .PRNGKey (0 ), trajectories_xyz .shape [0 ], (n_trajectories_plot ,), replace = False ):
298316 ax .plot (trajectories_xyz [i , :, 0 ], trajectories_xyz [i , :, 1 ], trajectories_xyz [i , :, 2 ], linewidth = 0.5 , ** kwargs )
299317 ax .grid (False )
300318 if axis_equal :
@@ -315,101 +333,97 @@ def loss_fraction(self, r_max=0.99):
315333 total_particles_lost = loss_fractions [- 1 ] * len (self .trajectories )
316334 return loss_fractions , total_particles_lost , lost_times
317335
318- # def get_phi(x, y, phi_last):
319- # """Compute the toroidal angle phi, ensuring continuity."""
320- # phi = jnp.arctan2(y, x)
321- # dphi = phi - phi_last
322- # return phi - jnp.round(dphi / (2 * jnp.pi)) * (2 * jnp.pi) # Ensure continuity
323-
324- # @partial(jit, static_argnums=(0, 2))
325- # def find_poincare_hits(self, traj, phis_poincare):
326- # """Find points where field lines cross specified phi values."""
327- # x, y, z = traj[:, 0], traj[:, 1], traj[:, 2]
328- # phi_values = jnp.unwrap(jnp.arctan2(y, x)) # Ensure continuity
329- # t_steps = jnp.arange(len(x))
330-
331- # hits = []
332-
333- # for phi_target in phis_poincare:
334- # phi_shifted = phi_values - phi_target # Shifted phi for comparison
335- # sign_change = (phi_shifted[:-1] * phi_shifted[1:]) < 0 # Detect crossing
336-
337- # if jnp.any(sign_change):
338- # crossing_indices = jnp.where(sign_change)[0] # Get indices of crossings
339- # for idx in crossing_indices:
340- # # Linear interpolation to estimate exact crossing
341- # w = (phi_target - phi_values[idx]) / (phi_values[idx + 1] - phi_values[idx])
342- # t_cross = t_steps[idx] + w * (t_steps[idx + 1] - t_steps[idx])
343- # x_cross = x[idx] + w * (x[idx + 1] - x[idx])
344- # y_cross = y[idx] + w * (y[idx + 1] - y[idx])
345- # z_cross = z[idx] + w * (z[idx + 1] - z[idx])
336+ def poincare_plot (self , shifts = [jnp .pi / 2 ], orientation = 'toroidal' , length = 1 , ax = None , show = True , color = None , ** kwargs ):
337+ """
338+ Plot Poincare plots using scipy to find the roots of an interpolation. Can take particle trace or field lines.
339+ Args:
340+ shifts (list, optional): Apply a linear shift to dependent data. Default is [0].
341+ orientation (str, optional):
342+ 'toroidal' - find time values when toroidal angle = shift [0, 2pi].
343+ 'z' - find time values where z coordinate = shift. Default is 'toroidal'.
344+ length (float, optional): A way to shorten data. 1 - plot full length, 0.1 - plot 1/10 of data length. Default is 1.
345+ ax (matplotlib.axes._subplots.AxesSubplot, optional): Matplotlib axis to plot on. Default is None.
346+ show (bool, optional): Whether to display the plot. Default is True.
347+ color: Can be time, None or a color to plot Poincaré points
348+ **kwargs: Additional keyword arguments for plotting.
349+ Notes:
350+ - If the data seem ill-behaved, there may not be enough steps in the trace for a good interpolation.
351+ - This will break if there are any NaNs.
352+ - Issues with toroidal interpolation: jnp.arctan2(Y, X) % (2 * jnp.pi) causes distortion in interpolation near phi = 0.
353+ - Maybe determine a lower limit on resolution needed per toroidal turn for "good" results.
354+ To-Do:
355+ - Format colorbars.
356+ """
357+ kwargs .setdefault ('s' , 0.5 )
358+ if ax is None :
359+ fig = plt .figure ()
360+ ax = fig .add_subplot ()
361+ shifts = jnp .array (shifts )
362+ plotting_data = []
363+ # from essos.util import roots_scipy
364+ for shift in shifts :
365+ @jit
366+ def compute_trajectory_toroidal (trace ):
367+ X ,Y ,Z = trace [:,:3 ].T
368+ R = jnp .sqrt (X ** 2 + Y ** 2 )
369+ phi = jnp .arctan2 (Y ,X )
370+ phi = jnp .where (shift == 0 , phi , jnp .abs (phi ))
371+ T_slice = roots (self .times , phi , shift = shift )
372+ T_slice = jnp .where (shift == 0 , jnp .concatenate ((T_slice [1 ::2 ],T_slice [1 ::2 ])), T_slice )
373+ # T_slice = roots_scipy(self.times, phi, shift = shift)
374+ R_slice = jnp .interp (T_slice , self .times , R )
375+ Z_slice = jnp .interp (T_slice , self .times , Z )
376+ return R_slice , Z_slice , T_slice
377+ @jit
378+ def compute_trajectory_z (trace ):
379+ X ,Y ,Z = trace [:,:3 ].T
380+ T_slice = roots (self .times , Z , shift = shift )
381+ # T_slice = roots_scipy(self.times, Z, shift = shift)
382+ X_slice = jnp .interp (T_slice , self .times , X )
383+ Y_slice = jnp .interp (T_slice , self .times , Y )
384+ return X_slice , Y_slice , T_slice
385+ if orientation == 'toroidal' :
386+ # X_slice, Y_slice, T_slice = vmap(compute_trajectory_toroidal)(self.trajectories)
387+ X_slice , Y_slice , T_slice = jit (vmap (compute_trajectory_toroidal ), in_shardings = sharding , out_shardings = sharding )(
388+ device_put (self .trajectories , sharding ))
389+ elif orientation == 'z' :
390+ # X_slice, Y_slice, T_slice = vmap(compute_trajectory_z)(self.trajectories)
391+ X_slice , Y_slice , T_slice = jit (vmap (compute_trajectory_z ), in_shardings = sharding , out_shardings = sharding )(
392+ device_put (self .trajectories , sharding ))
393+ @partial (jax .vmap , in_axes = (0 , 0 , 0 ))
394+ def process_trajectory (X_i , Y_i , T_i ):
395+ mask = (T_i [1 :] != T_i [:- 1 ])
396+ valid_idx = jnp .nonzero (mask , size = T_i .size - 1 )[0 ] + 1
397+ return X_i [valid_idx ], Y_i [valid_idx ], T_i [valid_idx ]
398+ X_s , Y_s , T_s = process_trajectory (X_slice , Y_slice , T_slice )
399+ length_ = (vmap (len )(X_s ) * length ).astype (int )
400+ colors = plt .cm .ocean (jnp .linspace (0 , 0.8 , len (X_s )))
401+ for i in range (len (X_s )):
402+ X_plot , Y_plot = X_s [i ][:length_ [i ]], Y_s [i ][:length_ [i ]]
403+ T_plot = T_s [i ][:length_ [i ]]
404+ plotting_data .append ((X_plot , Y_plot , T_plot ))
405+ if color == 'time' :
406+ hits = ax .scatter (X_plot , Y_plot , c = T_s [i ][:length_ [i ]], ** kwargs )
407+ else :
408+ if color is None : c = [colors [i ]]
409+ else : c = color
410+ hits = ax .scatter (X_plot , Y_plot , c = c , ** kwargs )
346411
347- # hits.append([t_cross, x_cross, y_cross, z_cross])
348-
349- # return jnp.array(hits)
350-
351- # @partial(jit, static_argnums=(0))
352- # def poincare(self):
353- # """Compute Poincaré section hits for multiple trajectories."""
354- # trajectories = self.trajectories # Pass trajectories directly into the function
355- # phis_poincare = self.phis_poincare # Similarly, use the direct attribute
356-
357- # # Use vmap to vectorize the calls for each trajectory
358- # return vmap(self.find_poincare_hits, in_axes=(0, None))(trajectories, tuple(phis_poincare))
359-
360- # def poincare_plot(self, phis=None, filename=None, res_phi_hits=None, mark_lost=False, aspect='equal', dpi=300, xlims=None,
361- # ylims=None, s=2, marker='o', show=True):
362- # import matplotlib.pyplot as plt
363-
364- # self.phis_poincare = phis
365- # if res_phi_hits is None:
366- # res_phi_hits = self.poincare()
367- # self.res_phi_hits = res_phi_hits
368-
369- # res_phi_hits = jnp.array(res_phi_hits) # Ensure it's a JAX array
370-
371- # # Determine number of rows/columns
372- # nrowcol = int(jnp.ceil(jnp.sqrt(len(phis))))
373-
374- # # Create subplots
375- # fig, axs = plt.subplots(nrowcol, nrowcol, figsize=(8, 5))
376- # axs = axs.ravel() # Flatten for easier indexing
412+ if orientation == 'toroidal' :
413+ plt .xlabel ('R' ,fontsize = 18 )
414+ plt .ylabel ('Z' ,fontsize = 18 )
415+ # plt.title(r'$\phi$ = {:.2f} $\pi$'.format(shift/jnp.pi),fontsize = 20)
416+ elif orientation == 'z' :
417+ plt .xlabel ('X' ,fontsize = 18 )
418+ plt .xlabel ('Y' ,fontsize = 18 )
419+ # plt.title('Z = {:.2f}'.format(shift),fontsize = 20)
420+ plt .axis ('equal' )
421+ plt .grid ()
422+ plt .tight_layout ()
423+ if show :
424+ plt .show ()
377425
378- # # Loop over phi values and create plots
379- # for i, phi in enumerate(phis):
380- # ax = axs[i]
381- # ax.set_aspect(aspect)
382- # ax.set_title(f"$\\phi = {phi/jnp.pi:.2f}\\pi$", loc='left', y=0.0)
383- # ax.set_xlabel("$r$")
384- # ax.set_ylabel("$z$")
385-
386- # if xlims:
387- # ax.set_xlim(xlims)
388- # if ylims:
389- # ax.set_ylim(ylims)
390-
391- # # Extract points corresponding to this phi
392- # mask = res_phi_hits[:, 1] == i
393- # data_this_phi = res_phi_hits[mask]
394-
395- # if data_this_phi.shape[0] > 0:
396- # r = jnp.sqrt(data_this_phi[:, 2]**2 + data_this_phi[:, 3]**2)
397- # z = data_this_phi[:, 4]
398-
399- # color = 'g' # Default color
400- # if mark_lost:
401- # lost = data_this_phi[-1, 1] < 0
402- # color = 'r' if lost else 'g'
403-
404- # ax.scatter(r, z, marker=marker, s=s, linewidths=0, c=color)
405-
406- # ax.grid(True, linewidth=0.5)
407-
408- # # Adjust layout and save
409- # plt.tight_layout()
410- # if filename is not None: plt.savefig(filename, dpi=dpi)
411- # if show: plt.show()
412- # plt.close()
426+ return plotting_data
413427
414428tree_util .register_pytree_node (Tracing ,
415429 Tracing ._tree_flatten ,
0 commit comments