Skip to content

Commit 7f43111

Browse files
authored
Merge pull request #7 from uwplasma/rj/nearaxis_plotboundary
Merge Rj/nearaxis plotboundary
2 parents 9ea3edb + 556e52b commit 7f43111

File tree

4 files changed

+243
-32
lines changed

4 files changed

+243
-32
lines changed

essos/dynamics.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,11 @@ def update_state(state, _):
254254
in_spec = PartitionSpec('workers', None)
255255
return shard_map(vmap(compute_trajectory), mesh,
256256
in_specs=(in_spec,), out_specs=in_spec, check_rep=False)(self.initial_conditions)
257+
# trajectories = []
258+
# for initial_condition in self.initial_conditions:
259+
# trajectory = compute_trajectory(initial_condition)
260+
# trajectories.append(trajectory)
261+
# return jnp.array(trajectories)
257262

258263
@property
259264
def trajectories(self):

essos/fields.py

Lines changed: 234 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
jax.config.update("jax_enable_x64", True)
33
import jax.numpy as jnp
44
from functools import partial
5-
from jax import jit, jacfwd, grad, vmap, tree_util
5+
from jax import jit, jacfwd, grad, vmap, tree_util, lax
66
from essos.surfaces import SurfaceRZFourier
77
from essos.plot import fix_matplotlib_3d
88

@@ -184,12 +184,34 @@ def to_xyz(self, points):
184184
Y = R * jnp.sin(phi)
185185
return jnp.array([X, Y, Z])
186186

187+
def newton(f, x0):
188+
"""Newton's method for root-finding."""
189+
initial_state = (0, x0) # (iteration, x)
190+
191+
def cond(state):
192+
it, x = state
193+
# We fix 30 iterations for simplicity, this is plenty for convergence in our tests.
194+
return (it < 30)
195+
196+
def body(state):
197+
it, x = state
198+
fx, dfx = f(x), jax.grad(f)(x)
199+
step = fx / dfx
200+
new_state = it + 1, x - step
201+
return new_state
202+
203+
return jax.lax.while_loop(
204+
cond,
205+
body,
206+
initial_state,
207+
)[1]
208+
187209
class near_axis():
188210
def __init__(self, rc=jnp.array([1, 0.1]), zs=jnp.array([0, 0.1]), etabar=1.0,
189-
B0=1, sigma0=0, I2=0, nphi=31, spsi=1, sG=1, nfp=2):
211+
B0=1, sigma0=0, I2=0, nphi=31, spsi=1, sG=1, nfp=2, order='r1', B2c=0, p2=0):
190212
assert nphi % 2 == 1, 'nphi must be odd'
191-
self.rc = rc
192-
self.zs = zs
213+
self.rc = jnp.array(rc)
214+
self.zs = jnp.array(zs)
193215
self.etabar = etabar
194216
self.nphi = nphi
195217
self.sigma0 = sigma0
@@ -198,27 +220,36 @@ def __init__(self, rc=jnp.array([1, 0.1]), zs=jnp.array([0, 0.1]), etabar=1.0,
198220
self.sG = sG
199221
self.B0 = B0
200222
self.nfp = nfp
223+
self.order = order # not used
224+
self.B2c = B2c # not used
225+
self.p2 = p2 # not used
201226

202-
self._dofs = jnp.concatenate((jnp.ravel(rc), jnp.ravel(zs), jnp.array([etabar])))
227+
self._dofs = jnp.concatenate((jnp.ravel(self.rc), jnp.ravel(self.zs), jnp.array([etabar])))
203228

204229
self.phi = jnp.linspace(0, 2 * jnp.pi / self.nfp, self.nphi, endpoint=False)
205230
self.nfourier = max(len(self.rc), len(self.zs))
206231

207232
parameters = self.calculate(self.rc, self.zs, self.etabar)
208-
self.R0, self.Z0, self.sigma, self.elongation, self.B_axis, self.grad_B_axis, self.axis_length, self.iota, self.iotaN, self.G0 = parameters
233+
(self.R0, self.Z0, self.sigma, self.elongation, self.B_axis, self.grad_B_axis, self.axis_length, self.iota, self.iotaN, self.G0,
234+
self.helicity, self.X1c_untwisted, self.X1s_untwisted, self.Y1s_untwisted, self.Y1c_untwisted,
235+
self.normal_R, self.normal_phi, self.normal_z, self.binormal_R, self.binormal_phi, self.binormal_z,
236+
self.L_grad_B, self.inv_L_grad_B, self.torsion) = parameters
209237

210238
@property
211239
def dofs(self):
212240
return self._dofs
213241

214242
@dofs.setter
215243
def dofs(self, new_dofs):
216-
self._dofs = new_dofs
217-
self.rc = new_dofs[:len(self.rc)]
218-
self.zs = new_dofs[len(self.rc):len(self.rc)+len(self.zs)]
219-
self.etabar = new_dofs[-1]
244+
self._dofs = jnp.array(new_dofs)
245+
self.rc = self._dofs[:self.nfourier]
246+
self.zs = self._dofs[self.nfourier:2*self.nfourier]
247+
self.etabar = self._dofs[-1]
220248
parameters = self.calculate(self.rc, self.zs, self.etabar)
221-
self.R0, self.Z0, self.sigma, self.elongation, self.B_axis, self.grad_B_axis, self.axis_length, self.iota, self.iotaN, self.G0 = parameters
249+
(self.R0, self.Z0, self.sigma, self.elongation, self.B_axis, self.grad_B_axis, self.axis_length, self.iota, self.iotaN, self.G0,
250+
self.helicity, self.X1c_untwisted, self.X1s_untwisted, self.Y1s_untwisted, self.Y1c_untwisted,
251+
self.normal_R, self.normal_z, self.normal_phi, self.binormal_R, self.binormal_z, self.binormal_phi,
252+
self.L_grad_B, self.inv_L_grad_B, self.torsion) = parameters
222253

223254
@property
224255
def x(self):
@@ -230,7 +261,8 @@ def x(self, new_x):
230261

231262
def _tree_flatten(self):
232263
children = (self.rc, self.zs, self.etabar, self.B0, self.sigma0, self.I2) # arrays / dynamic values
233-
aux_data = {"nphi": self.nphi, "spsi": self.spsi, "sG": self.sG, "nfp": self.nfp} # static values
264+
aux_data = {"nphi": self.nphi, "spsi": self.spsi, "sG": self.sG,
265+
"nfp": self.nfp, "order": self.order, "B2c": self.B2c, "p2": self.p2} # static values
234266
return (children, aux_data)
235267

236268
@classmethod
@@ -440,26 +472,202 @@ def body_fun(i, x):
440472
nablaB[2, 2]]
441473
])
442474

443-
return R0, Z0, sigma, elongation, B_axis, grad_B_axis, axis_length, iota, iotaN, G0
475+
grad_B_colon_grad_B = tn * tn + nt * nt \
476+
+ bb * bb + nn * nn \
477+
+ nb * nb + bn * bn \
478+
+ tt * tt
479+
L_grad_B = self.B0 * jnp.sqrt(2 / grad_B_colon_grad_B)
480+
inv_L_grad_B = 1.0 / L_grad_B
481+
482+
X1c_untwisted = jnp.where(helicity == 0, X1c, X1c * jnp.cos(-helicity * nfp * varphi))
483+
X1s_untwisted = jnp.where(helicity == 0, 0 * X1c, X1c * jnp.sin(-helicity * nfp * varphi))
484+
Y1s_untwisted = jnp.where(helicity == 0, Y1s, Y1s * jnp.cos(-helicity * nfp * varphi) + Y1c * jnp.sin(-helicity * nfp * varphi))
485+
Y1c_untwisted = jnp.where(helicity == 0, Y1c, Y1s * (-jnp.sin(-helicity * nfp * varphi)) + Y1c * jnp.cos(-helicity * nfp * varphi))
486+
487+
normal_R = normal_cylindrical[:,0]
488+
normal_phi = normal_cylindrical[:,1]
489+
normal_z = normal_cylindrical[:,2]
490+
binormal_R = binormal_cylindrical[:,0]
491+
binormal_phi = binormal_cylindrical[:,1]
492+
binormal_z = binormal_cylindrical[:,2]
493+
494+
return (R0, Z0, sigma, elongation, B_axis, grad_B_axis, axis_length, iota, iotaN, G0,
495+
helicity, X1c_untwisted, X1s_untwisted, Y1s_untwisted, Y1c_untwisted,
496+
normal_R, normal_phi, normal_z, binormal_R, binormal_phi, binormal_z,
497+
L_grad_B, inv_L_grad_B, torsion)
498+
499+
@jit
500+
def interpolated_array_at_point(self,array,point):
501+
sp=jnp.interp(jnp.array([point]), jnp.append(self.phi,2*jnp.pi/self.nfp), jnp.append(array,array[0]), period=2*jnp.pi/self.nfp)[0]
502+
# sp=interpax.interp1d(jnp.array([point]), jnp.append(self.phi,2*jnp.pi/self.nfp), jnp.append(array,array[0]), method="cubic", period=2*jnp.pi/self.nfp)[0]
503+
return sp
504+
505+
@jit
506+
def Frenet_to_cylindrical_residual_func(self,phi0, phi_target, X_at_this_theta, Y_at_this_theta):
507+
sinphi0 = jnp.sin(phi0)
508+
cosphi0 = jnp.cos(phi0)
509+
R0_at_phi0 = self.interpolated_array_at_point(self.R0,phi0)
510+
X_at_phi0 = self.interpolated_array_at_point(X_at_this_theta,phi0)
511+
Y_at_phi0 = self.interpolated_array_at_point(Y_at_this_theta,phi0)
512+
normal_R = self.interpolated_array_at_point(self.normal_R,phi0)
513+
normal_phi = self.interpolated_array_at_point(self.normal_phi,phi0)
514+
binormal_R = self.interpolated_array_at_point(self.binormal_R,phi0)
515+
binormal_phi = self.interpolated_array_at_point(self.binormal_phi,phi0)
516+
normal_x = normal_R * cosphi0 - normal_phi * sinphi0
517+
normal_y = normal_R * sinphi0 + normal_phi * cosphi0
518+
binormal_x = binormal_R * cosphi0 - binormal_phi * sinphi0
519+
binormal_y = binormal_R * sinphi0 + binormal_phi * cosphi0
520+
total_x = R0_at_phi0 * cosphi0 + X_at_phi0 * normal_x + Y_at_phi0 * binormal_x
521+
total_y = R0_at_phi0 * sinphi0 + X_at_phi0 * normal_y + Y_at_phi0 * binormal_y
522+
Frenet_to_cylindrical_residual = jnp.arctan2(total_y, total_x) - phi_target
523+
Frenet_to_cylindrical_residual = jnp.where(Frenet_to_cylindrical_residual > jnp.pi, Frenet_to_cylindrical_residual - 2 * jnp.pi, Frenet_to_cylindrical_residual)
524+
Frenet_to_cylindrical_residual = jnp.where(Frenet_to_cylindrical_residual <-jnp.pi, Frenet_to_cylindrical_residual + 2 * jnp.pi, Frenet_to_cylindrical_residual)
525+
return Frenet_to_cylindrical_residual
526+
527+
@jit
528+
def Frenet_to_cylindrical_1_point(self, phi0, X_at_this_theta, Y_at_this_theta):
529+
sinphi0 = jnp.sin(phi0)
530+
cosphi0 = jnp.cos(phi0)
531+
R0_at_phi0 = self.interpolated_array_at_point(self.R0,phi0)
532+
z0_at_phi0 = self.interpolated_array_at_point(self.Z0,phi0)
533+
X_at_phi0 = self.interpolated_array_at_point(X_at_this_theta,phi0)
534+
Y_at_phi0 = self.interpolated_array_at_point(Y_at_this_theta,phi0)
535+
normal_R = self.interpolated_array_at_point(self.normal_R,phi0)
536+
normal_phi = self.interpolated_array_at_point(self.normal_phi,phi0)
537+
normal_z = self.interpolated_array_at_point(self.normal_z,phi0)
538+
binormal_R = self.interpolated_array_at_point(self.binormal_R,phi0)
539+
binormal_phi = self.interpolated_array_at_point(self.binormal_phi,phi0)
540+
binormal_z = self.interpolated_array_at_point(self.binormal_z,phi0)
541+
normal_x = normal_R * cosphi0 - normal_phi * sinphi0
542+
normal_y = normal_R * sinphi0 + normal_phi * cosphi0
543+
binormal_x = binormal_R * cosphi0 - binormal_phi * sinphi0
544+
binormal_y = binormal_R * sinphi0 + binormal_phi * cosphi0
545+
total_x = R0_at_phi0 * cosphi0 + X_at_phi0 * normal_x + Y_at_phi0 * binormal_x
546+
total_y = R0_at_phi0 * sinphi0 + X_at_phi0 * normal_y + Y_at_phi0 * binormal_y
547+
total_z = z0_at_phi0 + X_at_phi0 * normal_z + Y_at_phi0 * binormal_z
548+
total_R = jnp.sqrt(total_x * total_x + total_y * total_y)
549+
total_phi=jnp.arctan2(total_y, total_x)
550+
return total_R, total_z, total_phi
551+
552+
@partial(jit, static_argnames=['ntheta'])
553+
def Frenet_to_cylindrical(self, r, ntheta=20):
554+
nphi_conversion = self.nphi
555+
theta = jnp.linspace(0, 2 * jnp.pi, ntheta, endpoint=False)
556+
phi_conversion = jnp.linspace(0, 2 * jnp.pi / self.nfp, nphi_conversion, endpoint=False)
557+
558+
def compute_for_theta(theta_j):
559+
costheta = jnp.cos(theta_j)
560+
sintheta = jnp.sin(theta_j)
561+
X_at_this_theta = r * (self.X1c_untwisted * costheta + self.X1s_untwisted * sintheta)
562+
Y_at_this_theta = r * (self.Y1c_untwisted * costheta + self.Y1s_untwisted * sintheta)
563+
564+
def compute_for_phi(phi_target):
565+
residual = partial(self.Frenet_to_cylindrical_residual_func, phi_target=phi_target,
566+
X_at_this_theta=X_at_this_theta, Y_at_this_theta=Y_at_this_theta)
567+
phi0_solution = lax.custom_root(residual, phi_target, newton, lambda g, y: y / g(1.0))
568+
final_R, final_Z, _ = self.Frenet_to_cylindrical_1_point(phi0_solution, X_at_this_theta, Y_at_this_theta)
569+
return final_R, final_Z, phi0_solution
570+
571+
return vmap(compute_for_phi)(phi_conversion)
572+
573+
R_2D, Z_2D, phi0_2D = vmap(compute_for_theta)(theta)
574+
return R_2D, Z_2D, phi0_2D
575+
576+
577+
@partial(jit, static_argnames=['mpol', 'ntor'])
578+
def to_Fourier(self, R_2D, Z_2D, nfp, mpol, ntor):
579+
ntheta, nphi_conversion = R_2D.shape
580+
theta = jnp.linspace(0, 2 * jnp.pi, ntheta, endpoint=False)
581+
phi_conversion = jnp.linspace(0, 2 * jnp.pi / nfp, nphi_conversion, endpoint=False)
444582

445-
def plot(self, ax=None, show=True, close=False, axis_equal=True, **kwargs):
446-
if close: raise NotImplementedError("close=True is not implemented, need to have closed surfaces")
583+
phi2d, theta2d = jnp.meshgrid(phi_conversion, theta, indexing='xy')
584+
factor = 2 / (ntheta * nphi_conversion)
585+
586+
def compute_RBC_ZBS(m, n):
587+
angle = m * theta2d - n * nfp * phi2d
588+
sinangle, cosangle = jnp.sin(angle), jnp.cos(angle)
447589

590+
factor2 = jax.lax.cond(
591+
(ntheta % 2 == 0) & (m == (ntheta / 2)),
592+
lambda _: factor / 2, lambda _: factor,
593+
operand=None)
594+
595+
factor2 = jax.lax.cond(
596+
(nphi_conversion % 2 == 0) & (abs(n) == (nphi_conversion / 2)),
597+
lambda _: factor2 / 2, lambda _: factor2,
598+
operand=None)
599+
600+
return jnp.sum(R_2D * cosangle * factor2), jnp.sum(Z_2D * sinangle * factor2)
601+
602+
m_vals = jnp.arange(mpol + 1)
603+
n_vals = jnp.concatenate([jnp.array([1]), jnp.arange(-ntor, ntor + 1)]) if mpol == 0 else jnp.arange(-ntor, ntor + 1)
604+
RBC, ZBS = vmap(lambda n: vmap(lambda m: compute_RBC_ZBS(m, n))(m_vals))(n_vals)
605+
606+
RBC = RBC.at[ntor, 0].set(jnp.sum(R_2D) / (ntheta * nphi_conversion))
607+
ZBS = ZBS.at[:ntor, 0].set(0)
608+
RBC = RBC.at[:ntor, 0].set(0)
609+
return RBC, ZBS
610+
611+
612+
@partial(jit, static_argnames=['ntheta_fourier', 'mpol', 'ntor', 'ntheta', 'nphi'])
613+
def get_boundary(self, r=0.1, ntheta=30, nphi=130, ntheta_fourier=20, mpol=5, ntor=5):
614+
R_2D, Z_2D, _ = self.Frenet_to_cylindrical(r, ntheta=ntheta_fourier)
615+
RBC, ZBS = self.to_Fourier(R_2D, Z_2D, self.nfp, mpol=mpol, ntor=ntor)
616+
617+
theta1D = jnp.linspace(0, 2 * jnp.pi, ntheta)
618+
phi1D = jnp.linspace(0, 2 * jnp.pi, nphi)
619+
phi2D, theta2D = jnp.meshgrid(phi1D, theta1D, indexing='ij')
620+
621+
def compute_RZ(m, n):
622+
angle = m * theta2D - n * self.nfp * phi2D
623+
return RBC[n + ntor, m] * jnp.cos(angle), ZBS[n + ntor, m] * jnp.sin(angle)
624+
625+
m_vals = jnp.arange(mpol + 1)
626+
n_vals = jnp.arange(-ntor, ntor + 1)
627+
628+
R_2Dnew, Z_2Dnew = vmap(lambda m: vmap(lambda n: compute_RZ(m, n))(n_vals))(m_vals)
629+
R_2Dnew, Z_2Dnew = R_2Dnew.sum(axis=(0, 1)), Z_2Dnew.sum(axis=(0, 1))
630+
631+
x_2D_plot = R_2Dnew.T * jnp.cos(phi1D)
632+
y_2D_plot = R_2Dnew.T * jnp.sin(phi1D)
633+
z_2D_plot = Z_2Dnew.T
634+
return x_2D_plot, y_2D_plot, z_2D_plot, R_2Dnew.T
635+
636+
@partial(jit, static_argnames=['self'])
637+
def B_mag(self, r, theta, phi):
638+
return self.B0*(1 + r * self.etabar * jnp.cos(theta - (self.iota - self.iotaN) * phi))
639+
640+
def plot(self, r=0.1, ntheta=80, nphi=150, ntheta_fourier=20, ax=None, show=True, close=False, axis_equal=True, **kwargs):
641+
kwargs.setdefault('alpha', 1)
448642
import matplotlib.pyplot as plt
643+
from matplotlib import cm
644+
import matplotlib.colors as clr
645+
from matplotlib.colors import LightSource
449646
if ax is None or ax.name != "3d":
450647
fig = plt.figure()
451-
ax = fig.add_subplot(projection='3d')
452-
453-
x_plot = self.R0 * jnp.cos(self.phi)
454-
y_plot = self.R0 * jnp.sin(self.phi)
455-
z_plot = self.Z0
456-
457-
plt.plot(x_plot, y_plot, z_plot)
648+
ax = fig.add_subplot(projection='3d')
649+
x_2D_plot, y_2D_plot, z_2D_plot, _ = self.get_boundary(r=r, ntheta=ntheta, nphi=nphi, ntheta_fourier=ntheta_fourier)
650+
theta1D = jnp.linspace(0, 2 * jnp.pi, ntheta)
651+
phi1D = jnp.linspace(0, 2 * jnp.pi, nphi)
652+
phi2D, theta2D = jnp.meshgrid(phi1D, theta1D)
653+
import numpy as np
654+
Bmag = np.array(self.B_mag(r, theta2D, phi2D))
655+
norm = clr.Normalize(vmin=Bmag.min(), vmax=Bmag.max())
656+
cmap = cm.viridis
657+
ls = LightSource(azdeg=0, altdeg=10)
658+
cmap_plot = ls.shade(Bmag, cmap, norm=norm)
659+
ax.plot_surface(x_2D_plot, y_2D_plot, z_2D_plot, facecolors=cmap_plot,
660+
rstride=1, cstride=1, antialiased=False,
661+
linewidth=0, shade=False, **kwargs)
662+
# ax.dist = 7
663+
# ax.elev = 5
664+
# ax.azim = 45
665+
# cbar_ax = ax.figure.add_axes([0.85, 0.2, 0.03, 0.6])
666+
# m = cm.ScalarMappable(cmap=cmap, norm=norm)
667+
# m.set_array([])
668+
# cbar = plt.colorbar(m, cax=cbar_ax)
669+
# cbar.ax.set_title(r'$|B| [T]$')
458670
ax.grid(False)
459-
# ax.set_xlabel('X (meters)', fontsize=10)
460-
# ax.set_ylabel('Y (meters)', fontsize=10)
461-
# ax.set_zlabel('Z (meters)', fontsize=10)
462-
463671
if axis_equal:
464672
fix_matplotlib_3d(ax)
465673
if show:

examples/optimize_coils_and_nearaxis.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,12 +81,11 @@
8181
ax1 = fig.add_subplot(121, projection='3d')
8282
ax2 = fig.add_subplot(122, projection='3d')
8383
coils_optimized_initial_nearaxis.plot(ax=ax1, show=False)
84-
field_nearaxis_initial.plot(ax=ax1, show=False)
84+
field_nearaxis_initial.plot(ax=ax1, show=False, alpha=0.1)
8585
tracing_initial.plot(ax=ax1, show=False)
8686
coils_optimized.plot(ax=ax2, show=False)
87-
field_nearaxis_optimized.plot(ax=ax2, show=False)
87+
field_nearaxis_optimized.plot(ax=ax2, show=False, alpha=0.1)
8888
tracing_optimized.plot(ax=ax2, show=False)
89-
plt.tight_layout()
9089
plt.show()
9190

9291
# # Save the coils to a json file

examples/optimize_coils_for_nearaxis.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,11 @@
6565
ax1 = fig.add_subplot(121, projection='3d')
6666
ax2 = fig.add_subplot(122, projection='3d')
6767
coils_initial.plot(ax=ax1, show=False)
68-
field.plot(ax=ax1, show=False)
68+
field.plot(ax=ax1, show=False, alpha=0.1)
6969
tracing_initial.plot(ax=ax1, show=False)
7070
coils_optimized.plot(ax=ax2, show=False)
71-
field.plot(ax=ax2, show=False)
71+
field.plot(ax=ax2, show=False, alpha=0.1)
7272
tracing_optimized.plot(ax=ax2, show=False)
73-
plt.tight_layout()
7473
plt.show()
7574

7675
# # Save the coils to a json file

0 commit comments

Comments
 (0)