Skip to content

Commit 885e592

Browse files
rogeriojorgeEstevaoMGomes
authored andcommitted
Refactor optimization functions to use a unified loss function approach across examples
1 parent 4f27d6e commit 885e592

7 files changed

+66
-74
lines changed

essos/objective_functions.py

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,24 +4,10 @@
44
from jax import jit, vmap
55
from functools import partial
66
from essos.dynamics import Tracing
7-
from essos.fields import BiotSavart, near_axis
7+
from essos.fields import BiotSavart
88
from essos.surfaces import BdotN_over_B, BdotN
99
from essos.coils import Curves, Coils
10-
11-
def new_nearaxis_from_x_and_old_nearaxis(new_field_nearaxis_x, field_nearaxis):
12-
len_rc = len(field_nearaxis.rc)
13-
len_zs = len(field_nearaxis.zs)
14-
# # keeping the first rc and zs the same
15-
# new_field_nearaxis_rc = jnp.concatenate((jnp.array([field_nearaxis.rc[0]]),new_field_nearaxis_x[:len_rc][1:]))
16-
# new_field_nearaxis_zs = jnp.concatenate((jnp.array([field_nearaxis.zs[0]]),new_field_nearaxis_x[len_rc:len_rc+len_zs][1:]))
17-
new_field_nearaxis_rc = new_field_nearaxis_x[:len_rc]
18-
new_field_nearaxis_zs = new_field_nearaxis_x[len_rc:len_rc+len_zs]
19-
new_field_nearaxis_etabar = new_field_nearaxis_x[-1]
20-
21-
new_field_nearaxis = near_axis(rc=new_field_nearaxis_rc, zs=new_field_nearaxis_zs, etabar=new_field_nearaxis_etabar,
22-
B0=field_nearaxis.B0, sigma0=field_nearaxis.sigma0, I2=field_nearaxis.I2,
23-
nphi=field_nearaxis.nphi, spsi=field_nearaxis.spsi, sG=field_nearaxis.sG, nfp=field_nearaxis.nfp)
24-
return new_field_nearaxis
10+
from essos.optimization import new_nearaxis_from_x_and_old_nearaxis
2511

2612
@partial(jit, static_argnums=(1, 4, 5, 6, 7, 8))
2713
def loss_coils_for_nearaxis(x, field_nearaxis, dofs_curves, currents_scale, nfp, max_coil_length=42,

essos/optimization.py

Lines changed: 16 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,22 @@
55
from functools import partial
66
from essos.coils import Curves, Coils
77
from scipy.optimize import least_squares, minimize
8-
from essos.objective_functions import (loss_optimize_coils_for_particle_confinement,
9-
loss_BdotN, loss_coils_for_nearaxis, loss_coils_and_nearaxis,
10-
new_nearaxis_from_x_and_old_nearaxis)
8+
from essos.fields import near_axis
9+
10+
def new_nearaxis_from_x_and_old_nearaxis(new_field_nearaxis_x, field_nearaxis):
11+
len_rc = len(field_nearaxis.rc)
12+
len_zs = len(field_nearaxis.zs)
13+
# # keeping the first rc and zs the same
14+
# new_field_nearaxis_rc = jnp.concatenate((jnp.array([field_nearaxis.rc[0]]),new_field_nearaxis_x[:len_rc][1:]))
15+
# new_field_nearaxis_zs = jnp.concatenate((jnp.array([field_nearaxis.zs[0]]),new_field_nearaxis_x[len_rc:len_rc+len_zs][1:]))
16+
new_field_nearaxis_rc = new_field_nearaxis_x[:len_rc]
17+
new_field_nearaxis_zs = new_field_nearaxis_x[len_rc:len_rc+len_zs]
18+
new_field_nearaxis_etabar = new_field_nearaxis_x[-1]
19+
20+
new_field_nearaxis = near_axis(rc=new_field_nearaxis_rc, zs=new_field_nearaxis_zs, etabar=new_field_nearaxis_etabar,
21+
B0=field_nearaxis.B0, sigma0=field_nearaxis.sigma0, I2=field_nearaxis.I2,
22+
nphi=field_nearaxis.nphi, spsi=field_nearaxis.spsi, sG=field_nearaxis.sG, nfp=field_nearaxis.nfp)
23+
return new_field_nearaxis
1124

1225
def optimize_loss_function(func, initial_dofs, coils, tolerance_optimization=1e-4, maximum_function_evaluations=30, **kwargs):
1326
len_dofs_curves = len(jnp.ravel(coils.dofs_curves))
@@ -48,34 +61,3 @@ def optimize_loss_function(func, initial_dofs, coils, tolerance_optimization=1e-
4861
except Exception as e:
4962
jax.debug.print("Error: {}", e)
5063
return None
51-
52-
def optimize_coils_for_particle_confinement(coils, particles, target_B_on_axis=5.7, max_coil_length=22, model='GuidingCenter',
53-
maxtime=5e-6, num_steps=500, trace_tolerance=1e-5, tolerance_optimization=1e-4,
54-
maximum_function_evaluations=30, max_coil_curvature=0.1):
55-
return optimize_loss_function(loss_optimize_coils_for_particle_confinement, initial_dofs=coils.x, coils=coils,
56-
tolerance_optimization=tolerance_optimization, particles=particles,
57-
maximum_function_evaluations=maximum_function_evaluations, max_coil_curvature=max_coil_curvature,
58-
target_B_on_axis=target_B_on_axis, max_coil_length=max_coil_length, model=model,
59-
maxtime=maxtime, num_steps=num_steps, trace_tolerance=trace_tolerance)
60-
61-
def optimize_coils_for_vmec_surface(vmec, coils, tolerance_optimization=1e-10,
62-
maximum_function_evaluations=30,
63-
max_coil_length=42, max_coil_curvature=0.1):
64-
return optimize_loss_function(loss_BdotN, initial_dofs=coils.x, coils=coils, tolerance_optimization=tolerance_optimization,
65-
maximum_function_evaluations=maximum_function_evaluations, vmec=vmec,
66-
max_coil_length=max_coil_length, max_coil_curvature=max_coil_curvature,)
67-
68-
def optimize_coils_for_nearaxis(field_nearaxis, coils, tolerance_optimization=1e-10,
69-
maximum_function_evaluations=30,
70-
max_coil_length=42, max_coil_curvature=0.1):
71-
return optimize_loss_function(loss_coils_for_nearaxis, initial_dofs=coils.x, coils=coils, tolerance_optimization=tolerance_optimization,
72-
maximum_function_evaluations=maximum_function_evaluations, field_nearaxis=field_nearaxis,
73-
max_coil_length=max_coil_length, max_coil_curvature=max_coil_curvature,)
74-
75-
def optimize_coils_and_nearaxis(field_nearaxis, coils, tolerance_optimization=1e-10,
76-
maximum_function_evaluations=30,
77-
max_coil_length=42, max_coil_curvature=0.1):
78-
initial_dofs = jnp.concatenate((coils.x, field_nearaxis.x))
79-
return optimize_loss_function(loss_coils_and_nearaxis, initial_dofs=initial_dofs, coils=coils, tolerance_optimization=tolerance_optimization,
80-
maximum_function_evaluations=maximum_function_evaluations, field_nearaxis=field_nearaxis,
81-
max_coil_length=max_coil_length, max_coil_curvature=max_coil_curvature,)

examples/optimize_coils_and_nearaxis.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
from essos.coils import Coils, CreateEquallySpacedCurves
55
from essos.fields import near_axis, BiotSavart
66
from essos.dynamics import Tracing
7-
from essos.optimization import optimize_coils_and_nearaxis, optimize_coils_for_nearaxis
8-
from essos.objective_functions import difference_B_gradB_onaxis
7+
from essos.optimization import optimize_loss_function
8+
from essos.objective_functions import (difference_B_gradB_onaxis,
9+
loss_coils_and_nearaxis, loss_coils_for_nearaxis)
910

1011
# Optimization parameters
1112
max_coil_length = 5.0
@@ -38,17 +39,19 @@
3839
# Optimize coils
3940
print(f'Optimizing coils for initial near=axis with {maximum_function_evaluations} function evaluations.')
4041
time0 = time()
41-
coils_optimized_initial_nearaxis = optimize_coils_for_nearaxis(field_nearaxis_initial, coils_initial, maximum_function_evaluations=maximum_function_evaluations,
42-
max_coil_length=max_coil_length, max_coil_curvature=max_coil_curvature,
43-
tolerance_optimization=tolerance_optimization)
42+
initial_dofs = coils_initial.x
43+
coils_optimized_initial_nearaxis = optimize_loss_function(loss_coils_for_nearaxis, initial_dofs=coils_initial.x, coils=coils_initial, tolerance_optimization=tolerance_optimization,
44+
maximum_function_evaluations=maximum_function_evaluations, field_nearaxis=field_nearaxis_initial,
45+
max_coil_length=max_coil_length, max_coil_curvature=max_coil_curvature,)
4446
print(f"Optimization took {time()-time0:.2f} seconds")
4547

4648
# Optimize coils
4749
print(f'Optimizing coils and near-axis with {maximum_function_evaluations} function evaluations.')
4850
time0 = time()
49-
coils_optimized, field_nearaxis_optimized = optimize_coils_and_nearaxis(field_nearaxis_initial, coils_initial, maximum_function_evaluations=maximum_function_evaluations,
50-
max_coil_length=max_coil_length, max_coil_curvature=max_coil_curvature,
51-
tolerance_optimization=tolerance_optimization)
51+
initial_dofs = jnp.concatenate((coils_initial.x, field_nearaxis_initial.x))
52+
coils_optimized, field_nearaxis_optimized = optimize_loss_function(loss_coils_and_nearaxis, initial_dofs=initial_dofs, coils=coils_initial, tolerance_optimization=tolerance_optimization,
53+
maximum_function_evaluations=maximum_function_evaluations, field_nearaxis=field_nearaxis_initial,
54+
max_coil_length=max_coil_length, max_coil_curvature=max_coil_curvature,)
5255
print(f"Optimization took {time()-time0:.2f} seconds")
5356

5457
B_difference_initial, gradB_difference_initial = difference_B_gradB_onaxis(field_nearaxis_initial, BiotSavart(coils_optimized_initial_nearaxis))

examples/optimize_coils_for_nearaxis.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
from essos.coils import Coils, CreateEquallySpacedCurves
55
from essos.fields import near_axis, BiotSavart
66
from essos.dynamics import Tracing
7-
from essos.optimization import optimize_coils_for_nearaxis
7+
from essos.optimization import optimize_loss_function
8+
from essos.objective_functions import loss_coils_for_nearaxis
89

910
# Optimization parameters
1011
max_coil_length = 5.0
@@ -37,9 +38,11 @@
3738
# Optimize coils
3839
print(f'Optimizing coils with {maximum_function_evaluations} function evaluations.')
3940
time0 = time()
40-
coils_optimized = optimize_coils_for_nearaxis(field, coils_initial, maximum_function_evaluations=maximum_function_evaluations,
41-
max_coil_length=max_coil_length, max_coil_curvature=max_coil_curvature,
42-
tolerance_optimization=tolerance_optimization)
41+
initial_dofs = coils_initial.x
42+
coils_optimized = optimize_loss_function(loss_coils_for_nearaxis, initial_dofs=coils_initial.x,
43+
coils=coils_initial, tolerance_optimization=tolerance_optimization,
44+
maximum_function_evaluations=maximum_function_evaluations, field_nearaxis=field,
45+
max_coil_length=max_coil_length, max_coil_curvature=max_coil_curvature,)
4346
print(f"Optimization took {time()-time0:.2f} seconds")
4447

4548
# Trace fieldlines

examples/optimize_coils_particle_confinement_fullorbit.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
from essos.dynamics import Particles, Tracing
99
from essos.fields import BiotSavart
1010
from essos.coils import Coils, CreateEquallySpacedCurves
11-
from essos.optimization import optimize_coils_for_particle_confinement
11+
from essos.optimization import optimize_loss_function
12+
from essos.objective_functions import loss_optimize_coils_for_particle_confinement
1213

1314
# Optimization parameters
1415
target_B_on_axis = 5.7
@@ -17,7 +18,7 @@
1718
nparticles = 8
1819
order_Fourier_series_coils = 4
1920
number_coil_points = 80
20-
maximum_function_evaluations = 100
21+
maximum_function_evaluations = 60
2122
maxtime_tracing = 1e-5
2223
number_coils_per_half_field_period = 3
2324
number_of_field_periods = 2
@@ -45,8 +46,11 @@
4546
# Optimize coils
4647
print(f'Optimizing coils with {maximum_function_evaluations} function evaluations and maxtime_tracing={maxtime_tracing}')
4748
time0 = time()
48-
coils_optimized = optimize_coils_for_particle_confinement(coils_initial, particles, target_B_on_axis=target_B_on_axis, maxtime=maxtime_tracing, model=model,
49-
max_coil_length=max_coil_length, maximum_function_evaluations=maximum_function_evaluations, max_coil_curvature=max_coil_curvature)
49+
coils_optimized = optimize_loss_function(loss_optimize_coils_for_particle_confinement, initial_dofs=coils_initial.x,
50+
coils=coils_initial, tolerance_optimization=1e-4, particles=particles,
51+
maximum_function_evaluations=maximum_function_evaluations, max_coil_curvature=max_coil_curvature,
52+
target_B_on_axis=target_B_on_axis, max_coil_length=max_coil_length, model=model,
53+
maxtime=maxtime_tracing, num_steps=timesteps, trace_tolerance=1e-5)
5054
print(f" Optimization took {time()-time0:.2f} seconds")
5155
particles.to_full_orbit(BiotSavart(coils_optimized))
5256
tracing_optimized = Tracing(field=coils_optimized, particles=particles, maxtime=maxtime_tracing, model=model, timesteps=timesteps)

examples/optimize_coils_particle_confinement_guidingcenter.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
import matplotlib.pyplot as plt
88
from essos.dynamics import Particles, Tracing
99
from essos.coils import Coils, CreateEquallySpacedCurves
10-
from essos.optimization import optimize_coils_for_particle_confinement
10+
from essos.optimization import optimize_loss_function
11+
from essos.objective_functions import loss_optimize_coils_for_particle_confinement
1112

1213
# Optimization parameters
1314
target_B_on_axis = 5.7
@@ -42,8 +43,13 @@
4243
# Optimize coils
4344
print(f'Optimizing coils with {maximum_function_evaluations} function evaluations and maxtime_tracing={maxtime_tracing}')
4445
time0 = time()
45-
coils_optimized = optimize_coils_for_particle_confinement(coils_initial, particles, target_B_on_axis=target_B_on_axis, maxtime=maxtime_tracing, model=model,
46-
max_coil_length=max_coil_length, maximum_function_evaluations=maximum_function_evaluations, max_coil_curvature=max_coil_curvature)
46+
coils_optimized = optimize_loss_function(loss_optimize_coils_for_particle_confinement, initial_dofs=coils_initial.x, coils=coils_initial,
47+
tolerance_optimization=1e-4, particles=particles,
48+
maximum_function_evaluations=maximum_function_evaluations, max_coil_curvature=max_coil_curvature,
49+
target_B_on_axis=target_B_on_axis, max_coil_length=max_coil_length, model=model,
50+
maxtime=maxtime_tracing, num_steps=500, trace_tolerance=1e-5)
51+
# coils_optimized = optimize_coils_for_particle_confinement(coils_initial, particles, target_B_on_axis=target_B_on_axis, maxtime=maxtime_tracing, model=model,
52+
# max_coil_length=max_coil_length, maximum_function_evaluations=maximum_function_evaluations, max_coil_curvature=max_coil_curvature)
4753
print(f" Optimization took {time()-time0:.2f} seconds")
4854
tracing_optimized = Tracing(field=coils_optimized, particles=particles, maxtime=maxtime_tracing, model=model)
4955

examples/optimize_coils_vmec_surface.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,13 @@
22
number_of_processors_to_use = 8 # Parallelization, this should divide ntheta*nphi
33
os.environ["XLA_FLAGS"] = f'--xla_force_host_platform_device_count={number_of_processors_to_use}'
44
from time import time
5+
import jax.numpy as jnp
56
import matplotlib.pyplot as plt
7+
from essos.surfaces import BdotN_over_B
68
from essos.coils import Coils, CreateEquallySpacedCurves
7-
from essos.fields import Vmec
8-
from essos.optimization import optimize_coils_for_vmec_surface
9+
from essos.fields import Vmec, BiotSavart
10+
from essos.objective_functions import loss_BdotN
11+
from essos.optimization import optimize_loss_function
912

1013
# Optimization parameters
1114
max_coil_length = 40
@@ -38,11 +41,16 @@
3841
# Optimize coils
3942
print(f'Optimizing coils with {maximum_function_evaluations} function evaluations.')
4043
time0 = time()
41-
coils_optimized = optimize_coils_for_vmec_surface(vmec, coils_initial, maximum_function_evaluations=maximum_function_evaluations,
42-
max_coil_length=max_coil_length, max_coil_curvature=max_coil_curvature,
43-
tolerance_optimization=tolerance_optimization)
44+
coils_optimized = optimize_loss_function(loss_BdotN, initial_dofs=coils_initial.x, coils=coils_initial, tolerance_optimization=tolerance_optimization,
45+
maximum_function_evaluations=maximum_function_evaluations, vmec=vmec,
46+
max_coil_length=max_coil_length, max_coil_curvature=max_coil_curvature,)
4447
print(f"Optimization took {time()-time0:.2f} seconds")
4548

49+
BdotN_over_B_initial = BdotN_over_B(vmec.surface, BiotSavart(coils_initial))
50+
BdotN_over_B_optimized = BdotN_over_B(vmec.surface, BiotSavart(coils_optimized))
51+
print(f"Maximum BdotN/B before optimization: {jnp.max(BdotN_over_B_initial):.2e}")
52+
print(f"Maximum BdotN/B after optimization: {jnp.max(BdotN_over_B_optimized):.2e}")
53+
4654
# Plot coils, before and after optimization
4755
fig = plt.figure(figsize=(8, 4))
4856
ax1 = fig.add_subplot(121, projection='3d')

0 commit comments

Comments
 (0)