Skip to content

Commit 9cd4cf9

Browse files
committed
Now optimizing surfaces, near-axis and coils altogether
1 parent 2cbc36c commit 9cd4cf9

File tree

6 files changed

+196
-66
lines changed

6 files changed

+196
-66
lines changed

essos/fields.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import jax.numpy as jnp
44
from functools import partial
55
from jax import jit, jacfwd, grad, vmap, tree_util, lax
6-
from essos.surfaces import SurfaceRZFourier
6+
from essos.surfaces import SurfaceRZFourier, BdotN_over_B
77
from essos.plot import fix_matplotlib_3d
88
from essos.util import newton
99

@@ -620,7 +620,7 @@ def compute_RZ(m, n):
620620
def B_mag(self, r, theta, phi):
621621
return self.B0*(1 + r * self.etabar * jnp.cos(theta - (self.iota - self.iotaN) * phi))
622622

623-
def plot(self, r=0.1, ntheta=80, nphi=150, ntheta_fourier=20, ax=None, show=True, close=False, axis_equal=True, **kwargs):
623+
def plot(self, r=0.1, ntheta=40, nphi=120, ntheta_fourier=20, ax=None, show=True, close=False, axis_equal=True, **kwargs):
624624
kwargs.setdefault('alpha', 1)
625625
import matplotlib.pyplot as plt
626626
from matplotlib import cm
@@ -657,6 +657,29 @@ def plot(self, r=0.1, ntheta=80, nphi=150, ntheta_fourier=20, ax=None, show=True
657657
if show:
658658
plt.show()
659659

660+
def to_vtk(self, filename, r=0.1, ntheta=40, nphi=120, ntheta_fourier=20, extra_data=None, field=None):
661+
try: import numpy as np
662+
except ImportError: raise ImportError("The 'numpy' library is required. Please install it using 'pip install numpy'.")
663+
try: from pyevtk.hl import gridToVTK
664+
except ImportError: raise ImportError("The 'pyevtk' library is required. Please install it using 'pip install pyevtk'.")
665+
x, y, z, _ = self.get_boundary(r=r, ntheta=ntheta, nphi=nphi, ntheta_fourier=ntheta_fourier)
666+
x = np.array(x.T.reshape((1, nphi, ntheta)).copy())
667+
y = np.array(y.T.reshape((1, nphi, ntheta)).copy())
668+
z = np.array(z.T.reshape((1, nphi, ntheta)).copy())
669+
pointData = {}
670+
if field is not None:
671+
boundary = np.array([x, y, z]).transpose(1, 2, 3, 0)[0]
672+
B_BiotSavart = np.array(vmap(lambda surf: vmap(lambda x: field.AbsB(x))(surf))(boundary)).reshape((1, nphi, ntheta)).copy()
673+
pointData["B_BiotSavart"] = B_BiotSavart
674+
theta1D = jnp.linspace(0, 2 * jnp.pi, ntheta)
675+
phi1D = jnp.linspace(0, 2 * jnp.pi, nphi)
676+
phi2D, theta2D = jnp.meshgrid(phi1D, theta1D)
677+
Bmag = np.array(self.B_mag(r, theta2D, phi2D)).T.reshape((1, nphi, ntheta)).copy()
678+
pointData["B_NearAxis"]=Bmag
679+
if extra_data is not None:
680+
pointData = {**pointData, **extra_data}
681+
gridToVTK(str(filename), x, y, z, pointData=pointData)
682+
660683
tree_util.register_pytree_node(near_axis,
661684
near_axis._tree_flatten,
662685
near_axis._tree_unflatten)

essos/objective_functions.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def loss_coils_for_nearaxis(x, field_nearaxis, dofs_curves, currents_scale, nfp,
4242

4343
return B_difference_loss+gradB_difference_loss+coil_length_loss+coil_curvature_loss
4444

45-
@partial(jit, static_argnums=(0, 1))
45+
# @partial(jit, static_argnums=(0, 1))
4646
def difference_B_gradB_onaxis(nearaxis_field, coils_field):
4747
Raxis = nearaxis_field.R0
4848
Zaxis = nearaxis_field.Z0
@@ -102,15 +102,15 @@ def loss_particle_drift(field, particles, maxtime=1e-5, num_steps=300, trace_tol
102102
angular_drift=(jnp.sum(jnp.diff(angular_drift,axis=1),axis=1))/num_steps
103103
return jnp.concatenate((jnp.max(radial_drift)*jnp.ravel(2./jnp.pi*jnp.abs(jnp.arctan(radial_drift/(angular_drift+1e-10)))), jnp.ravel(jnp.abs(radial_drift)), jnp.ravel(jnp.abs(vertical_factor))))
104104

105-
@partial(jit, static_argnums=(0))
105+
# @partial(jit, static_argnums=(0))
106106
def loss_coil_length(field):
107107
return jnp.ravel(field.coils.length)
108108

109-
@partial(jit, static_argnums=(0))
109+
# @partial(jit, static_argnums=(0))
110110
def loss_coil_curvature(field):
111111
return jnp.mean(field.coils.curvature, axis=1)
112112

113-
@partial(jit, static_argnums=(0, 1))
113+
# @partial(jit, static_argnums=(0, 1))
114114
def loss_normB_axis(field, npoints=15):
115115
R_axis = jnp.mean(jnp.sqrt(vmap(lambda dofs: dofs[0, 0]**2 + dofs[1, 0]**2)(field.coils.dofs_curves)))
116116
phi_array = jnp.linspace(0, 2 * jnp.pi, npoints)

essos/optimization.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def optimize_loss_function(func, initial_dofs, coils, tolerance_optimization=1e-
3434
loss_partial = partial(func, dofs_curves=coils.dofs_curves, currents_scale=currents_scale, nfp=nfp, n_segments=n_segments, stellsym=stellsym, **kwargs)
3535

3636
## Without JAX gradients, using finite differences
37-
# result = least_squares(loss_partial, x0=initial_dofs, verbose=2, diff_step=1e-2,
37+
# result = least_squares(loss_partial, x0=initial_dofs, verbose=2, diff_step=1e-4,
3838
# ftol=tolerance_optimization, gtol=tolerance_optimization,
3939
# xtol=1e-14, max_nfev=maximum_function_evaluations)
4040

@@ -60,13 +60,23 @@ def optimize_loss_function(func, initial_dofs, coils, tolerance_optimization=1e-
6060
new_field_nearaxis = new_nearaxis_from_x_and_old_nearaxis(result.x[-len(kwargs['field_nearaxis'].x):], kwargs['field_nearaxis'])
6161
return new_coils, new_field_nearaxis
6262
elif 'surface_all' in kwargs and len(initial_dofs) == len(coils.x) + len(kwargs['surface_all'].x):
63-
dofs_currents = result.x[len_dofs_curves:-len(kwargs['surface_all'].x)]
63+
surface_all = kwargs['surface_all']
64+
dofs_currents = result.x[len_dofs_curves:-len(surface_all.x)]
6465
curves = Curves(dofs_curves, n_segments, nfp, stellsym)
6566
new_coils = Coils(curves=curves, currents=dofs_currents * coils.currents_scale)
66-
surface_all = kwargs['surface_all']
6767
new_surface = SurfaceRZFourier(rc=surface_all.rc, zs=surface_all.zs, nfp=nfp, range_torus=surface_all.range_torus, nphi=surface_all.nphi, ntheta=surface_all.ntheta)
6868
new_surface.dofs = result.x[-len(surface_all.x):]
6969
return new_coils, new_surface
70+
elif 'surface_all' in kwargs and 'field_nearaxis' in kwargs and len(initial_dofs) == len(coils.x) + len(kwargs['surface_all'].x) + len(kwargs['field_nearaxis'].x):
71+
surface_all = kwargs['surface_all']
72+
field_nearaxis = kwargs['field_nearaxis']
73+
dofs_currents = result.x[len_dofs_curves:-len(surface_all.x)-len(field_nearaxis.x)]
74+
curves = Curves(dofs_curves, n_segments, nfp, stellsym)
75+
new_coils = Coils(curves=curves, currents=dofs_currents * coils.currents_scale)
76+
new_surface = SurfaceRZFourier(rc=surface_all.rc, zs=surface_all.zs, nfp=nfp, range_torus=surface_all.range_torus, nphi=surface_all.nphi, ntheta=surface_all.ntheta)
77+
new_surface.dofs = result.x[-len(surface_all.x)-len(field_nearaxis.x):-len(field_nearaxis.x)]
78+
new_field_nearaxis = new_nearaxis_from_x_and_old_nearaxis(result.x[-len(field_nearaxis.x):], field_nearaxis)
79+
return new_coils, new_surface, new_field_nearaxis
7080
except Exception as e:
7181
jax.debug.print("Error: {}", e)
7282
return None

examples/input_files/input.rotating_ellipse

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@
44
NFP = 0002
55
MPOL = 002
66
NTOR = 002
7-
!----- Boundary Parameters -----
8-
RBC( 000,000) = 1.012658382028E+001 ZBS( 000,000) = 0.000000000000E+000
9-
RBC( 001,000) = 1.656442246637E+000 ZBS( 001,000) = -1.269527884010E+000
10-
RBC(-001,001) = 3.742783485244E-001 ZBS(-001,001) = 3.400206123943E-001
11-
RBC( 000,001) = 1.679774143201E+000 ZBS( 000,001) = 2.181203441748E+000
12-
RBC( 001,001) = -1.102571811326E+000 ZBS( 001,001) = 8.105522673578E-001
13-
RBC(-002,002) = 6.572506820813E-003 ZBS(-002,002) = 4.002721688287E-003
7+
!----- Boundary Parameters (n,m) -----
8+
RBC( 000,000) = 10 ZBS( 000,000) = 0
9+
RBC( 001,000) = 1 ZBS( 001,000) = -1
10+
RBC(-001,001) = 0.1 ZBS(-001,001) = 0.1
11+
RBC( 000,001) = 2.5 ZBS( 000,001) = 2.5
12+
RBC( 001,001) = -1 ZBS( 001,001) = 1
13+
RBC(-002,002) = 1E-4 ZBS(-002,002) = 1E-4
1414
/

0 commit comments

Comments
 (0)