Skip to content

Commit b77a6e0

Browse files
committed
fix conflicts
2 parents 9f270d1 + 02c940f commit b77a6e0

File tree

64 files changed

+244
-153
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

64 files changed

+244
-153
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ repos:
66
- id: end-of-file-fixer
77
- id: trailing-whitespace
88
- repo: https://github.com/psf/black-pre-commit-mirror
9-
rev: 25.12.0
9+
rev: 26.1.0
1010
hooks:
1111
- id: black
1212
files: ^src/

src/spikeinterface/benchmark/benchmark_sorter.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from spikeinterface.sorters import run_sorter
99
from spikeinterface.comparison import compare_sorter_to_ground_truth
1010

11-
1211
# TODO later integrate CollisionGTComparison optionally in this class.
1312

1413

src/spikeinterface/benchmark/benchmark_sorter_without_gt.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
from spikeinterface.benchmark import analyse_residual
1212

13-
1413
# TODO later integrate CollisionGTComparison optionally in this class.
1514

1615

src/spikeinterface/benchmark/tests/test_residual_analysis.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
from spikeinterface.benchmark import analyse_residual
99

10-
1110
job_kwargs = dict(n_jobs=-1, progress_bar=True)
1211

1312

src/spikeinterface/comparison/tests/test_multisortingcomparison.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from spikeinterface.extractors import NumpySorting
1010
from spikeinterface.comparison import compare_multiple_sorters, MultiSortingComparison
1111

12-
1312
ON_GITHUB = bool(os.getenv("GITHUB_ACTIONS"))
1413

1514

src/spikeinterface/comparison/tests/test_templatecomparison.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from spikeinterface.core import create_sorting_analyzer, generate_ground_truth_recording
66
from spikeinterface.comparison import compare_templates, compare_multiple_templates
77

8-
98
# def setup_module():
109
# if test_dir.is_dir():
1110
# shutil.rmtree(test_dir)

src/spikeinterface/core/analyzer_extension_core.py

Lines changed: 45 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -961,18 +961,28 @@ def get_metric_column_descriptions(cls, metric_names=None):
961961
)
962962
return metric_column_descriptions
963963

964-
def _cast_metrics(self, metrics_df):
965-
metric_dtypes = {}
966-
for m in self.metric_list:
967-
metric_dtypes.update(m.metric_columns)
968-
969-
for col in metrics_df.columns:
970-
if col in metric_dtypes:
971-
try:
972-
metrics_df[col] = metrics_df[col].astype(metric_dtypes[col])
973-
except Exception as e:
974-
print(f"Error casting column {col}: {e}")
975-
return metrics_df
964+
@classmethod
965+
def get_optional_dependencies(cls, **params):
966+
metric_names = params.get("metric_names", None)
967+
if metric_names is None:
968+
metric_names = [m.metric_name for m in cls.metric_list]
969+
else:
970+
for metric_name in metric_names:
971+
if metric_name not in [m.metric_name for m in cls.metric_list]:
972+
raise ValueError(
973+
f"Metric {metric_name} not in available metrics {[m.metric_name for m in cls.metric_list]}"
974+
)
975+
metric_depend_on = set()
976+
for metric_name in metric_names:
977+
metric = [m for m in cls.metric_list if m.metric_name == metric_name][0]
978+
for dep in metric.depend_on:
979+
if "|" in dep:
980+
dep_options = dep.split("|")
981+
metric_depend_on.update(dep_options)
982+
else:
983+
metric_depend_on.add(dep)
984+
depend_on = list(cls.depend_on) + list(metric_depend_on)
985+
return depend_on
976986

977987
def _set_params(
978988
self,
@@ -994,6 +1004,8 @@ def _set_params(
9941004
If None, default parameters for all metrics are used.
9951005
delete_existing_metrics : bool, default: False
9961006
If True, existing metrics in the extension will be deleted before computing new ones.
1007+
metrics_to_compute : list[str] | None
1008+
List of metric names to compute. If None, all metrics in `metric_names` are computed.
9971009
other_params : dict
9981010
Additional parameters for metric computation.
9991011
@@ -1208,15 +1220,18 @@ def _get_data(self):
12081220
# convert to correct dtype
12091221
return self.data["metrics"]
12101222

1211-
def set_data(self, ext_data_name, data):
1212-
import pandas as pd
1223+
def _cast_metrics(self, metrics_df):
1224+
metric_dtypes = {}
1225+
for m in self.metric_list:
1226+
metric_dtypes.update(m.metric_columns)
12131227

1214-
if ext_data_name != "metrics":
1215-
return
1216-
if not isinstance(data, pd.DataFrame):
1217-
return
1218-
metrics = self._cast_metrics(data)
1219-
self.data[ext_data_name] = metrics
1228+
for col in metrics_df.columns:
1229+
if col in metric_dtypes:
1230+
try:
1231+
metrics_df[col] = metrics_df[col].astype(metric_dtypes[col])
1232+
except Exception as e:
1233+
print(f"Error casting column {col}: {e}")
1234+
return metrics_df
12201235

12211236
def _select_extension_data(self, unit_ids: list[int | str]):
12221237
"""
@@ -1331,6 +1346,16 @@ def _split_extension_data(
13311346
new_data = dict(metrics=metrics)
13321347
return new_data
13331348

1349+
def set_data(self, ext_data_name, data):
1350+
import pandas as pd
1351+
1352+
if ext_data_name != "metrics":
1353+
return
1354+
if not isinstance(data, pd.DataFrame):
1355+
return
1356+
metrics = self._cast_metrics(data)
1357+
self.data[ext_data_name] = metrics
1358+
13341359

13351360
class BaseSpikeVectorExtension(AnalyzerExtension):
13361361
"""

src/spikeinterface/core/base.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
)
2828
from .job_tools import _shared_job_kwargs_doc
2929

30-
3130
# base dtypes used throughout spikeinterface
3231
base_peak_dtype = [
3332
("sample_index", "int64"),

src/spikeinterface/core/baserecordingsnippets.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ def _set_probes(self, probe_or_probegroup, group_mode="auto", in_place=False):
237237
probes_info = []
238238
for probe in probegroup.probes:
239239
probes_info.append(probe.annotations)
240-
self.annotate(probes_info=probes_info)
240+
sub_recording.annotate(probes_info=probes_info)
241241

242242
return sub_recording
243243

@@ -264,6 +264,12 @@ def get_probegroup(self):
264264
probegroup.add_probe(probe)
265265
else:
266266
probegroup = ProbeGroup.from_numpy(arr)
267+
268+
if "probes_info" in self.get_annotation_keys():
269+
probes_info = self.get_annotation("probes_info")
270+
for probe, probe_info in zip(probegroup.probes, probes_info):
271+
probe.annotations = probe_info
272+
267273
for probe_index, probe in enumerate(probegroup.probes):
268274
contour = self.get_annotation(f"probe_{probe_index}_planar_contour")
269275
if contour is not None:

src/spikeinterface/core/job_tools.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818

1919
from spikeinterface.core.core_tools import convert_string_to_bytes, convert_bytes_to_str, convert_seconds_to_str
2020

21-
2221
_shared_job_kwargs_doc = """**job_kwargs : keyword arguments for parallel processing:
2322
* chunk_duration or chunk_size or chunk_memory or total_memory
2423
- chunk_size : int

0 commit comments

Comments
 (0)