Skip to content

Commit 229a321

Browse files
authored
Merge pull request #16 from uwplasma/development
Add Poincaré plot and sharding
2 parents 0ca8486 + 9cd4cf9 commit 229a321

30 files changed

+1128
-265
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ ESSOS/
9090
│ ├── ESSOS_bio_savart_LandremanPaulQA.json
9191
│ ├── SIMSOPT_bio_savart_LandremanPaulQA.json
9292
│ ├── wout_n3are_R7.75B5.7.nc
93-
│ └── wout_LandremanPaul2021_QA_reactorScale_lowres_reference.nc
93+
│ └── wout_LandremanPaul2021_QA_reactorScale_lowres.nc
9494
├── tests/
9595
│ ├── test_coils.py
9696
│ ├── test_constants.py

essos/__main__.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
from essos.coils import Coils, CreateEquallySpacedCurves
99
from essos.fields import near_axis, BiotSavart
1010
from essos.dynamics import Tracing
11-
from essos.optimization import optimize_coils_for_nearaxis
11+
from essos.optimization import optimize_loss_function
12+
from essos.objective_functions import loss_coils_for_nearaxis
1213

1314

1415
def main(cl_args=sys.argv[1:]):
@@ -53,9 +54,10 @@ def main(cl_args=sys.argv[1:]):
5354
# Optimize coils
5455
print(f'Optimizing coils with {maximum_function_evaluations} function evaluations.')
5556
time0 = time()
56-
coils_optimized = optimize_coils_for_nearaxis(field, coils_initial, maximum_function_evaluations=maximum_function_evaluations,
57-
max_coil_length=max_coil_length, max_coil_curvature=max_coil_curvature,
58-
tolerance_optimization=tolerance_optimization)
57+
coils_optimized = optimize_loss_function(loss_coils_for_nearaxis, initial_dofs=coils_initial.x,
58+
coils=coils_initial, tolerance_optimization=tolerance_optimization,
59+
maximum_function_evaluations=maximum_function_evaluations, field_nearaxis=field,
60+
max_coil_length=max_coil_length, max_coil_curvature=max_coil_curvature,)
5961
print(f"Optimization took {time()-time0:.2f} seconds")
6062

6163
# Trace fieldlines
@@ -81,10 +83,10 @@ def main(cl_args=sys.argv[1:]):
8183
ax1 = fig.add_subplot(121, projection='3d')
8284
ax2 = fig.add_subplot(122, projection='3d')
8385
coils_initial.plot(ax=ax1, show=False)
84-
field.plot(ax=ax1, show=False)
86+
field.plot(ax=ax1, show=False, alpha=0.2)
8587
tracing_initial.plot(ax=ax1, show=False)
8688
coils_optimized.plot(ax=ax2, show=False)
87-
field.plot(ax=ax2, show=False)
89+
field.plot(ax=ax2, show=False, alpha=0.2)
8890
tracing_optimized.plot(ax=ax2, show=False)
8991
plt.tight_layout()
9092
plt.show()

essos/coils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -314,8 +314,8 @@ def __init__(self, curves: Curves, currents: jnp.ndarray):
314314
assert jnp.size(currents) == jnp.size(curves.dofs, 0)
315315
super().__init__(curves.dofs, curves.n_segments, curves.nfp, curves.stellsym)
316316
self._currents_scale = jnp.mean(jnp.abs(currents))
317-
self._dofs_currents = currents/self.currents_scale
318-
self._currents = apply_symmetries_to_currents(self._dofs_currents*self.currents_scale, self.nfp, self.stellsym)
317+
self._dofs_currents = currents/self._currents_scale
318+
self._currents = apply_symmetries_to_currents(self._dofs_currents*self._currents_scale, self.nfp, self.stellsym)
319319

320320
def __str__(self):
321321
return f"nfp stellsym order\n{self.nfp} {self.stellsym} {self.order}\n"\

essos/dynamics.py

Lines changed: 118 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,18 @@
22
jax.config.update("jax_enable_x64", True)
33
import jax.numpy as jnp
44
import matplotlib.pyplot as plt
5-
from jax.experimental.shard_map import shard_map
6-
from jax.sharding import Mesh, NamedSharding, PartitionSpec
7-
from jax import jit, vmap, tree_util, random, lax
5+
from jax.sharding import Mesh, PartitionSpec, NamedSharding
6+
from jax import jit, vmap, tree_util, random, lax, device_put
87
from functools import partial
98
from diffrax import diffeqsolve, ODETerm, SaveAt, Tsit5, PIDController, Event
109
from essos.coils import Coils
1110
from essos.fields import BiotSavart, Vmec
1211
from essos.constants import ALPHA_PARTICLE_MASS, ALPHA_PARTICLE_CHARGE, FUSION_ALPHA_PARTICLE_ENERGY
1312
from essos.plot import fix_matplotlib_3d
13+
from essos.util import roots
14+
15+
mesh = Mesh(jax.devices(), ("dev",))
16+
sharding = NamedSharding(mesh, PartitionSpec("dev", None))
1417

1518
def gc_to_fullorbit(field, initial_xyz, initial_vparallel, total_speed, mass, charge, phase_angle_full_orbit=0):
1619
"""
@@ -211,6 +214,7 @@ def compute_energy_fo(trajectory):
211214
def trace(self):
212215
@jit
213216
def compute_trajectory(initial_condition) -> jnp.ndarray:
217+
# initial_condition = initial_condition[0]
214218
if self.model == 'FullOrbit_Boris':
215219
dt=self.maxtime / self.timesteps
216220
def update_state(state, _):
@@ -233,6 +237,8 @@ def update_state(state, _):
233237
_, trajectory = lax.scan(update_state, initial_condition, jnp.arange(len(self.times)-1))
234238
trajectory = jnp.vstack([initial_condition, trajectory])
235239
else:
240+
import warnings
241+
warnings.simplefilter("ignore", category=FutureWarning) # see https://github.com/patrick-kidger/diffrax/issues/445 for explanation
236242
trajectory = diffeqsolve(
237243
self.ODE_term,
238244
t0=0.0,
@@ -245,15 +251,26 @@ def update_state(state, _):
245251
throw=False,
246252
# adjoint=DirectAdjoint(),
247253
stepsize_controller = PIDController(pcoeff=0.4, icoeff=0.3, dcoeff=0, rtol=self.tol_step_size, atol=self.tol_step_size),
248-
# max_steps=1000000,
254+
max_steps=10000000000,
249255
event = Event(self.condition)
250256
).ys
251257
return trajectory
252-
# return jnp.array(vmap(compute_trajectory)(self.initial_conditions))
253-
mesh = Mesh(devices=jax.devices(), axis_names=('workers',))
254-
in_spec = PartitionSpec('workers', None)
255-
return shard_map(vmap(compute_trajectory), mesh,
256-
in_specs=(in_spec,), out_specs=in_spec, check_rep=False)(self.initial_conditions)
258+
259+
# if len(jax.devices())!=len(self.initial_conditions):
260+
# return vmap(compute_trajectory)(self.initial_conditions[:,None,:])
261+
# else:
262+
# # num_devices = len(jax.devices())
263+
# shape = self.initial_conditions.shape
264+
# # distributed_initial_conditions = self.initial_conditions.reshape(num_devices, -1, *shape[1:])
265+
# mesh = Mesh(devices=jax.devices(), axis_names=('workers'))
266+
# in_spec = PartitionSpec('workers') # Distribute along the workers axis
267+
# out_spec = PartitionSpec('workers') # Gather results along the same axis
268+
# return shard_map(compute_trajectory, mesh, in_specs=in_spec, out_specs=out_spec, check_rep=False)(
269+
# self.initial_conditions).reshape((shape[0], self.timesteps, shape[1]))
270+
271+
return jit(vmap(compute_trajectory), in_shardings=sharding, out_shardings=sharding)(
272+
device_put(self.initial_conditions, sharding))
273+
257274
# trajectories = []
258275
# for initial_condition in self.initial_conditions:
259276
# trajectory = compute_trajectory(initial_condition)
@@ -289,12 +306,13 @@ def to_vtk(self, filename):
289306
data = np.array(jnp.concatenate([i*jnp.ones((self.trajectories[i].shape[0], )) for i in range(len(self.trajectories))]))
290307
polyLinesToVTK(filename, x, y, z, pointsPerLine=ppl, pointData={'idx': data})
291308

292-
def plot(self, ax=None, show=True, axis_equal=True, **kwargs):
309+
def plot(self, ax=None, show=True, axis_equal=True, n_trajectories_plot=5, **kwargs):
293310
if ax is None or ax.name != "3d":
294311
fig = plt.figure()
295312
ax = fig.add_subplot(projection='3d')
296313
trajectories_xyz = jnp.array(self.trajectories_xyz)
297-
for i in range(trajectories_xyz.shape[0]):
314+
n_trajectories_plot = jnp.min(jnp.array([n_trajectories_plot, trajectories_xyz.shape[0]]))
315+
for i in random.choice(random.PRNGKey(0), trajectories_xyz.shape[0], (n_trajectories_plot,), replace=False):
298316
ax.plot(trajectories_xyz[i, :, 0], trajectories_xyz[i, :, 1], trajectories_xyz[i, :, 2], linewidth=0.5, **kwargs)
299317
ax.grid(False)
300318
if axis_equal:
@@ -315,101 +333,97 @@ def loss_fraction(self, r_max=0.99):
315333
total_particles_lost = loss_fractions[-1] * len(self.trajectories)
316334
return loss_fractions, total_particles_lost, lost_times
317335

318-
# def get_phi(x, y, phi_last):
319-
# """Compute the toroidal angle phi, ensuring continuity."""
320-
# phi = jnp.arctan2(y, x)
321-
# dphi = phi - phi_last
322-
# return phi - jnp.round(dphi / (2 * jnp.pi)) * (2 * jnp.pi) # Ensure continuity
323-
324-
# @partial(jit, static_argnums=(0, 2))
325-
# def find_poincare_hits(self, traj, phis_poincare):
326-
# """Find points where field lines cross specified phi values."""
327-
# x, y, z = traj[:, 0], traj[:, 1], traj[:, 2]
328-
# phi_values = jnp.unwrap(jnp.arctan2(y, x)) # Ensure continuity
329-
# t_steps = jnp.arange(len(x))
330-
331-
# hits = []
332-
333-
# for phi_target in phis_poincare:
334-
# phi_shifted = phi_values - phi_target # Shifted phi for comparison
335-
# sign_change = (phi_shifted[:-1] * phi_shifted[1:]) < 0 # Detect crossing
336-
337-
# if jnp.any(sign_change):
338-
# crossing_indices = jnp.where(sign_change)[0] # Get indices of crossings
339-
# for idx in crossing_indices:
340-
# # Linear interpolation to estimate exact crossing
341-
# w = (phi_target - phi_values[idx]) / (phi_values[idx + 1] - phi_values[idx])
342-
# t_cross = t_steps[idx] + w * (t_steps[idx + 1] - t_steps[idx])
343-
# x_cross = x[idx] + w * (x[idx + 1] - x[idx])
344-
# y_cross = y[idx] + w * (y[idx + 1] - y[idx])
345-
# z_cross = z[idx] + w * (z[idx + 1] - z[idx])
336+
def poincare_plot(self, shifts = [jnp.pi/2], orientation = 'toroidal', length = 1, ax=None, show=True, color=None, **kwargs):
337+
"""
338+
Plot Poincare plots using scipy to find the roots of an interpolation. Can take particle trace or field lines.
339+
Args:
340+
shifts (list, optional): Apply a linear shift to dependent data. Default is [0].
341+
orientation (str, optional):
342+
'toroidal' - find time values when toroidal angle = shift [0, 2pi].
343+
'z' - find time values where z coordinate = shift. Default is 'toroidal'.
344+
length (float, optional): A way to shorten data. 1 - plot full length, 0.1 - plot 1/10 of data length. Default is 1.
345+
ax (matplotlib.axes._subplots.AxesSubplot, optional): Matplotlib axis to plot on. Default is None.
346+
show (bool, optional): Whether to display the plot. Default is True.
347+
color: Can be time, None or a color to plot Poincaré points
348+
**kwargs: Additional keyword arguments for plotting.
349+
Notes:
350+
- If the data seem ill-behaved, there may not be enough steps in the trace for a good interpolation.
351+
- This will break if there are any NaNs.
352+
- Issues with toroidal interpolation: jnp.arctan2(Y, X) % (2 * jnp.pi) causes distortion in interpolation near phi = 0.
353+
- Maybe determine a lower limit on resolution needed per toroidal turn for "good" results.
354+
To-Do:
355+
- Format colorbars.
356+
"""
357+
kwargs.setdefault('s', 0.5)
358+
if ax is None:
359+
fig = plt.figure()
360+
ax = fig.add_subplot()
361+
shifts = jnp.array(shifts)
362+
plotting_data = []
363+
# from essos.util import roots_scipy
364+
for shift in shifts:
365+
@jit
366+
def compute_trajectory_toroidal(trace):
367+
X,Y,Z = trace[:,:3].T
368+
R = jnp.sqrt(X**2 + Y**2)
369+
phi = jnp.arctan2(Y,X)
370+
phi = jnp.where(shift==0, phi, jnp.abs(phi))
371+
T_slice = roots(self.times, phi, shift = shift)
372+
T_slice = jnp.where(shift==0, jnp.concatenate((T_slice[1::2],T_slice[1::2])), T_slice)
373+
# T_slice = roots_scipy(self.times, phi, shift = shift)
374+
R_slice = jnp.interp(T_slice, self.times, R)
375+
Z_slice = jnp.interp(T_slice, self.times, Z)
376+
return R_slice, Z_slice, T_slice
377+
@jit
378+
def compute_trajectory_z(trace):
379+
X,Y,Z = trace[:,:3].T
380+
T_slice = roots(self.times, Z, shift = shift)
381+
# T_slice = roots_scipy(self.times, Z, shift = shift)
382+
X_slice = jnp.interp(T_slice, self.times, X)
383+
Y_slice = jnp.interp(T_slice, self.times, Y)
384+
return X_slice, Y_slice, T_slice
385+
if orientation == 'toroidal':
386+
# X_slice, Y_slice, T_slice = vmap(compute_trajectory_toroidal)(self.trajectories)
387+
X_slice, Y_slice, T_slice = jit(vmap(compute_trajectory_toroidal), in_shardings=sharding, out_shardings=sharding)(
388+
device_put(self.trajectories, sharding))
389+
elif orientation == 'z':
390+
# X_slice, Y_slice, T_slice = vmap(compute_trajectory_z)(self.trajectories)
391+
X_slice, Y_slice, T_slice = jit(vmap(compute_trajectory_z), in_shardings=sharding, out_shardings=sharding)(
392+
device_put(self.trajectories, sharding))
393+
@partial(jax.vmap, in_axes=(0, 0, 0))
394+
def process_trajectory(X_i, Y_i, T_i):
395+
mask = (T_i[1:] != T_i[:-1])
396+
valid_idx = jnp.nonzero(mask, size=T_i.size - 1)[0] + 1
397+
return X_i[valid_idx], Y_i[valid_idx], T_i[valid_idx]
398+
X_s, Y_s, T_s = process_trajectory(X_slice, Y_slice, T_slice)
399+
length_ = (vmap(len)(X_s) * length).astype(int)
400+
colors = plt.cm.ocean(jnp.linspace(0, 0.8, len(X_s)))
401+
for i in range(len(X_s)):
402+
X_plot, Y_plot = X_s[i][:length_[i]], Y_s[i][:length_[i]]
403+
T_plot = T_s[i][:length_[i]]
404+
plotting_data.append((X_plot, Y_plot, T_plot))
405+
if color == 'time':
406+
hits = ax.scatter(X_plot, Y_plot, c=T_s[i][:length_[i]], **kwargs)
407+
else:
408+
if color is None: c=[colors[i]]
409+
else: c=color
410+
hits = ax.scatter(X_plot, Y_plot, c=c, **kwargs)
346411

347-
# hits.append([t_cross, x_cross, y_cross, z_cross])
348-
349-
# return jnp.array(hits)
350-
351-
# @partial(jit, static_argnums=(0))
352-
# def poincare(self):
353-
# """Compute Poincaré section hits for multiple trajectories."""
354-
# trajectories = self.trajectories # Pass trajectories directly into the function
355-
# phis_poincare = self.phis_poincare # Similarly, use the direct attribute
356-
357-
# # Use vmap to vectorize the calls for each trajectory
358-
# return vmap(self.find_poincare_hits, in_axes=(0, None))(trajectories, tuple(phis_poincare))
359-
360-
# def poincare_plot(self, phis=None, filename=None, res_phi_hits=None, mark_lost=False, aspect='equal', dpi=300, xlims=None,
361-
# ylims=None, s=2, marker='o', show=True):
362-
# import matplotlib.pyplot as plt
363-
364-
# self.phis_poincare = phis
365-
# if res_phi_hits is None:
366-
# res_phi_hits = self.poincare()
367-
# self.res_phi_hits = res_phi_hits
368-
369-
# res_phi_hits = jnp.array(res_phi_hits) # Ensure it's a JAX array
370-
371-
# # Determine number of rows/columns
372-
# nrowcol = int(jnp.ceil(jnp.sqrt(len(phis))))
373-
374-
# # Create subplots
375-
# fig, axs = plt.subplots(nrowcol, nrowcol, figsize=(8, 5))
376-
# axs = axs.ravel() # Flatten for easier indexing
412+
if orientation == 'toroidal':
413+
plt.xlabel('R',fontsize = 18)
414+
plt.ylabel('Z',fontsize = 18)
415+
# plt.title(r'$\phi$ = {:.2f} $\pi$'.format(shift/jnp.pi),fontsize = 20)
416+
elif orientation == 'z':
417+
plt.xlabel('X',fontsize = 18)
418+
plt.xlabel('Y',fontsize = 18)
419+
# plt.title('Z = {:.2f}'.format(shift),fontsize = 20)
420+
plt.axis('equal')
421+
plt.grid()
422+
plt.tight_layout()
423+
if show:
424+
plt.show()
377425

378-
# # Loop over phi values and create plots
379-
# for i, phi in enumerate(phis):
380-
# ax = axs[i]
381-
# ax.set_aspect(aspect)
382-
# ax.set_title(f"$\\phi = {phi/jnp.pi:.2f}\\pi$", loc='left', y=0.0)
383-
# ax.set_xlabel("$r$")
384-
# ax.set_ylabel("$z$")
385-
386-
# if xlims:
387-
# ax.set_xlim(xlims)
388-
# if ylims:
389-
# ax.set_ylim(ylims)
390-
391-
# # Extract points corresponding to this phi
392-
# mask = res_phi_hits[:, 1] == i
393-
# data_this_phi = res_phi_hits[mask]
394-
395-
# if data_this_phi.shape[0] > 0:
396-
# r = jnp.sqrt(data_this_phi[:, 2]**2 + data_this_phi[:, 3]**2)
397-
# z = data_this_phi[:, 4]
398-
399-
# color = 'g' # Default color
400-
# if mark_lost:
401-
# lost = data_this_phi[-1, 1] < 0
402-
# color = 'r' if lost else 'g'
403-
404-
# ax.scatter(r, z, marker=marker, s=s, linewidths=0, c=color)
405-
406-
# ax.grid(True, linewidth=0.5)
407-
408-
# # Adjust layout and save
409-
# plt.tight_layout()
410-
# if filename is not None: plt.savefig(filename, dpi=dpi)
411-
# if show: plt.show()
412-
# plt.close()
426+
return plotting_data
413427

414428
tree_util.register_pytree_node(Tracing,
415429
Tracing._tree_flatten,

0 commit comments

Comments
 (0)