Skip to content

Conversation

@EstevaoMGomes
Copy link
Collaborator

Refactor of coils & surfaces to be proper PyTrees;
Added a loss wrapper to differentiate with respect to the dogs (PyTree leaves);
Added analysis & validation of the code

- Implemented `fo_integrators.py` for full orbit tracing with various methods and parameters.
- Implemented `gc_integrators.py` for guiding center dynamics with adaptative and constant step sizes.
- Enhanced `Tracing` class in `dynamics.py` to support multiple methods and step sizes.
…ance plots, and improve layout for better visualization
@EstevaoMGomes
Copy link
Collaborator Author

Tests need fixing;
All examples should be fixed;
Future (minor) changes may affect some examples, such as turning gamma from a property to a function;
Coils should have the normalized dofs as leaves and should not update the normalization at runtime.

Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR implements a major refactoring to make coils and surfaces proper JAX PyTrees, enabling automatic differentiation. It introduces a new loss wrapper system for gradient-based optimization and adds comprehensive analysis and validation code comparing ESSOS with SIMSOPT.

Key changes:

  • Refactored Coils, Curves, SurfaceRZFourier, and BiotSavart classes as JAX PyTrees with proper tree flattening/unflattening
  • Added essos/losses.py with custom_loss and composite_loss classes for differentiable loss functions
  • Updated API: Coils_from_json()Coils.from_json(), tracing.energy property → tracing.energy() method
  • Added extensive analysis scripts for validation against SIMSOPT

Reviewed changes

Copilot reviewed 32 out of 32 changed files in this pull request and generated 39 comments.

Show a summary per file
File Description
essos/losses.py New module implementing base_loss, custom_loss, and composite_loss classes for automatic differentiation
essos/surfaces.py Refactored SurfaceRZFourier as PyTree with cached properties and improved initialization
essos/coils.py Refactored Curves and Coils as PyTrees with cached properties, changed to classmethod constructors
essos/fields.py Added MagneticField base class and registered BiotSavart as PyTree
essos/dynamics.py Changed energy from cached property to method, added Particles.join() method
essos/objective_functions.py Removed deprecated loss functions, added new coil separation and curvature losses
essos/optimization.py Updated surface instantiation to include mpol/ntor parameters
examples/optimize_coils_vmec_surface.py Major rewrite using new loss wrapper API instead of old optimization functions
examples/trace_particles_coils_guidingcenter.py Updated imports and API calls (from_json, energy method)
examples/trace_fieldlines_coils.py Updated to use Coils.from_json()
examples/optimize_coils_particle_confinement_fullorbit.py Minor formatting and parameter updates
examples/optimize_coils_and_surface.py Added mpol/ntor parameters, simplified loss calculations
examples/input_files/*. Updated VMEC input file coefficients
examples/comparisons_SIMSOPT/*.py Deleted old comparison scripts
examples/compare_guidingcenter_fullorbit.py Updated particle initialization and energy calculation
analysis/*.py New analysis scripts for validation and benchmarking

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +373 to +374
# if hasattr(curves, 'n_base_curves') and hasattr(currents, 'size'):
# assert curves.n_base_curves == currents.size, "Number of base curves and number of currents must be the same"
Copy link

Copilot AI Dec 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment appears to contain commented-out code.

Suggested change
# if hasattr(curves, 'n_base_curves') and hasattr(currents, 'size'):
# assert curves.n_base_curves == currents.size, "Number of base curves and number of currents must be the same"
if hasattr(curves, 'n_base_curves') and hasattr(currents, 'size'):
assert curves.n_base_curves == currents.size, "Number of base curves and number of currents must be the same"

Copilot uses AI. Check for mistakes.
Comment on lines +514 to +544
# if nphi is not None:
# self.nphi = nphi
# else:
# nphi = self.nphi

# #rc_new = jnp.zeros((mpol, 2 * ntor + 1))
# #zs_new = jnp.zeros((mpol, 2 * ntor + 1))
# rc_new = jnp.zeros(((mpol+1)*( 2 * ntor + 1)-ntor))
# zs_new = jnp.zeros(((mpol+1)*( 2 * ntor + 1)-ntor))
# m_keep = min(mpol_old, mpol)
# n_keep = min(ntor_old, ntor)

# xm_old=self.xm
# xn_old=self.xn
# self.xm = jnp.repeat(jnp.arange(mpol+1), 2*ntor+1)[ntor:]
# self.xn = self.nfp*jnp.tile(jnp.arange(-ntor, ntor + 1), mpol+1)[ntor:]
# # Copy overlapping region
# for l in range(len(self.xm)):
# if self.xm[l]<=m_keep and jnp.abs(self.xn[l]/self.nfp)<=n_keep:
# index=self.xm[l]*(ntor_old*2+1)-self.xn[l]//self.nfp
# rc_new=rc_new.at[l].set(self.rc[index])
# zs_new=zs_new.at[l].set(self.zs[index])


# # Update attributes
# self.mpol, self.ntor = mpol, ntor
# self.rc, self.zs = rc_new, zs_new

# self.rmnc_interp = self.rc
# self.zmns_interp = self.zs

Copy link

Copilot AI Dec 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment appears to contain commented-out code.

Suggested change
# if nphi is not None:
# self.nphi = nphi
# else:
# nphi = self.nphi
# #rc_new = jnp.zeros((mpol, 2 * ntor + 1))
# #zs_new = jnp.zeros((mpol, 2 * ntor + 1))
# rc_new = jnp.zeros(((mpol+1)*( 2 * ntor + 1)-ntor))
# zs_new = jnp.zeros(((mpol+1)*( 2 * ntor + 1)-ntor))
# m_keep = min(mpol_old, mpol)
# n_keep = min(ntor_old, ntor)
# xm_old=self.xm
# xn_old=self.xn
# self.xm = jnp.repeat(jnp.arange(mpol+1), 2*ntor+1)[ntor:]
# self.xn = self.nfp*jnp.tile(jnp.arange(-ntor, ntor + 1), mpol+1)[ntor:]
# # Copy overlapping region
# for l in range(len(self.xm)):
# if self.xm[l]<=m_keep and jnp.abs(self.xn[l]/self.nfp)<=n_keep:
# index=self.xm[l]*(ntor_old*2+1)-self.xn[l]//self.nfp
# rc_new=rc_new.at[l].set(self.rc[index])
# zs_new=zs_new.at[l].set(self.zs[index])
# # Update attributes
# self.mpol, self.ntor = mpol, ntor
# self.rc, self.zs = rc_new, zs_new
# self.rmnc_interp = self.rc
# self.zmns_interp = self.zs

Copilot uses AI. Check for mistakes.
Comment on lines +548 to +569
# self._dofs = jnp.concatenate((self.rescaling_function(jnp.ravel(self.rc)), self.rescaling_function(jnp.ravel(self.zs))))

# # Recompute angles and geometry
# if self.range_torus == 'full torus': div = 1
# else: div = self.nfp
# if self.range_torus == 'half period': end_val = 0.5
# else: end_val = 1.0
# self.quadpoints_theta = jnp.linspace(0, 2 * jnp.pi, num=ntheta, endpoint=True if close else False)
# self.quadpoints_phi = jnp.linspace(0, 2 * jnp.pi * end_val / div, num=nphi, endpoint=True if close else False)
# self.theta_2d, self.phi_2d = jnp.meshgrid(self.quadpoints_theta, self.quadpoints_phi)

# self.angles = (jnp.einsum('i,jk->ijk', self.xm, self.theta_2d)- jnp.einsum('i,jk->ijk', self.xn, self.phi_2d))
# (self._gamma, self._gammadash_theta, self._gammadash_phi,
# self._normal, self._unitnormal) = self._set_gamma(self.rmnc_interp, self.zmns_interp)


# # Recompute AbsB if available
# if hasattr(self, 'bmnc'):
# self._AbsB = self._set_AbsB()

# return self

Copy link

Copilot AI Dec 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment appears to contain commented-out code.

Suggested change
# self._dofs = jnp.concatenate((self.rescaling_function(jnp.ravel(self.rc)), self.rescaling_function(jnp.ravel(self.zs))))
# # Recompute angles and geometry
# if self.range_torus == 'full torus': div = 1
# else: div = self.nfp
# if self.range_torus == 'half period': end_val = 0.5
# else: end_val = 1.0
# self.quadpoints_theta = jnp.linspace(0, 2 * jnp.pi, num=ntheta, endpoint=True if close else False)
# self.quadpoints_phi = jnp.linspace(0, 2 * jnp.pi * end_val / div, num=nphi, endpoint=True if close else False)
# self.theta_2d, self.phi_2d = jnp.meshgrid(self.quadpoints_theta, self.quadpoints_phi)
# self.angles = (jnp.einsum('i,jk->ijk', self.xm, self.theta_2d)- jnp.einsum('i,jk->ijk', self.xn, self.phi_2d))
# (self._gamma, self._gammadash_theta, self._gammadash_phi,
# self._normal, self._unitnormal) = self._set_gamma(self.rmnc_interp, self.zmns_interp)
# # Recompute AbsB if available
# if hasattr(self, 'bmnc'):
# self._AbsB = self._set_AbsB()
# return self

Copilot uses AI. Check for mistakes.
@property
def dependencies_buffer(self):
if self._dependencies_buffer is None:
self._dependencies_buffer = tree_util.tree_map(lambda x: jnp.zeros_like(x), self.dependencies)
Copy link

Copilot AI Dec 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This 'lambda' is just a simple wrapper around a callable object. Use that object directly.

Suggested change
self._dependencies_buffer = tree_util.tree_map(lambda x: jnp.zeros_like(x), self.dependencies)
self._dependencies_buffer = tree_util.tree_map(jnp.zeros_like, self.dependencies)

Copilot uses AI. Check for mistakes.
json_file_stel = curves_stel
field_simsopt = load(json_file_stel)
coils_simsopt = field_simsopt.coils
curves_simsopt = [coil.curve for coil in coils_simsopt]
Copy link

Copilot AI Dec 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assignment to 'curves_simsopt' is unnecessary as it is redefined before this value is used.

Copilot uses AI. Check for mistakes.
Comment on lines +85 to +94
compile_tracing = Tracing('GuidingCenter', field_essos, tmax_gc, timesteps=100, method='Dopri5',
stepsize='adaptive', tol_step_size=trace_tolerance_array[0], particles=particles)
block_until_ready(compile_tracing.trajectories)

for index, trace_tolerance_ESSOS in enumerate(trace_tolerance_array):
num_steps_essos = avg_steps_SIMSOPT_array[index]
print(f'Tracing ESSOS guiding center with tolerance={trace_tolerance_ESSOS}')
start_time = time()
tracing = Tracing('GuidingCenter', field_essos, tmax_gc, timesteps=num_steps_essos, method='Dopri5',
stepsize='adaptive', tol_step_size=trace_tolerance_ESSOS, particles=particles)
Copy link

Copilot AI Dec 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keyword argument 'timesteps' is not a supported parameter name of Tracing.init.
Keyword argument 'tol_step_size' is not a supported parameter name of Tracing.init.
Keyword argument 'method' is not a supported parameter name of Tracing.init.
Keyword argument 'stepsize' is not a supported parameter name of Tracing.init.

Suggested change
compile_tracing = Tracing('GuidingCenter', field_essos, tmax_gc, timesteps=100, method='Dopri5',
stepsize='adaptive', tol_step_size=trace_tolerance_array[0], particles=particles)
block_until_ready(compile_tracing.trajectories)
for index, trace_tolerance_ESSOS in enumerate(trace_tolerance_array):
num_steps_essos = avg_steps_SIMSOPT_array[index]
print(f'Tracing ESSOS guiding center with tolerance={trace_tolerance_ESSOS}')
start_time = time()
tracing = Tracing('GuidingCenter', field_essos, tmax_gc, timesteps=num_steps_essos, method='Dopri5',
stepsize='adaptive', tol_step_size=trace_tolerance_ESSOS, particles=particles)
compile_tracing = Tracing('GuidingCenter', field_essos, tmax_gc, particles=particles)
block_until_ready(compile_tracing.trajectories)
for index, trace_tolerance_ESSOS in enumerate(trace_tolerance_array):
num_steps_essos = avg_steps_SIMSOPT_array[index]
print(f'Tracing ESSOS guiding center with tolerance={trace_tolerance_ESSOS}')
start_time = time()
tracing = Tracing('GuidingCenter', field_essos, tmax_gc, particles=particles)

Copilot uses AI. Check for mistakes.
Comment on lines +61 to +62
tracing_fo = Tracing(field=field, model='FullOrbit', particles=particles, maxtime=tmax_fo,
timesteps=timesteps_fo, tol_step_size=trace_tolerance)
Copy link

Copilot AI Dec 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keyword argument 'timesteps' is not a supported parameter name of Tracing.init.
Keyword argument 'tol_step_size' is not a supported parameter name of Tracing.init.

Copilot uses AI. Check for mistakes.
Comment on lines +70 to +71
tracing_gc = Tracing(field=field, model='GuidingCenter', particles=particles, maxtime=tmax_gc,
timesteps=timesteps_gc, tol_step_size=trace_tolerance)
Copy link

Copilot AI Dec 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keyword argument 'timesteps' is not a supported parameter name of Tracing.init.
Keyword argument 'tol_step_size' is not a supported parameter name of Tracing.init.

Copilot uses AI. Check for mistakes.
nfp=number_of_field_periods, stellsym=True)
coils_essos = Coils(curves=curves_essos, currents=[current_on_each_coil]*number_coils_per_half_field_period)
field_essos = BiotSavart(coils_essos)
surface_essos = SurfaceRZFourier_ESSOS(vmec, ntheta=ntheta, nphi=nphi, close=False)
Copy link

Copilot AI Dec 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Call to SurfaceRZFourier.init with too few arguments; should be no fewer than 5.

Suggested change
surface_essos = SurfaceRZFourier_ESSOS(vmec, ntheta=ntheta, nphi=nphi, close=False)
surface_essos = SurfaceRZFourier_ESSOS(vmec, order_Fourier_series_coils, ntheta=ntheta, nphi=nphi, close=False)

Copilot uses AI. Check for mistakes.
EXPORT = False
Copy link

Copilot AI Dec 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This statement is unreachable.

Suggested change
EXPORT = False
EXPORT = True

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants