Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
e2875a4
Expose radius params and similarity lags in lupin
samuelgarcia Jan 29, 2026
42f3a07
Merge branch 'main' of github.com:SpikeInterface/spikeinterface into …
samuelgarcia Jan 29, 2026
b711eee
boundaries for num_shifts
samuelgarcia Jan 29, 2026
b750808
better seed for lupin
samuelgarcia Jan 29, 2026
e81f7e9
oups
samuelgarcia Jan 29, 2026
270d113
lupin save recording lazy in analyzer
samuelgarcia Jan 29, 2026
7240a08
propagate to tdc2
samuelgarcia Feb 2, 2026
268b173
Merge branch 'main' of github.com:SpikeInterface/spikeinterface into …
samuelgarcia Feb 3, 2026
1874d04
Lupin use now mop-circus
samuelgarcia Feb 3, 2026
9746c9a
Merge branch 'lupin_updates' of github.com:samuelgarcia/spikeinterfac…
yger Feb 4, 2026
948f146
Handling lags
yger Feb 4, 2026
6743f76
Merge branch 'main' of github.com:SpikeInterface/spikeinterface into …
samuelgarcia Feb 4, 2026
4139593
Merge branch 'lupin_updates' of github.com:samuelgarcia/spikeinterfac…
yger Feb 4, 2026
555f547
WIP
yger Feb 20, 2026
74dd3d6
WIP
yger Feb 20, 2026
8ba425d
WIP
yger Feb 20, 2026
ebcbedc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 20, 2026
9294619
Fixing tests
yger Feb 20, 2026
9145b56
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 20, 2026
e1fc209
Apply suggestion from @samuelgarcia
samuelgarcia Mar 11, 2026
6225311
Apply suggestion from @samuelgarcia
samuelgarcia Mar 11, 2026
ff63749
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 11, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/spikeinterface/sorters/internal/lupin.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,10 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
job_kwargs=job_kwargs,
)

if more_outs["time_shifts"] is not None:
time_shifts = more_outs["time_shifts"]
peaks["sample_index"] += time_shifts

mask = clustering_label >= 0
kept_peaks = peaks[mask]
kept_labels = clustering_label[mask]
Expand Down
1 change: 1 addition & 0 deletions src/spikeinterface/sorters/internal/tridesclous2.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
sorting = NumpySorting(final_spikes, sampling_frequency, templates.unit_ids)

auto_merge = True

analyzer_final = None
if auto_merge:
from spikeinterface.sorters.internal.spyking_circus2 import final_cleaning_circus
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,13 +166,15 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()):
print("Kept %d raw clusters" % len(labels))

if params["merge_from_templates"] is not None:
peak_labels, merge_template_array, new_sparse_mask, new_unit_ids = merge_peak_labels_from_templates(
peaks,
peak_labels,
templates.unit_ids,
templates.templates_array,
new_sparse_mask,
**params["merge_from_templates"],
peak_labels, merge_template_array, new_sparse_mask, new_unit_ids, time_shifts = (
merge_peak_labels_from_templates(
peaks,
peak_labels,
templates.unit_ids,
templates.templates_array,
new_sparse_mask,
**params["merge_from_templates"],
)
)

templates = Templates(
Expand All @@ -185,6 +187,8 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()):
probe=recording.get_probe(),
is_in_uV=False,
)
else:
time_shifts = None

# clean very small cluster before peeler
if (
Expand Down Expand Up @@ -212,6 +216,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()):
more_outs = dict(
svd_model=svd_model,
peaks_svd=peaks_svd,
time_shifts=time_shifts,
peak_svd_sparse_mask=sparse_mask,
)
return labels, peak_labels, more_outs
Original file line number Diff line number Diff line change
Expand Up @@ -293,16 +293,19 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()):
num_shifts = params_merge_from_templates["num_shifts"]
num_shifts = min((num_shifts, nbefore, nafter))
params_merge_from_templates["num_shifts"] = num_shifts
post_merge_label2, templates_array, template_sparse_mask, unit_ids = merge_peak_labels_from_templates(
peaks,
post_merge_label1,
unit_ids,
templates_array,
template_sparse_mask,
**params_merge_from_templates,
post_merge_label2, templates_array, template_sparse_mask, unit_ids, time_shifts = (
merge_peak_labels_from_templates(
peaks,
post_merge_label1,
unit_ids,
templates_array,
template_sparse_mask,
**params_merge_from_templates,
)
)
else:
post_merge_label2 = post_merge_label1.copy()
time_shifts = None

dense_templates = Templates(
templates_array=templates_array,
Expand Down Expand Up @@ -337,7 +340,5 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()):

labels_set = templates.unit_ids

more_outs = dict(
templates=templates,
)
more_outs = dict(templates=templates, time_shifts=time_shifts)
return labels_set, final_peak_labels, more_outs
22 changes: 15 additions & 7 deletions src/spikeinterface/sortingcomponents/clustering/merging_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,13 +558,13 @@ def merge_peak_labels_from_templates(
if not use_lags:
lags = None

clean_labels, merge_template_array, merge_sparsity_mask, new_unit_ids = (
clean_labels, merge_template_array, merge_sparsity_mask, new_unit_ids, time_shifts = (
_apply_pair_mask_on_labels_and_recompute_templates(
pair_mask, peak_labels, unit_ids, templates_array, template_sparse_mask, lags
)
)

return clean_labels, merge_template_array, merge_sparsity_mask, new_unit_ids
return clean_labels, merge_template_array, merge_sparsity_mask, new_unit_ids, time_shifts


def _apply_pair_mask_on_labels_and_recompute_templates(
Expand All @@ -582,7 +582,10 @@ def _apply_pair_mask_on_labels_and_recompute_templates(
clean_labels = peak_labels.copy()
n_components, group_labels = connected_components(pair_mask, directed=False, return_labels=True)

# print("merges", templates_array.shape[0], "to", n_components)
if lags is not None:
time_shifts = np.zeros(len(peak_labels), dtype=np.int32)
else:
time_shifts = None

merge_template_array = templates_array.copy()
merge_sparsity_mask = template_sparse_mask.copy()
Expand All @@ -605,10 +608,15 @@ def _apply_pair_mask_on_labels_and_recompute_templates(

for i, l in enumerate(merge_group):
label = unit_ids[l]
weights[i] = np.sum(peak_labels == label)
mask = peak_labels == label
weights[i] = np.sum(mask)
if i > 0:
clean_labels[peak_labels == label] = unit_ids[g0]
clean_labels[mask] = unit_ids[g0]
keep_template[l] = False
if lags is not None:
shift = lags[l, g0] # which is the same as -lags[g0, l]
time_shifts[mask] += shift

weights /= weights.sum()

if lags is None:
Expand All @@ -619,7 +627,7 @@ def _apply_pair_mask_on_labels_and_recompute_templates(
# with shifts
accumulated_template = np.zeros_like(merge_template_array[g0, :, :])
for i, l in enumerate(merge_group):
shift = -lags[g0, l]
shift = lags[l, g0] # which is the same as -lags[g0, l]
if shift > 0:
# template is shifted to right
temp = np.zeros_like(accumulated_template)
Expand All @@ -639,4 +647,4 @@ def _apply_pair_mask_on_labels_and_recompute_templates(
merge_template_array = merge_template_array[keep_template, :, :]
merge_sparsity_mask = merge_sparsity_mask[keep_template, :]

return clean_labels, merge_template_array, merge_sparsity_mask, new_unit_ids
return clean_labels, merge_template_array, merge_sparsity_mask, new_unit_ids, time_shifts
Original file line number Diff line number Diff line change
Expand Up @@ -134,13 +134,15 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()):
print("Kept %d raw clusters" % len(labels))

if params["merge_from_templates"] is not None:
peak_labels, merge_template_array, new_sparse_mask, new_unit_ids = merge_peak_labels_from_templates(
peaks,
peak_labels,
unit_ids,
templates_array,
np.ones((len(unit_ids), num_chans), dtype=bool),
**params["merge_from_templates"],
peak_labels, merge_template_array, new_sparse_mask, new_unit_ids, time_shifts = (
merge_peak_labels_from_templates(
peaks,
peak_labels,
unit_ids,
templates_array,
np.ones((len(unit_ids), num_chans), dtype=bool),
**params["merge_from_templates"],
)
)

templates = Templates(
Expand All @@ -153,6 +155,8 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()):
probe=recording.get_probe(),
is_in_uV=False,
)
else:
time_shifts = None

labels = templates.unit_ids

Expand All @@ -162,4 +166,4 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()):
if verbose:
print("Kept %d non-duplicated clusters" % len(labels))

return labels, peak_labels, dict()
return labels, peak_labels, dict(time_shifts=time_shifts, templates=templates)
Loading