22jax .config .update ("jax_enable_x64" , True )
33import jax .numpy as jnp
44from functools import partial
5- from jax import jit , jacfwd , grad , vmap , tree_util
5+ from jax import jit , jacfwd , grad , vmap , tree_util , lax
66from essos .surfaces import SurfaceRZFourier
77from essos .plot import fix_matplotlib_3d
88
@@ -184,12 +184,34 @@ def to_xyz(self, points):
184184 Y = R * jnp .sin (phi )
185185 return jnp .array ([X , Y , Z ])
186186
187+ def newton (f , x0 ):
188+ """Newton's method for root-finding."""
189+ initial_state = (0 , x0 ) # (iteration, x)
190+
191+ def cond (state ):
192+ it , x = state
193+ # We fix 30 iterations for simplicity, this is plenty for convergence in our tests.
194+ return (it < 30 )
195+
196+ def body (state ):
197+ it , x = state
198+ fx , dfx = f (x ), jax .grad (f )(x )
199+ step = fx / dfx
200+ new_state = it + 1 , x - step
201+ return new_state
202+
203+ return jax .lax .while_loop (
204+ cond ,
205+ body ,
206+ initial_state ,
207+ )[1 ]
208+
187209class near_axis ():
188210 def __init__ (self , rc = jnp .array ([1 , 0.1 ]), zs = jnp .array ([0 , 0.1 ]), etabar = 1.0 ,
189- B0 = 1 , sigma0 = 0 , I2 = 0 , nphi = 31 , spsi = 1 , sG = 1 , nfp = 2 ):
211+ B0 = 1 , sigma0 = 0 , I2 = 0 , nphi = 31 , spsi = 1 , sG = 1 , nfp = 2 , order = 'r1' , B2c = 0 , p2 = 0 ):
190212 assert nphi % 2 == 1 , 'nphi must be odd'
191- self .rc = rc
192- self .zs = zs
213+ self .rc = jnp . array ( rc )
214+ self .zs = jnp . array ( zs )
193215 self .etabar = etabar
194216 self .nphi = nphi
195217 self .sigma0 = sigma0
@@ -198,27 +220,36 @@ def __init__(self, rc=jnp.array([1, 0.1]), zs=jnp.array([0, 0.1]), etabar=1.0,
198220 self .sG = sG
199221 self .B0 = B0
200222 self .nfp = nfp
223+ self .order = order # not used
224+ self .B2c = B2c # not used
225+ self .p2 = p2 # not used
201226
202- self ._dofs = jnp .concatenate ((jnp .ravel (rc ), jnp .ravel (zs ), jnp .array ([etabar ])))
227+ self ._dofs = jnp .concatenate ((jnp .ravel (self . rc ), jnp .ravel (self . zs ), jnp .array ([etabar ])))
203228
204229 self .phi = jnp .linspace (0 , 2 * jnp .pi / self .nfp , self .nphi , endpoint = False )
205230 self .nfourier = max (len (self .rc ), len (self .zs ))
206231
207232 parameters = self .calculate (self .rc , self .zs , self .etabar )
208- self .R0 , self .Z0 , self .sigma , self .elongation , self .B_axis , self .grad_B_axis , self .axis_length , self .iota , self .iotaN , self .G0 = parameters
233+ (self .R0 , self .Z0 , self .sigma , self .elongation , self .B_axis , self .grad_B_axis , self .axis_length , self .iota , self .iotaN , self .G0 ,
234+ self .helicity , self .X1c_untwisted , self .X1s_untwisted , self .Y1s_untwisted , self .Y1c_untwisted ,
235+ self .normal_R , self .normal_phi , self .normal_z , self .binormal_R , self .binormal_phi , self .binormal_z ,
236+ self .L_grad_B , self .inv_L_grad_B , self .torsion ) = parameters
209237
210238 @property
211239 def dofs (self ):
212240 return self ._dofs
213241
214242 @dofs .setter
215243 def dofs (self , new_dofs ):
216- self ._dofs = new_dofs
217- self .rc = new_dofs [: len ( self .rc ) ]
218- self .zs = new_dofs [ len ( self .rc ): len ( self .rc ) + len ( self . zs ) ]
219- self .etabar = new_dofs [- 1 ]
244+ self ._dofs = jnp . array ( new_dofs )
245+ self .rc = self . _dofs [: self .nfourier ]
246+ self .zs = self . _dofs [ self .nfourier : 2 * self .nfourier ]
247+ self .etabar = self . _dofs [- 1 ]
220248 parameters = self .calculate (self .rc , self .zs , self .etabar )
221- self .R0 , self .Z0 , self .sigma , self .elongation , self .B_axis , self .grad_B_axis , self .axis_length , self .iota , self .iotaN , self .G0 = parameters
249+ (self .R0 , self .Z0 , self .sigma , self .elongation , self .B_axis , self .grad_B_axis , self .axis_length , self .iota , self .iotaN , self .G0 ,
250+ self .helicity , self .X1c_untwisted , self .X1s_untwisted , self .Y1s_untwisted , self .Y1c_untwisted ,
251+ self .normal_R , self .normal_z , self .normal_phi , self .binormal_R , self .binormal_z , self .binormal_phi ,
252+ self .L_grad_B , self .inv_L_grad_B , self .torsion ) = parameters
222253
223254 @property
224255 def x (self ):
@@ -230,7 +261,8 @@ def x(self, new_x):
230261
231262 def _tree_flatten (self ):
232263 children = (self .rc , self .zs , self .etabar , self .B0 , self .sigma0 , self .I2 ) # arrays / dynamic values
233- aux_data = {"nphi" : self .nphi , "spsi" : self .spsi , "sG" : self .sG , "nfp" : self .nfp } # static values
264+ aux_data = {"nphi" : self .nphi , "spsi" : self .spsi , "sG" : self .sG ,
265+ "nfp" : self .nfp , "order" : self .order , "B2c" : self .B2c , "p2" : self .p2 } # static values
234266 return (children , aux_data )
235267
236268 @classmethod
@@ -440,26 +472,202 @@ def body_fun(i, x):
440472 nablaB [2 , 2 ]]
441473 ])
442474
443- return R0 , Z0 , sigma , elongation , B_axis , grad_B_axis , axis_length , iota , iotaN , G0
475+ grad_B_colon_grad_B = tn * tn + nt * nt \
476+ + bb * bb + nn * nn \
477+ + nb * nb + bn * bn \
478+ + tt * tt
479+ L_grad_B = self .B0 * jnp .sqrt (2 / grad_B_colon_grad_B )
480+ inv_L_grad_B = 1.0 / L_grad_B
481+
482+ X1c_untwisted = jnp .where (helicity == 0 , X1c , X1c * jnp .cos (- helicity * nfp * varphi ))
483+ X1s_untwisted = jnp .where (helicity == 0 , 0 * X1c , X1c * jnp .sin (- helicity * nfp * varphi ))
484+ Y1s_untwisted = jnp .where (helicity == 0 , Y1s , Y1s * jnp .cos (- helicity * nfp * varphi ) + Y1c * jnp .sin (- helicity * nfp * varphi ))
485+ Y1c_untwisted = jnp .where (helicity == 0 , Y1c , Y1s * (- jnp .sin (- helicity * nfp * varphi )) + Y1c * jnp .cos (- helicity * nfp * varphi ))
486+
487+ normal_R = normal_cylindrical [:,0 ]
488+ normal_phi = normal_cylindrical [:,1 ]
489+ normal_z = normal_cylindrical [:,2 ]
490+ binormal_R = binormal_cylindrical [:,0 ]
491+ binormal_phi = binormal_cylindrical [:,1 ]
492+ binormal_z = binormal_cylindrical [:,2 ]
493+
494+ return (R0 , Z0 , sigma , elongation , B_axis , grad_B_axis , axis_length , iota , iotaN , G0 ,
495+ helicity , X1c_untwisted , X1s_untwisted , Y1s_untwisted , Y1c_untwisted ,
496+ normal_R , normal_phi , normal_z , binormal_R , binormal_phi , binormal_z ,
497+ L_grad_B , inv_L_grad_B , torsion )
498+
499+ @jit
500+ def interpolated_array_at_point (self ,array ,point ):
501+ sp = jnp .interp (jnp .array ([point ]), jnp .append (self .phi ,2 * jnp .pi / self .nfp ), jnp .append (array ,array [0 ]), period = 2 * jnp .pi / self .nfp )[0 ]
502+ # sp=interpax.interp1d(jnp.array([point]), jnp.append(self.phi,2*jnp.pi/self.nfp), jnp.append(array,array[0]), method="cubic", period=2*jnp.pi/self.nfp)[0]
503+ return sp
504+
505+ @jit
506+ def Frenet_to_cylindrical_residual_func (self ,phi0 , phi_target , X_at_this_theta , Y_at_this_theta ):
507+ sinphi0 = jnp .sin (phi0 )
508+ cosphi0 = jnp .cos (phi0 )
509+ R0_at_phi0 = self .interpolated_array_at_point (self .R0 ,phi0 )
510+ X_at_phi0 = self .interpolated_array_at_point (X_at_this_theta ,phi0 )
511+ Y_at_phi0 = self .interpolated_array_at_point (Y_at_this_theta ,phi0 )
512+ normal_R = self .interpolated_array_at_point (self .normal_R ,phi0 )
513+ normal_phi = self .interpolated_array_at_point (self .normal_phi ,phi0 )
514+ binormal_R = self .interpolated_array_at_point (self .binormal_R ,phi0 )
515+ binormal_phi = self .interpolated_array_at_point (self .binormal_phi ,phi0 )
516+ normal_x = normal_R * cosphi0 - normal_phi * sinphi0
517+ normal_y = normal_R * sinphi0 + normal_phi * cosphi0
518+ binormal_x = binormal_R * cosphi0 - binormal_phi * sinphi0
519+ binormal_y = binormal_R * sinphi0 + binormal_phi * cosphi0
520+ total_x = R0_at_phi0 * cosphi0 + X_at_phi0 * normal_x + Y_at_phi0 * binormal_x
521+ total_y = R0_at_phi0 * sinphi0 + X_at_phi0 * normal_y + Y_at_phi0 * binormal_y
522+ Frenet_to_cylindrical_residual = jnp .arctan2 (total_y , total_x ) - phi_target
523+ Frenet_to_cylindrical_residual = jnp .where (Frenet_to_cylindrical_residual > jnp .pi , Frenet_to_cylindrical_residual - 2 * jnp .pi , Frenet_to_cylindrical_residual )
524+ Frenet_to_cylindrical_residual = jnp .where (Frenet_to_cylindrical_residual < - jnp .pi , Frenet_to_cylindrical_residual + 2 * jnp .pi , Frenet_to_cylindrical_residual )
525+ return Frenet_to_cylindrical_residual
526+
527+ @jit
528+ def Frenet_to_cylindrical_1_point (self , phi0 , X_at_this_theta , Y_at_this_theta ):
529+ sinphi0 = jnp .sin (phi0 )
530+ cosphi0 = jnp .cos (phi0 )
531+ R0_at_phi0 = self .interpolated_array_at_point (self .R0 ,phi0 )
532+ z0_at_phi0 = self .interpolated_array_at_point (self .Z0 ,phi0 )
533+ X_at_phi0 = self .interpolated_array_at_point (X_at_this_theta ,phi0 )
534+ Y_at_phi0 = self .interpolated_array_at_point (Y_at_this_theta ,phi0 )
535+ normal_R = self .interpolated_array_at_point (self .normal_R ,phi0 )
536+ normal_phi = self .interpolated_array_at_point (self .normal_phi ,phi0 )
537+ normal_z = self .interpolated_array_at_point (self .normal_z ,phi0 )
538+ binormal_R = self .interpolated_array_at_point (self .binormal_R ,phi0 )
539+ binormal_phi = self .interpolated_array_at_point (self .binormal_phi ,phi0 )
540+ binormal_z = self .interpolated_array_at_point (self .binormal_z ,phi0 )
541+ normal_x = normal_R * cosphi0 - normal_phi * sinphi0
542+ normal_y = normal_R * sinphi0 + normal_phi * cosphi0
543+ binormal_x = binormal_R * cosphi0 - binormal_phi * sinphi0
544+ binormal_y = binormal_R * sinphi0 + binormal_phi * cosphi0
545+ total_x = R0_at_phi0 * cosphi0 + X_at_phi0 * normal_x + Y_at_phi0 * binormal_x
546+ total_y = R0_at_phi0 * sinphi0 + X_at_phi0 * normal_y + Y_at_phi0 * binormal_y
547+ total_z = z0_at_phi0 + X_at_phi0 * normal_z + Y_at_phi0 * binormal_z
548+ total_R = jnp .sqrt (total_x * total_x + total_y * total_y )
549+ total_phi = jnp .arctan2 (total_y , total_x )
550+ return total_R , total_z , total_phi
551+
552+ @partial (jit , static_argnames = ['ntheta' ])
553+ def Frenet_to_cylindrical (self , r , ntheta = 20 ):
554+ nphi_conversion = self .nphi
555+ theta = jnp .linspace (0 , 2 * jnp .pi , ntheta , endpoint = False )
556+ phi_conversion = jnp .linspace (0 , 2 * jnp .pi / self .nfp , nphi_conversion , endpoint = False )
557+
558+ def compute_for_theta (theta_j ):
559+ costheta = jnp .cos (theta_j )
560+ sintheta = jnp .sin (theta_j )
561+ X_at_this_theta = r * (self .X1c_untwisted * costheta + self .X1s_untwisted * sintheta )
562+ Y_at_this_theta = r * (self .Y1c_untwisted * costheta + self .Y1s_untwisted * sintheta )
563+
564+ def compute_for_phi (phi_target ):
565+ residual = partial (self .Frenet_to_cylindrical_residual_func , phi_target = phi_target ,
566+ X_at_this_theta = X_at_this_theta , Y_at_this_theta = Y_at_this_theta )
567+ phi0_solution = lax .custom_root (residual , phi_target , newton , lambda g , y : y / g (1.0 ))
568+ final_R , final_Z , _ = self .Frenet_to_cylindrical_1_point (phi0_solution , X_at_this_theta , Y_at_this_theta )
569+ return final_R , final_Z , phi0_solution
570+
571+ return vmap (compute_for_phi )(phi_conversion )
572+
573+ R_2D , Z_2D , phi0_2D = vmap (compute_for_theta )(theta )
574+ return R_2D , Z_2D , phi0_2D
575+
576+
577+ @partial (jit , static_argnames = ['mpol' , 'ntor' ])
578+ def to_Fourier (self , R_2D , Z_2D , nfp , mpol , ntor ):
579+ ntheta , nphi_conversion = R_2D .shape
580+ theta = jnp .linspace (0 , 2 * jnp .pi , ntheta , endpoint = False )
581+ phi_conversion = jnp .linspace (0 , 2 * jnp .pi / nfp , nphi_conversion , endpoint = False )
444582
445- def plot (self , ax = None , show = True , close = False , axis_equal = True , ** kwargs ):
446- if close : raise NotImplementedError ("close=True is not implemented, need to have closed surfaces" )
583+ phi2d , theta2d = jnp .meshgrid (phi_conversion , theta , indexing = 'xy' )
584+ factor = 2 / (ntheta * nphi_conversion )
585+
586+ def compute_RBC_ZBS (m , n ):
587+ angle = m * theta2d - n * nfp * phi2d
588+ sinangle , cosangle = jnp .sin (angle ), jnp .cos (angle )
447589
590+ factor2 = jax .lax .cond (
591+ (ntheta % 2 == 0 ) & (m == (ntheta / 2 )),
592+ lambda _ : factor / 2 , lambda _ : factor ,
593+ operand = None )
594+
595+ factor2 = jax .lax .cond (
596+ (nphi_conversion % 2 == 0 ) & (abs (n ) == (nphi_conversion / 2 )),
597+ lambda _ : factor2 / 2 , lambda _ : factor2 ,
598+ operand = None )
599+
600+ return jnp .sum (R_2D * cosangle * factor2 ), jnp .sum (Z_2D * sinangle * factor2 )
601+
602+ m_vals = jnp .arange (mpol + 1 )
603+ n_vals = jnp .concatenate ([jnp .array ([1 ]), jnp .arange (- ntor , ntor + 1 )]) if mpol == 0 else jnp .arange (- ntor , ntor + 1 )
604+ RBC , ZBS = vmap (lambda n : vmap (lambda m : compute_RBC_ZBS (m , n ))(m_vals ))(n_vals )
605+
606+ RBC = RBC .at [ntor , 0 ].set (jnp .sum (R_2D ) / (ntheta * nphi_conversion ))
607+ ZBS = ZBS .at [:ntor , 0 ].set (0 )
608+ RBC = RBC .at [:ntor , 0 ].set (0 )
609+ return RBC , ZBS
610+
611+
612+ @partial (jit , static_argnames = ['ntheta_fourier' , 'mpol' , 'ntor' , 'ntheta' , 'nphi' ])
613+ def get_boundary (self , r = 0.1 , ntheta = 30 , nphi = 130 , ntheta_fourier = 20 , mpol = 5 , ntor = 5 ):
614+ R_2D , Z_2D , _ = self .Frenet_to_cylindrical (r , ntheta = ntheta_fourier )
615+ RBC , ZBS = self .to_Fourier (R_2D , Z_2D , self .nfp , mpol = mpol , ntor = ntor )
616+
617+ theta1D = jnp .linspace (0 , 2 * jnp .pi , ntheta )
618+ phi1D = jnp .linspace (0 , 2 * jnp .pi , nphi )
619+ phi2D , theta2D = jnp .meshgrid (phi1D , theta1D , indexing = 'ij' )
620+
621+ def compute_RZ (m , n ):
622+ angle = m * theta2D - n * self .nfp * phi2D
623+ return RBC [n + ntor , m ] * jnp .cos (angle ), ZBS [n + ntor , m ] * jnp .sin (angle )
624+
625+ m_vals = jnp .arange (mpol + 1 )
626+ n_vals = jnp .arange (- ntor , ntor + 1 )
627+
628+ R_2Dnew , Z_2Dnew = vmap (lambda m : vmap (lambda n : compute_RZ (m , n ))(n_vals ))(m_vals )
629+ R_2Dnew , Z_2Dnew = R_2Dnew .sum (axis = (0 , 1 )), Z_2Dnew .sum (axis = (0 , 1 ))
630+
631+ x_2D_plot = R_2Dnew .T * jnp .cos (phi1D )
632+ y_2D_plot = R_2Dnew .T * jnp .sin (phi1D )
633+ z_2D_plot = Z_2Dnew .T
634+ return x_2D_plot , y_2D_plot , z_2D_plot , R_2Dnew .T
635+
636+ @partial (jit , static_argnames = ['self' ])
637+ def B_mag (self , r , theta , phi ):
638+ return self .B0 * (1 + r * self .etabar * jnp .cos (theta - (self .iota - self .iotaN ) * phi ))
639+
640+ def plot (self , r = 0.1 , ntheta = 80 , nphi = 150 , ntheta_fourier = 20 , ax = None , show = True , close = False , axis_equal = True , ** kwargs ):
641+ kwargs .setdefault ('alpha' , 1 )
448642 import matplotlib .pyplot as plt
643+ from matplotlib import cm
644+ import matplotlib .colors as clr
645+ from matplotlib .colors import LightSource
449646 if ax is None or ax .name != "3d" :
450647 fig = plt .figure ()
451- ax = fig .add_subplot (projection = '3d' )
452-
453- x_plot = self .R0 * jnp .cos (self .phi )
454- y_plot = self .R0 * jnp .sin (self .phi )
455- z_plot = self .Z0
456-
457- plt .plot (x_plot , y_plot , z_plot )
648+ ax = fig .add_subplot (projection = '3d' )
649+ x_2D_plot , y_2D_plot , z_2D_plot , _ = self .get_boundary (r = r , ntheta = ntheta , nphi = nphi , ntheta_fourier = ntheta_fourier )
650+ theta1D = jnp .linspace (0 , 2 * jnp .pi , ntheta )
651+ phi1D = jnp .linspace (0 , 2 * jnp .pi , nphi )
652+ phi2D , theta2D = jnp .meshgrid (phi1D , theta1D )
653+ import numpy as np
654+ Bmag = np .array (self .B_mag (r , theta2D , phi2D ))
655+ norm = clr .Normalize (vmin = Bmag .min (), vmax = Bmag .max ())
656+ cmap = cm .viridis
657+ ls = LightSource (azdeg = 0 , altdeg = 10 )
658+ cmap_plot = ls .shade (Bmag , cmap , norm = norm )
659+ ax .plot_surface (x_2D_plot , y_2D_plot , z_2D_plot , facecolors = cmap_plot ,
660+ rstride = 1 , cstride = 1 , antialiased = False ,
661+ linewidth = 0 , shade = False , ** kwargs )
662+ # ax.dist = 7
663+ # ax.elev = 5
664+ # ax.azim = 45
665+ # cbar_ax = ax.figure.add_axes([0.85, 0.2, 0.03, 0.6])
666+ # m = cm.ScalarMappable(cmap=cmap, norm=norm)
667+ # m.set_array([])
668+ # cbar = plt.colorbar(m, cax=cbar_ax)
669+ # cbar.ax.set_title(r'$|B| [T]$')
458670 ax .grid (False )
459- # ax.set_xlabel('X (meters)', fontsize=10)
460- # ax.set_ylabel('Y (meters)', fontsize=10)
461- # ax.set_zlabel('Z (meters)', fontsize=10)
462-
463671 if axis_equal :
464672 fix_matplotlib_3d (ax )
465673 if show :
0 commit comments