Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/pylibsparseir/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
class FiniteTempBasis(AbstractBasis):
"""Finite temperature basis for intermediate representation."""

def __init__(self, statistics: str, beta: float, wmax: float, eps: float, sve_result: Optional[SVEResult] = None):
def __init__(self, statistics: str, beta: float, wmax: float, eps: float, sve_result: Optional[SVEResult] = None, max_size: int =-1):
"""
Initialize finite temperature basis.

Expand Down Expand Up @@ -47,7 +47,7 @@ def __init__(self, statistics: str, beta: float, wmax: float, eps: float, sve_re

# Create basis
stats_int = STATISTICS_FERMIONIC if statistics == 'F' else STATISTICS_BOSONIC
self._ptr = basis_new(stats_int, self._beta, self._wmax, self._kernel._ptr, self._sve._ptr)
self._ptr = basis_new(stats_int, self._beta, self._wmax, self._kernel._ptr, self._sve._ptr, max_size)

u_funcs = FunctionSet(basis_get_u(self._ptr))
v_funcs = FunctionSet(basis_get_v(self._ptr))
Expand Down
6 changes: 5 additions & 1 deletion src/pylibsparseir/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,8 @@

# Make sure these are available at module level
SPIR_ORDER_ROW_MAJOR = 0
SPIR_ORDER_COLUMN_MAJOR = 1
SPIR_ORDER_COLUMN_MAJOR = 1

# SVE Twork constants
SPIR_TWORK_FLOAT64 = 0
SPIR_TWORK_FLOAT64X2 = 1
23 changes: 16 additions & 7 deletions src/pylibsparseir/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import numpy as np

from .ctypes_wrapper import spir_kernel, spir_sve_result, spir_basis, spir_funcs, spir_sampling
from .constants import COMPUTATION_SUCCESS, ORDER_ROW_MAJOR
from .constants import COMPUTATION_SUCCESS, ORDER_ROW_MAJOR, SPIR_TWORK_FLOAT64, SPIR_TWORK_FLOAT64X2

def _find_library():
"""Find the SparseIR shared library."""
Expand Down Expand Up @@ -68,7 +68,7 @@ def _setup_prototypes():
_lib.spir_kernel_domain.restype = c_int

# SVE result functions
_lib.spir_sve_result_new.argtypes = [spir_kernel, c_double, POINTER(c_int)]
_lib.spir_sve_result_new.argtypes = [spir_kernel, c_double, c_double, c_int, c_int, c_int, POINTER(c_int)]
_lib.spir_sve_result_new.restype = spir_sve_result

_lib.spir_sve_result_get_size.argtypes = [spir_sve_result, POINTER(c_int)]
Expand All @@ -79,7 +79,7 @@ def _setup_prototypes():

# Basis functions
_lib.spir_basis_new.argtypes = [
c_int, c_double, c_double, spir_kernel, spir_sve_result, POINTER(c_int)
c_int, c_double, c_double, spir_kernel, spir_sve_result, c_int, POINTER(c_int)
]
_lib.spir_basis_new.restype = spir_basis

Expand Down Expand Up @@ -262,14 +262,23 @@ def reg_bose_kernel_new(lambda_val):
raise RuntimeError(f"Failed to create regularized bosonic kernel: {status.value}")
return kernel

def sve_result_new(kernel, epsilon):
def sve_result_new(kernel, epsilon, cutoff=None, lmax=None, n_gauss=None, Twork=None):
"""Create a new SVE result."""
# Validate epsilon
if epsilon <= 0:
raise RuntimeError(f"Failed to create SVE result: epsilon must be positive, got {epsilon}")

if cutoff is None:
cutoff = -1.0
if lmax is None:
lmax = -1
if n_gauss is None:
n_gauss = -1
if Twork is None:
Twork = SPIR_TWORK_FLOAT64X2

status = c_int()
sve = _lib.spir_sve_result_new(kernel, epsilon, byref(status))
sve = _lib.spir_sve_result_new(kernel, epsilon, cutoff, lmax, n_gauss, Twork, byref(status))
if status.value != COMPUTATION_SUCCESS:
raise RuntimeError(f"Failed to create SVE result: {status.value}")
return sve
Expand All @@ -291,11 +300,11 @@ def sve_result_get_svals(sve):
raise RuntimeError(f"Failed to get singular values: {status}")
return svals

def basis_new(statistics, beta, omega_max, kernel, sve):
def basis_new(statistics, beta, omega_max, kernel, sve, max_size):
"""Create a new basis."""
status = c_int()
basis = _lib.spir_basis_new(
statistics, beta, omega_max, kernel, sve, byref(status)
statistics, beta, omega_max, kernel, sve, max_size, byref(status)
)
if status.value != COMPUTATION_SUCCESS:
raise RuntimeError(f"Failed to create basis: {status.value}")
Expand Down
Loading