-
Notifications
You must be signed in to change notification settings - Fork 3
Major refactoring for JAX-style classes. #29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
- Implemented `fo_integrators.py` for full orbit tracing with various methods and parameters. - Implemented `gc_integrators.py` for guiding center dynamics with adaptative and constant step sizes. - Enhanced `Tracing` class in `dynamics.py` to support multiple methods and step sizes.
…d adjust num_steps based on dt
… class for improved step size handling
…for adaptive step size
…ters for performance
…ance plots, and improve layout for better visualization
…n to scale the modes with different norms, optimization.py slightly changed to accomodate changes in surfaces. The example optimize_coils_and_surfaces.py was also changed to accomodate the changes
…d Coils and Curves into correct PyTrees
…arate gamma computation
|
Tests need fixing; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR implements a major refactoring to make coils and surfaces proper JAX PyTrees, enabling automatic differentiation. It introduces a new loss wrapper system for gradient-based optimization and adds comprehensive analysis and validation code comparing ESSOS with SIMSOPT.
Key changes:
- Refactored
Coils,Curves,SurfaceRZFourier, andBiotSavartclasses as JAX PyTrees with proper tree flattening/unflattening - Added
essos/losses.pywithcustom_lossandcomposite_lossclasses for differentiable loss functions - Updated API:
Coils_from_json()→Coils.from_json(),tracing.energyproperty →tracing.energy()method - Added extensive analysis scripts for validation against SIMSOPT
Reviewed changes
Copilot reviewed 32 out of 32 changed files in this pull request and generated 39 comments.
Show a summary per file
| File | Description |
|---|---|
essos/losses.py |
New module implementing base_loss, custom_loss, and composite_loss classes for automatic differentiation |
essos/surfaces.py |
Refactored SurfaceRZFourier as PyTree with cached properties and improved initialization |
essos/coils.py |
Refactored Curves and Coils as PyTrees with cached properties, changed to classmethod constructors |
essos/fields.py |
Added MagneticField base class and registered BiotSavart as PyTree |
essos/dynamics.py |
Changed energy from cached property to method, added Particles.join() method |
essos/objective_functions.py |
Removed deprecated loss functions, added new coil separation and curvature losses |
essos/optimization.py |
Updated surface instantiation to include mpol/ntor parameters |
examples/optimize_coils_vmec_surface.py |
Major rewrite using new loss wrapper API instead of old optimization functions |
examples/trace_particles_coils_guidingcenter.py |
Updated imports and API calls (from_json, energy method) |
examples/trace_fieldlines_coils.py |
Updated to use Coils.from_json() |
examples/optimize_coils_particle_confinement_fullorbit.py |
Minor formatting and parameter updates |
examples/optimize_coils_and_surface.py |
Added mpol/ntor parameters, simplified loss calculations |
examples/input_files/*. |
Updated VMEC input file coefficients |
examples/comparisons_SIMSOPT/*.py |
Deleted old comparison scripts |
examples/compare_guidingcenter_fullorbit.py |
Updated particle initialization and energy calculation |
analysis/*.py |
New analysis scripts for validation and benchmarking |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| # if hasattr(curves, 'n_base_curves') and hasattr(currents, 'size'): | ||
| # assert curves.n_base_curves == currents.size, "Number of base curves and number of currents must be the same" |
Copilot
AI
Dec 6, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment appears to contain commented-out code.
| # if hasattr(curves, 'n_base_curves') and hasattr(currents, 'size'): | |
| # assert curves.n_base_curves == currents.size, "Number of base curves and number of currents must be the same" | |
| if hasattr(curves, 'n_base_curves') and hasattr(currents, 'size'): | |
| assert curves.n_base_curves == currents.size, "Number of base curves and number of currents must be the same" |
| # if nphi is not None: | ||
| # self.nphi = nphi | ||
| # else: | ||
| # nphi = self.nphi | ||
|
|
||
| # #rc_new = jnp.zeros((mpol, 2 * ntor + 1)) | ||
| # #zs_new = jnp.zeros((mpol, 2 * ntor + 1)) | ||
| # rc_new = jnp.zeros(((mpol+1)*( 2 * ntor + 1)-ntor)) | ||
| # zs_new = jnp.zeros(((mpol+1)*( 2 * ntor + 1)-ntor)) | ||
| # m_keep = min(mpol_old, mpol) | ||
| # n_keep = min(ntor_old, ntor) | ||
|
|
||
| # xm_old=self.xm | ||
| # xn_old=self.xn | ||
| # self.xm = jnp.repeat(jnp.arange(mpol+1), 2*ntor+1)[ntor:] | ||
| # self.xn = self.nfp*jnp.tile(jnp.arange(-ntor, ntor + 1), mpol+1)[ntor:] | ||
| # # Copy overlapping region | ||
| # for l in range(len(self.xm)): | ||
| # if self.xm[l]<=m_keep and jnp.abs(self.xn[l]/self.nfp)<=n_keep: | ||
| # index=self.xm[l]*(ntor_old*2+1)-self.xn[l]//self.nfp | ||
| # rc_new=rc_new.at[l].set(self.rc[index]) | ||
| # zs_new=zs_new.at[l].set(self.zs[index]) | ||
|
|
||
|
|
||
| # # Update attributes | ||
| # self.mpol, self.ntor = mpol, ntor | ||
| # self.rc, self.zs = rc_new, zs_new | ||
|
|
||
| # self.rmnc_interp = self.rc | ||
| # self.zmns_interp = self.zs | ||
|
|
Copilot
AI
Dec 6, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment appears to contain commented-out code.
| # if nphi is not None: | |
| # self.nphi = nphi | |
| # else: | |
| # nphi = self.nphi | |
| # #rc_new = jnp.zeros((mpol, 2 * ntor + 1)) | |
| # #zs_new = jnp.zeros((mpol, 2 * ntor + 1)) | |
| # rc_new = jnp.zeros(((mpol+1)*( 2 * ntor + 1)-ntor)) | |
| # zs_new = jnp.zeros(((mpol+1)*( 2 * ntor + 1)-ntor)) | |
| # m_keep = min(mpol_old, mpol) | |
| # n_keep = min(ntor_old, ntor) | |
| # xm_old=self.xm | |
| # xn_old=self.xn | |
| # self.xm = jnp.repeat(jnp.arange(mpol+1), 2*ntor+1)[ntor:] | |
| # self.xn = self.nfp*jnp.tile(jnp.arange(-ntor, ntor + 1), mpol+1)[ntor:] | |
| # # Copy overlapping region | |
| # for l in range(len(self.xm)): | |
| # if self.xm[l]<=m_keep and jnp.abs(self.xn[l]/self.nfp)<=n_keep: | |
| # index=self.xm[l]*(ntor_old*2+1)-self.xn[l]//self.nfp | |
| # rc_new=rc_new.at[l].set(self.rc[index]) | |
| # zs_new=zs_new.at[l].set(self.zs[index]) | |
| # # Update attributes | |
| # self.mpol, self.ntor = mpol, ntor | |
| # self.rc, self.zs = rc_new, zs_new | |
| # self.rmnc_interp = self.rc | |
| # self.zmns_interp = self.zs |
| # self._dofs = jnp.concatenate((self.rescaling_function(jnp.ravel(self.rc)), self.rescaling_function(jnp.ravel(self.zs)))) | ||
|
|
||
| # # Recompute angles and geometry | ||
| # if self.range_torus == 'full torus': div = 1 | ||
| # else: div = self.nfp | ||
| # if self.range_torus == 'half period': end_val = 0.5 | ||
| # else: end_val = 1.0 | ||
| # self.quadpoints_theta = jnp.linspace(0, 2 * jnp.pi, num=ntheta, endpoint=True if close else False) | ||
| # self.quadpoints_phi = jnp.linspace(0, 2 * jnp.pi * end_val / div, num=nphi, endpoint=True if close else False) | ||
| # self.theta_2d, self.phi_2d = jnp.meshgrid(self.quadpoints_theta, self.quadpoints_phi) | ||
|
|
||
| # self.angles = (jnp.einsum('i,jk->ijk', self.xm, self.theta_2d)- jnp.einsum('i,jk->ijk', self.xn, self.phi_2d)) | ||
| # (self._gamma, self._gammadash_theta, self._gammadash_phi, | ||
| # self._normal, self._unitnormal) = self._set_gamma(self.rmnc_interp, self.zmns_interp) | ||
|
|
||
|
|
||
| # # Recompute AbsB if available | ||
| # if hasattr(self, 'bmnc'): | ||
| # self._AbsB = self._set_AbsB() | ||
|
|
||
| # return self | ||
|
|
Copilot
AI
Dec 6, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment appears to contain commented-out code.
| # self._dofs = jnp.concatenate((self.rescaling_function(jnp.ravel(self.rc)), self.rescaling_function(jnp.ravel(self.zs)))) | |
| # # Recompute angles and geometry | |
| # if self.range_torus == 'full torus': div = 1 | |
| # else: div = self.nfp | |
| # if self.range_torus == 'half period': end_val = 0.5 | |
| # else: end_val = 1.0 | |
| # self.quadpoints_theta = jnp.linspace(0, 2 * jnp.pi, num=ntheta, endpoint=True if close else False) | |
| # self.quadpoints_phi = jnp.linspace(0, 2 * jnp.pi * end_val / div, num=nphi, endpoint=True if close else False) | |
| # self.theta_2d, self.phi_2d = jnp.meshgrid(self.quadpoints_theta, self.quadpoints_phi) | |
| # self.angles = (jnp.einsum('i,jk->ijk', self.xm, self.theta_2d)- jnp.einsum('i,jk->ijk', self.xn, self.phi_2d)) | |
| # (self._gamma, self._gammadash_theta, self._gammadash_phi, | |
| # self._normal, self._unitnormal) = self._set_gamma(self.rmnc_interp, self.zmns_interp) | |
| # # Recompute AbsB if available | |
| # if hasattr(self, 'bmnc'): | |
| # self._AbsB = self._set_AbsB() | |
| # return self |
| @property | ||
| def dependencies_buffer(self): | ||
| if self._dependencies_buffer is None: | ||
| self._dependencies_buffer = tree_util.tree_map(lambda x: jnp.zeros_like(x), self.dependencies) |
Copilot
AI
Dec 6, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This 'lambda' is just a simple wrapper around a callable object. Use that object directly.
| self._dependencies_buffer = tree_util.tree_map(lambda x: jnp.zeros_like(x), self.dependencies) | |
| self._dependencies_buffer = tree_util.tree_map(jnp.zeros_like, self.dependencies) |
| json_file_stel = curves_stel | ||
| field_simsopt = load(json_file_stel) | ||
| coils_simsopt = field_simsopt.coils | ||
| curves_simsopt = [coil.curve for coil in coils_simsopt] |
Copilot
AI
Dec 6, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This assignment to 'curves_simsopt' is unnecessary as it is redefined before this value is used.
| compile_tracing = Tracing('GuidingCenter', field_essos, tmax_gc, timesteps=100, method='Dopri5', | ||
| stepsize='adaptive', tol_step_size=trace_tolerance_array[0], particles=particles) | ||
| block_until_ready(compile_tracing.trajectories) | ||
|
|
||
| for index, trace_tolerance_ESSOS in enumerate(trace_tolerance_array): | ||
| num_steps_essos = avg_steps_SIMSOPT_array[index] | ||
| print(f'Tracing ESSOS guiding center with tolerance={trace_tolerance_ESSOS}') | ||
| start_time = time() | ||
| tracing = Tracing('GuidingCenter', field_essos, tmax_gc, timesteps=num_steps_essos, method='Dopri5', | ||
| stepsize='adaptive', tol_step_size=trace_tolerance_ESSOS, particles=particles) |
Copilot
AI
Dec 6, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Keyword argument 'timesteps' is not a supported parameter name of Tracing.init.
Keyword argument 'tol_step_size' is not a supported parameter name of Tracing.init.
Keyword argument 'method' is not a supported parameter name of Tracing.init.
Keyword argument 'stepsize' is not a supported parameter name of Tracing.init.
| compile_tracing = Tracing('GuidingCenter', field_essos, tmax_gc, timesteps=100, method='Dopri5', | |
| stepsize='adaptive', tol_step_size=trace_tolerance_array[0], particles=particles) | |
| block_until_ready(compile_tracing.trajectories) | |
| for index, trace_tolerance_ESSOS in enumerate(trace_tolerance_array): | |
| num_steps_essos = avg_steps_SIMSOPT_array[index] | |
| print(f'Tracing ESSOS guiding center with tolerance={trace_tolerance_ESSOS}') | |
| start_time = time() | |
| tracing = Tracing('GuidingCenter', field_essos, tmax_gc, timesteps=num_steps_essos, method='Dopri5', | |
| stepsize='adaptive', tol_step_size=trace_tolerance_ESSOS, particles=particles) | |
| compile_tracing = Tracing('GuidingCenter', field_essos, tmax_gc, particles=particles) | |
| block_until_ready(compile_tracing.trajectories) | |
| for index, trace_tolerance_ESSOS in enumerate(trace_tolerance_array): | |
| num_steps_essos = avg_steps_SIMSOPT_array[index] | |
| print(f'Tracing ESSOS guiding center with tolerance={trace_tolerance_ESSOS}') | |
| start_time = time() | |
| tracing = Tracing('GuidingCenter', field_essos, tmax_gc, particles=particles) |
| tracing_fo = Tracing(field=field, model='FullOrbit', particles=particles, maxtime=tmax_fo, | ||
| timesteps=timesteps_fo, tol_step_size=trace_tolerance) |
Copilot
AI
Dec 6, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Keyword argument 'timesteps' is not a supported parameter name of Tracing.init.
Keyword argument 'tol_step_size' is not a supported parameter name of Tracing.init.
| tracing_gc = Tracing(field=field, model='GuidingCenter', particles=particles, maxtime=tmax_gc, | ||
| timesteps=timesteps_gc, tol_step_size=trace_tolerance) |
Copilot
AI
Dec 6, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Keyword argument 'timesteps' is not a supported parameter name of Tracing.init.
Keyword argument 'tol_step_size' is not a supported parameter name of Tracing.init.
| nfp=number_of_field_periods, stellsym=True) | ||
| coils_essos = Coils(curves=curves_essos, currents=[current_on_each_coil]*number_coils_per_half_field_period) | ||
| field_essos = BiotSavart(coils_essos) | ||
| surface_essos = SurfaceRZFourier_ESSOS(vmec, ntheta=ntheta, nphi=nphi, close=False) |
Copilot
AI
Dec 6, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Call to SurfaceRZFourier.init with too few arguments; should be no fewer than 5.
| surface_essos = SurfaceRZFourier_ESSOS(vmec, ntheta=ntheta, nphi=nphi, close=False) | |
| surface_essos = SurfaceRZFourier_ESSOS(vmec, order_Fourier_series_coils, ntheta=ntheta, nphi=nphi, close=False) |
| EXPORT = False |
Copilot
AI
Dec 6, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This statement is unreachable.
| EXPORT = False | |
| EXPORT = True |
Refactor of coils & surfaces to be proper PyTrees;
Added a loss wrapper to differentiate with respect to the dogs (PyTree leaves);
Added analysis & validation of the code