Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 162 additions & 4 deletions dedalus/core/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@


# Public interface
__all__ = ['Jacobi',
__all__ = ['CardinalBasis',
'Jacobi',
'Legendre',
'Ultraspherical',
'Chebyshev',
Expand Down Expand Up @@ -334,10 +335,166 @@ def enum_indices(tensorsig):
# self._ncc_matrices = [self._ncc_matrix_recursion(ncc.data[ind], ncc.domain.full_bases, operand.domain.full_bases, separability, **kw) for ind in np.ndindex(*tshape)]


class IntervalBasis(Basis):
class CardinalBasis(Basis):
"""Cardinal basis."""

dim = 1
group_shape = (1,)
subaxis_dependence = [False]

def __init__(self, coord, size):
self.coord = coord
self.coordsys = coord
self.size = size
self.shape = (size,)
self.dealias = (1,)
super().__init__(coord)

def __add__(self, other):
if other is None or other is self:
return self
return NotImplemented

def __mul__(self, other):
if other is None or other is self:
return self
return NotImplemented

def __rmatmul__(self, other):
# NCC (other) * operand (self)
if other is None or other is self:
return self
return NotImplemented

def elements_to_groups(self, grid_space, elements):
# No permutations
return elements

def valid_elements(self, tensorsig, grid_space, elements):
# No invalid modes
vshape = tuple(cs.dim for cs in tensorsig) + elements[0].shape
return np.ones(shape=vshape, dtype=bool)

def matrix_dependence(self, matrix_coupling):
return matrix_coupling

def global_grids(self, dist, scales):
"""Global grids."""
return (self.global_grid(dist, scales[0]),)

def global_grid(self, dist, scale):
"""Global grid."""
if scale != 1:
raise NotImplementedError("Cardinal basis only supports scale=1.")
return np.arange(self.size)

def local_grids(self, dist, scales):
"""Local grids."""
return (self.local_grid(dist, scales[0]),)

def local_grid(self, dist, scale):
"""Local grid."""
if scale != 1:
raise NotImplementedError("Cardinal basis only supports scale=1.")
local_elements = dist.grid_layout.local_elements(self.domain(dist), scales=scale)
return np.arange(self.size)[local_elements[dist.get_basis_axis(self)]]

def local_modes(self, dist):
"""Local grid."""
local_elements = dist.coeff_layout.local_elements(self.domain(dist), scales=1)
return reshape_vector(local_elements[dist.get_basis_axis(self)], dim=dist.dim, axis=dist.get_basis_axis(self))

def global_shape(self, grid_space, scales):
return self.shape

def chunk_shape(self, grid_space):
return (1,)

def forward_transform(self, field, axis, gdata, cdata):
"""Forward transform field data."""
np.copyto(cdata, gdata)

def backward_transform(self, field, axis, cdata, gdata):
"""Backward transform field data."""
np.copyto(gdata, cdata)


class ConvertConstantCardinal(operators.ConvertConstant, operators.SpectralOperator1D):
"""Convert constant to Cardinal basis."""

output_basis_type = CardinalBasis
subaxis_dependence = [True]
subaxis_coupling = [True]

@staticmethod
def _full_matrix(input_basis, output_basis):
return np.ones(input_basis.size)[None, :]


class InterpolateCardinal(operators.Interpolate, operators.SpectralOperator1D):
"""Interpolate Cardinal basis."""

input_basis_type = CardinalBasis
basis_subaxis = 0
subaxis_dependence = [True]
subaxis_coupling = [True]

def __init__(self, coord, size, position, out=None):
if not isinstance(position, (int, np.integer)):
raise TypeError("Cardinal interpolation position must be an integer")
super().__init__(coord, size, position, out=out)

@staticmethod
def _output_basis(input_basis, position):
return None

@staticmethod
def _full_matrix(input_basis, output_basis, position):
interp_vector = np.zeros(input_basis.size)
interp_vector[position] = 1
return interp_vector[None, :]


class IntegrateCardinal(operators.Integrate, operators.SpectralOperator1D):
"""Cardinal basis integration."""

input_coord_type = Coordinate
input_basis_type = CardinalBasis
subaxis_dependence = [True]
subaxis_coupling = [True]

@staticmethod
def _output_basis(input_basis):
return None

@staticmethod
def _full_matrix(input_basis, output_basis):
integ_vector = np.ones(input_basis.size)
return integ_vector[None, :]


class AverageCardinal(operators.Average, operators.SpectralOperator1D):
"""Cardinal basis averaging."""

input_coord_type = Coordinate
input_basis_type = CardinalBasis
subaxis_dependence = [True]
subaxis_coupling = [True]

@staticmethod
def _output_basis(input_basis):
return None

@staticmethod
def _full_matrix(input_basis, output_basis):
ave_vector = np.ones(input_basis.size) / input_basis.size
return ave_vector[None, :]


class IntervalBasis(Basis):

dim = 1
subaxis_dependence = [False]

def __init__(self, coord, size, bounds, dealias):
self.coord = coord
Expand Down Expand Up @@ -6084,15 +6241,16 @@ def cfl_spacing(self):
velocity = self.operand
coordsys = velocity.tensorsig[0]
spacing = []
for i, c in enumerate(coordsys.coords):
for c in coordsys.coords:
basis = velocity.domain.get_basis(c)
if basis:
dealias = basis.dealias[0]
axis_spacing = basis.local_grid_spacing(self.dist, dealias) * dealias
N = basis.grid_shape((dealias,))[0]
if isinstance(basis, Jacobi) and basis.a == -1/2 and basis.b == -1/2:
#Special case for ChebyshevT (a=b=-1/2)
local_elements = self.dist.grid_layout.local_elements(basis.domain(self.dist), scales=dealias)[i]
axis = self.dist.get_basis_axis(basis)
local_elements = self.dist.grid_layout.local_elements(basis.domain(self.dist), scales=dealias)[axis]
i = np.arange(N)[local_elements].reshape(axis_spacing.shape)
theta = np.pi * (i + 1/2) / N
axis_spacing[:] = dealias * basis.COV.stretch * np.sin(theta) * np.pi / N
Expand Down