From 816416d55962ee333262893c5b8a20d1810041bb Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Thu, 8 Jan 2026 13:16:01 +0100 Subject: [PATCH 1/8] WIP --- .../postprocessing/correlograms.py | 341 ++++++++++++++++++ 1 file changed, 341 insertions(+) diff --git a/src/spikeinterface/postprocessing/correlograms.py b/src/spikeinterface/postprocessing/correlograms.py index ce3d1cd4a9..72aa016635 100644 --- a/src/spikeinterface/postprocessing/correlograms.py +++ b/src/spikeinterface/postprocessing/correlograms.py @@ -207,7 +207,10 @@ def _get_data(self): register_result_extension(ComputeCorrelograms) +#register_result_extension(ComputeAutoCorrelograms) compute_correlograms_sorting_analyzer = ComputeCorrelograms.function_factory() +#compute_auto_correlograms_sorting_analyzer = ComputeAutoCorrelograms.function_factory() + def compute_correlograms( @@ -234,6 +237,7 @@ def compute_correlograms( compute_correlograms.__doc__ = compute_correlograms_sorting_analyzer.__doc__ +#compute_auto_correlograms.__doc__ = compute_auto_correlograms_sorting_analyzer.__doc__ def _make_bins(sorting, window_ms, bin_ms) -> tuple[np.ndarray, int, int]: @@ -602,7 +606,344 @@ def _compute_correlograms_one_segment_numba( bin = diff // bin_size correlograms[spike_unit_indices[i], spike_unit_indices[j], num_half_bins + bin] += 1 + + + @numba.jit( + nopython=True, + nogil=True, + cache=False, + ) + def _compute_auto_correlograms_one_segment_numba( + correlograms, spike_times, spike_unit_indices, window_size, bin_size, num_half_bins + ): + """ + Compute the correlograms using `numba` for speed. + + The algorithm works by brute-force iteration through all + pairs of spikes (skipping those when outside of the window). + The spike-time difference and its time bin are computed + and stored in a (num_units, num_units, num_bins) + correlogram. The correlogram must be passed as an + argument and is filled in-place. + + Parameters + --------- + + correlograms: np.array + A (num_units, num_bins) array of auto_correlograms + between all units at each lag time bin. This is passed + as counts for all segments are added to it. + spike_times : np.ndarray + An array of spike times (in samples, not seconds). + This contains spikes from all units. + spike_unit_indices : np.ndarray + An array of labels indicating the unit of the corresponding + spike in `spike_times`. + window_size : int + The window size over which to perform the cross-correlation, in samples + bin_size : int + The size of which to bin lags, in samples. + """ + start_j = 0 + for i in range(spike_times.size): + for j in range(start_j, spike_times.size): + if i == j: + continue + + if spike_unit_indices[i] != spike_unit_indices[j]: + continue + + diff = spike_times[i] - spike_times[j] + + # When the diff is exactly the window size, keep going + # without iterating start_j in case this spike also has + # other diffs with other units that == window size. + if diff == window_size: + continue + + # if the time of spike i is more than window size later than + # spike j, then spike i + 1 will also be more than a window size + # later than spike j. Iterate the start_j and check the next spike. + if diff > window_size: + start_j += 1 + continue + + # If the time of spike i is more than a window size earlier + # than spike j, then all following j spikes will be even later + # i spikes and so all more than a window size earlier. So move + # onto the next i. + if diff < -window_size: + break + + bin = diff // bin_size + + correlograms[spike_unit_indices[i], num_half_bins + bin] += 1 + + +###### ACG area ###### + +def compute_auto_correlograms( + sorting_analyzer_or_sorting, + window_ms: float = 50.0, + bin_ms: float = 1.0, + method: str = "auto", +): + """ + Compute correlograms using Numba or Numpy. + See ComputeCorrelograms() for details. + """ + if isinstance(sorting_analyzer_or_sorting, MockWaveformExtractor): + sorting_analyzer_or_sorting = sorting_analyzer_or_sorting.sorting + + if isinstance(sorting_analyzer_or_sorting, SortingAnalyzer): + return _compute_auto_correlograms_sorting_analyzer( + sorting_analyzer_or_sorting, window_ms=window_ms, bin_ms=bin_ms, method=method + ) + else: + return _compute_auto_correlograms_on_sorting( + sorting_analyzer_or_sorting, window_ms=window_ms, bin_ms=bin_ms, method=method + ) + +def _compute_auto_correlograms_on_sorting(sorting, window_ms, bin_ms, method="auto"): + """ + Computes auto-correlograms from multiple units. + + Entry function to compute correlograms across all units in a `Sorting` + object (i.e. spike trains at all determined offsets will be computed + for each unit against every other unit). + + Parameters + ---------- + sorting : Sorting + A SpikeInterface Sorting object + window_ms : float + The window size over which to perform the cross-correlation, in ms + bin_ms : float + The size of which to bin lags, in ms. + method : str + To use "numpy" or "numba". "auto" will use numba if available, + otherwise numpy. + + Returns + ------- + correlograms : np.array + A (num_units, num_bins) array where unit x unit correlation + matrices are stacked at all determined time bins. Note the true + correlation is not returned but instead the count of number of matches. + bins : np.array + The bins edges in ms + """ + assert method in ("auto", "numba", "numpy"), "method must be 'auto', 'numba' or 'numpy'" + + if method == "auto": + method = "numba" if HAVE_NUMBA else "numpy" + + bins, window_size, bin_size = _make_bins(sorting, window_ms, bin_ms) + + if method == "numpy": + correlograms = _compute_auto_correlograms_numpy(sorting, window_size, bin_size) + if method == "numba": + correlograms = _compute_auto_correlograms_numba(sorting, window_size, bin_size) + + return correlograms, bins + + +# LOW-LEVEL IMPLEMENTATIONS +def _compute_auto_correlograms_numpy(sorting, window_size, bin_size): + """ + Computes auto-correlograms for all units in a sorting object. + + This very elegant implementation is copied from phy package written by Cyrille Rossant. + https://github.com/cortex-lab/phylib/blob/master/phylib/stats/ccg.py + + The main modification is the way positive and negative are handled + explicitly for rounding reasons. + + Other slight modifications have been made to fit the SpikeInterface + data model (e.g. adding the ability to handle multiple segments). + + Adaptation: Samuel Garcia + """ + num_seg = sorting.get_num_segments() + num_units = len(sorting.unit_ids) + spikes = sorting.to_spike_vector(concatenated=False) + + num_bins, num_half_bins = _compute_num_bins(window_size, bin_size) + + correlograms = np.zeros((num_units, num_bins), dtype="int64") + + for seg_index in range(num_seg): + spike_times = spikes[seg_index]["sample_index"] + spike_unit_indices = spikes[seg_index]["unit_index"] + + c0 = auto_correlogram_for_one_segment(spike_times, spike_unit_indices, window_size, bin_size) + + correlograms += c0 + + return correlograms + + +def auto_correlogram_for_one_segment(spike_times, spike_unit_indices, window_size, bin_size): + """ + A very well optimized algorithm for the auto-correlation of + spike trains, copied from the Phy package, written by Cyrille Rossant. + + Parameters + ---------- + spike_times : np.ndarray + An array of spike times (in samples, not seconds). + This contains spikes from all units. + spike_unit_indices : np.ndarray + An array of labels indicating the unit of the corresponding + spike in `spike_times`. + window_size : int + The window size over which to perform the cross-correlation, in samples + bin_size : int + The size of which to bin lags, in samples. + + Returns + ------- + correlograms : np.array + A (num_units, num_bins) array of correlograms + between all units at each lag time bin. + + Notes + ----- + For all spikes, time difference between this spike and + every other spike within the window is directly computed + and stored as a count in the relevant lag time bin. + + Initially, the spike_times array is shifted by 1 position, and the difference + computed. This gives the time differences between the closest spikes + (skipping the zero-lag case). Next, the differences between + spikes times in samples are converted into units relative to + bin_size ('binarized'). Spikes in which the binarized difference to + their closest neighbouring spike is greater than half the bin-size are + masked. + + Finally, the indices of the (num_units, num_units, num_bins) correlogram + that need incrementing are done so with `ravel_multi_index()`. This repeats + for all shifts along the spike_train until no spikes have a corresponding + match within the window size. + """ + num_bins, num_half_bins = _compute_num_bins(window_size, bin_size) + num_units = len(np.unique(spike_unit_indices)) + + correlograms = np.zeros((num_units, num_bins), dtype="int64") + + for unit_id in range(num_units): + unit_mask = spike_unit_indices == unit_id + spike_times_unit = spike_times[unit_mask] + spike_unit_indices_unit = spike_unit_indices[unit_mask] + + # At a given shift, the mask precises which spikes have matching spikes + # within the correlogram time window. + mask = np.ones_like(spike_times, dtype="bool") + + # The loop continues as long as there is at least one + # spike with a matching spike. + shift = 1 + while mask[:-shift].any(): + # Number of time samples between spike i and spike i+shift. + spike_diff = spike_times[shift:] - spike_times[:-shift] + + for sign in (-1, 1): + # Binarize the delays between spike i and spike i+shift for negative and positive + # the operator // is np.floor_divide + spike_diff_b = (spike_diff * sign) // bin_size + + # Spikes with no matching spikes are masked. + if sign == -1: + mask[:-shift][spike_diff_b < -num_half_bins] = False + else: + mask[:-shift][spike_diff_b >= num_half_bins] = False + + m = mask[:-shift] + + # Find the indices in the raveled correlograms array that need + # to be incremented, taking into account the spike unit labels. + if sign == 1: + indices = np.ravel_multi_index( + (spike_unit_indices[+shift:][m], spike_unit_indices[:-shift][m], spike_diff_b[m] + num_half_bins), + correlograms.shape, + ) + else: + indices = np.ravel_multi_index( + (spike_unit_indices[:-shift][m], spike_unit_indices[+shift:][m], spike_diff_b[m] + num_half_bins), + correlograms.shape, + ) + + # Increment the matching spikes in the correlograms array. + bbins = np.bincount(indices) + correlograms.ravel()[: len(bbins)] += bbins + + if sign == 1: + # For positive sign, the end bin is < num_half_bins (e.g. + # bin = 29, num_half_bins = 30, will go to index 59 (i.e. the + # last bin). For negative sign, the first bin is == num_half_bins + # e.g. bin = -30, with num_half_bins = 30 will go to bin 0. Therefore + # sign == 1 must mask spike_diff_b <= num_half_bins but sign == -1 + # must count all (possibly repeating across units) cases of + # spike_diff_b == num_half_bins. So we turn it back on here + # for the next loop that starts with the -1 case. + mask[:-shift][spike_diff_b == num_half_bins] = True + + shift += 1 + + return correlograms + + +def _compute_auto_correlograms_numba(sorting, window_size, bin_size): + """ + Computes auto-correlograms between all units in `sorting`. + + This is a "brute force" method using compiled code (numba) + to accelerate the computation. See + `_compute_auto_correlograms_one_segment_numba()` for details. + + Parameters + ---------- + sorting : Sorting + A SpikeInterface Sorting object + window_size : int + The window size over which to perform the cross-correlation, in samples + bin_size : int + The size of which to bin lags, in samples. + + Returns + ------- + correlograms: np.array + A (num_units, num_units, num_bins) array of correlograms + between all units at each lag time bin. + + Implementation: Aurélien Wyngaard + """ + assert HAVE_NUMBA, "numba version of this function requires installation of numba" + + num_bins, num_half_bins = _compute_num_bins(window_size, bin_size) + num_units = len(sorting.unit_ids) + + spikes = sorting.to_spike_vector(concatenated=False) + correlograms = np.zeros((num_units, num_bins), dtype=np.int64) + + for seg_index in range(sorting.get_num_segments()): + spike_times = spikes[seg_index]["sample_index"] + spike_unit_indices = spikes[seg_index]["unit_index"] + + _compute_auto_correlograms_one_segment_numba( + correlograms, + spike_times.astype(np.int64, copy=False), + spike_unit_indices.astype(np.int32, copy=False), + window_size, + bin_size, + num_half_bins, + ) + + return correlograms + + +###### 3D ACG area ###### class ComputeACG3D(AnalyzerExtension): """ From bfec01fbdada818cca35cc5d6a20816dba55295d Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Thu, 8 Jan 2026 16:42:34 +0100 Subject: [PATCH 2/8] WIP --- .../postprocessing/correlograms.py | 206 ++++++++++++++++-- 1 file changed, 184 insertions(+), 22 deletions(-) diff --git a/src/spikeinterface/postprocessing/correlograms.py b/src/spikeinterface/postprocessing/correlograms.py index 72aa016635..7bdae61355 100644 --- a/src/spikeinterface/postprocessing/correlograms.py +++ b/src/spikeinterface/postprocessing/correlograms.py @@ -206,10 +206,183 @@ def _get_data(self): return self.data["ccgs"], self.data["bins"] + +class ComputeAutoCorrelograms(AnalyzerExtension): + """ + Compute only auto correlograms of unit spike times. + + Parameters + ---------- + window_ms : float, default: 50.0 + The window around the spike to compute the correlation in ms. For example, + if 50 ms, the correlations will be computed at lags -25 ms ... 25 ms. + bin_ms : float, default: 1.0 + The bin size in ms. This determines the bin size over which to + combine lags. For example, with a window size of -25 ms to 25 ms, and + bin size 1 ms, the correlation will be binned as -25 ms, -24 ms, ... + method : "auto" | "numpy" | "numba", default: "auto" + If "auto" and numba is installed, numba is used, otherwise numpy is used. + + Returns + ------- + correlogram : np.array + Auto Correlograms with shape (num_units, num_bins) + bins : np.array + The bin edges in ms + + Notes + ----- + In the extracellular electrophysiology context, a correlogram + is a visualisation of the results of a cross-correlation + between two spike trains. The cross-correlation slides one spike train + along another sample-by-sample, taking the correlation at each 'lag'. This results + in a plot with 'lag' (i.e. time offset) on the x-axis and 'correlation' + (i.e. how similar to two spike trains are) on the y-axis. In this + implementation, the y-axis result is the 'counts' of spike matches per + time bin (rather than a computer correlation or covariance). + + In the present implementation, a 'window' around spikes is first + specified. For example, if a window of 100 ms is taken, we will + take the correlation at lags from -50 ms to +50 ms around the spike peak. + In theory, we can have as many lags as we have samples. Often, this + visualisation is too high resolution and instead the lags are binned + (e.g. -50 to -45 ms, ..., -5 to 0 ms, 0 to 5 ms, ...., 45 to 50 ms). + When using counts as output, binning the lags involves adding up all counts across + a range of lags. + + + """ + + extension_name = "auto_correlograms" + depend_on = [] + need_recording = False + use_nodepipeline = False + need_job_kwargs = False + + def _set_params(self, window_ms: float = 50.0, bin_ms: float = 1.0, method: str = "auto"): + params = dict(window_ms=window_ms, bin_ms=bin_ms, method=method) + + return params + + def _select_extension_data(self, unit_ids): + # filter metrics dataframe + unit_indices = self.sorting_analyzer.sorting.ids_to_indices(unit_ids) + new_acgs = self.data["acgs"][unit_indices] + new_bins = self.data["bins"] + new_data = dict(ccgs=new_acgs, bins=new_bins) + return new_data + + def _merge_extension_data( + self, merge_unit_groups, new_unit_ids, new_sorting_analyzer, censor_ms=None, verbose=False, **job_kwargs + ): + """ + When two units are merged, their cross-correlograms with other units become the sum + of the previous cross-correlograms. More precisely, if units i and j get merged into + unit k, then the new unit's cross-correlogram with any other unit l is: + C_{k,l} = C_{i,l} + C_{j,l} + C_{l,k} = C_{l,k} + C_{l,j} + Here, we apply this formula to quickly compute correlograms for merged units. + """ + + can_apply_soft_method = True + if censor_ms is not None: + # if censor_ms has no effect, can apply "soft" method. Check if any spikes have been removed + for new_unit_id, merge_unit_group in zip(new_unit_ids, merge_unit_groups): + num_segments = new_sorting_analyzer.get_num_segments() + for segment_index in range(num_segments): + merged_spike_train_length = len( + new_sorting_analyzer.sorting.get_unit_spike_train(new_unit_id, segment_index=segment_index) + ) + + old_spike_train_lengths = len( + np.concatenate( + [ + self.sorting_analyzer.sorting.get_unit_spike_train(unit_id, segment_index=segment_index) + for unit_id in merge_unit_group + ] + ) + ) + + if merged_spike_train_length != old_spike_train_lengths: + can_apply_soft_method = False + break + + if can_apply_soft_method is False: + new_acgs, new_bins = _compute_auto_correlograms_on_sorting(new_sorting_analyzer.sorting, **self.params) + new_data = dict(acgs=new_acgs, bins=new_bins) + else: + # Make a transformation dict, which tells us how unit_indices from the + # old to the new sorter are mapped. + old_to_new_unit_index_map = {} + for old_unit in self.sorting_analyzer.unit_ids: + old_unit_index = self.sorting_analyzer.sorting.id_to_index(old_unit) + unit_involved_in_merge = False + for merge_unit_group, new_unit_id in zip(merge_unit_groups, new_unit_ids): + new_unit_index = new_sorting_analyzer.sorting.id_to_index(new_unit_id) + # check if the old_unit is involved in a merge + if old_unit in merge_unit_group: + # check if it is mapped to itself + if old_unit == new_unit_id: + old_to_new_unit_index_map[old_unit_index] = new_unit_index + # or to a unit_id outwith the old ones + elif new_unit_id not in self.sorting_analyzer.unit_ids: + if new_unit_index not in old_to_new_unit_index_map.values(): + old_to_new_unit_index_map[old_unit_index] = new_unit_index + unit_involved_in_merge = True + if unit_involved_in_merge is False: + old_to_new_unit_index_map[old_unit_index] = new_sorting_analyzer.sorting.id_to_index(old_unit) + + correlograms, new_bins = deepcopy(self.get_data()) + + for new_unit_id, merge_unit_group in zip(new_unit_ids, merge_unit_groups): + merge_unit_group_indices = self.sorting_analyzer.sorting.ids_to_indices(merge_unit_group) + + # Sum unit rows of the correlogram matrix: C_{k,l} = C_{i,l} + C_{j,l} + # and place this sum in all indices from the merge group + new_col = np.sum(correlograms[merge_unit_group_indices, :, :], axis=0) + # correlograms[merge_unit_group_indices[0], :, :] = new_col + correlograms[merge_unit_group_indices, :, :] = new_col + # correlograms[merge_unit_group_indices[1:], :, :] = 0 + + # Sum unit columns of the correlogram matrix: C_{l,k} = C_{l,i} + C_{l,j} + # and put this sum in all indices from the merge group + new_row = np.sum(correlograms[:, merge_unit_group_indices, :], axis=1) + + for merge_unit_group_index in merge_unit_group_indices: + correlograms[:, merge_unit_group_index, :] = new_row + + new_correlograms = np.zeros( + (len(new_sorting_analyzer.unit_ids), len(new_sorting_analyzer.unit_ids), correlograms.shape[2]) + ) + for old_index_1, new_index_1 in old_to_new_unit_index_map.items(): + for old_index_2, new_index_2 in old_to_new_unit_index_map.items(): + new_correlograms[new_index_1, new_index_2, :] = correlograms[old_index_1, old_index_2, :] + new_correlograms[new_index_2, new_index_1, :] = correlograms[old_index_2, old_index_1, :] + + new_data = dict(ccgs=new_correlograms, bins=new_bins) + return new_data + + def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs): + # TODO: for now we just copy + new_acgs, new_bins = _compute_auto_correlograms_on_sorting(new_sorting_analyzer.sorting, **self.params) + new_data = dict(acgs=new_acgs, bins=new_bins) + return new_data + + def _run(self, verbose=False): + acgs, bins = _compute_auto_correlograms_on_sorting(self.sorting_analyzer.sorting, **self.params) + self.data["acgs"] = acgs + self.data["bins"] = bins + + def _get_data(self): + return self.data["acgs"], self.data["bins"] + + + + register_result_extension(ComputeCorrelograms) -#register_result_extension(ComputeAutoCorrelograms) +register_result_extension(ComputeAutoCorrelograms) compute_correlograms_sorting_analyzer = ComputeCorrelograms.function_factory() -#compute_auto_correlograms_sorting_analyzer = ComputeAutoCorrelograms.function_factory() +compute_auto_correlograms_sorting_analyzer = ComputeAutoCorrelograms.function_factory() @@ -235,11 +408,6 @@ def compute_correlograms( sorting_analyzer_or_sorting, window_ms=window_ms, bin_ms=bin_ms, method=method ) - -compute_correlograms.__doc__ = compute_correlograms_sorting_analyzer.__doc__ -#compute_auto_correlograms.__doc__ = compute_auto_correlograms_sorting_analyzer.__doc__ - - def _make_bins(sorting, window_ms, bin_ms) -> tuple[np.ndarray, int, int]: """ Create the bins for the correlogram, in samples. @@ -831,22 +999,20 @@ def auto_correlogram_for_one_segment(spike_times, spike_unit_indices, window_siz correlograms = np.zeros((num_units, num_bins), dtype="int64") - for unit_id in range(num_units): - unit_mask = spike_unit_indices == unit_id + for unit_ind in range(num_units): + unit_mask = spike_unit_indices == unit_ind spike_times_unit = spike_times[unit_mask] - spike_unit_indices_unit = spike_unit_indices[unit_mask] # At a given shift, the mask precises which spikes have matching spikes # within the correlogram time window. - mask = np.ones_like(spike_times, dtype="bool") + mask = np.ones_like(spike_times_unit, dtype="bool") # The loop continues as long as there is at least one # spike with a matching spike. shift = 1 while mask[:-shift].any(): # Number of time samples between spike i and spike i+shift. - spike_diff = spike_times[shift:] - spike_times[:-shift] - + spike_diff = spike_times_unit[shift:] - spike_times_unit[:-shift] for sign in (-1, 1): # Binarize the delays between spike i and spike i+shift for negative and positive # the operator // is np.floor_divide @@ -863,19 +1029,13 @@ def auto_correlogram_for_one_segment(spike_times, spike_unit_indices, window_siz # Find the indices in the raveled correlograms array that need # to be incremented, taking into account the spike unit labels. if sign == 1: - indices = np.ravel_multi_index( - (spike_unit_indices[+shift:][m], spike_unit_indices[:-shift][m], spike_diff_b[m] + num_half_bins), - correlograms.shape, - ) + indices = spike_diff_b[m] + num_half_bins else: - indices = np.ravel_multi_index( - (spike_unit_indices[:-shift][m], spike_unit_indices[+shift:][m], spike_diff_b[m] + num_half_bins), - correlograms.shape, - ) + indices = spike_diff_b[m] + num_half_bins # Increment the matching spikes in the correlograms array. bbins = np.bincount(indices) - correlograms.ravel()[: len(bbins)] += bbins + correlograms[unit_ind, :len(bbins)] += bbins if sign == 1: # For positive sign, the end bin is < num_half_bins (e.g. @@ -1355,3 +1515,5 @@ def compute_acgs_3d( compute_acgs_3d.__doc__ = compute_acgs_3d_sorting_analyzer.__doc__ +compute_correlograms.__doc__ = compute_correlograms_sorting_analyzer.__doc__ +compute_auto_correlograms.__doc__ = compute_auto_correlograms_sorting_analyzer.__doc__ From d278a9ed868a51aabdc9d7fbbca2ad30f72c441a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 8 Jan 2026 15:44:53 +0000 Subject: [PATCH 3/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/postprocessing/correlograms.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/postprocessing/correlograms.py b/src/spikeinterface/postprocessing/correlograms.py index 7bdae61355..cd10935483 100644 --- a/src/spikeinterface/postprocessing/correlograms.py +++ b/src/spikeinterface/postprocessing/correlograms.py @@ -206,7 +206,6 @@ def _get_data(self): return self.data["ccgs"], self.data["bins"] - class ComputeAutoCorrelograms(AnalyzerExtension): """ Compute only auto correlograms of unit spike times. @@ -377,15 +376,12 @@ def _get_data(self): return self.data["acgs"], self.data["bins"] - - register_result_extension(ComputeCorrelograms) register_result_extension(ComputeAutoCorrelograms) compute_correlograms_sorting_analyzer = ComputeCorrelograms.function_factory() compute_auto_correlograms_sorting_analyzer = ComputeAutoCorrelograms.function_factory() - def compute_correlograms( sorting_analyzer_or_sorting, window_ms: float = 50.0, @@ -408,6 +404,7 @@ def compute_correlograms( sorting_analyzer_or_sorting, window_ms=window_ms, bin_ms=bin_ms, method=method ) + def _make_bins(sorting, window_ms, bin_ms) -> tuple[np.ndarray, int, int]: """ Create the bins for the correlogram, in samples. @@ -774,7 +771,6 @@ def _compute_correlograms_one_segment_numba( bin = diff // bin_size correlograms[spike_unit_indices[i], spike_unit_indices[j], num_half_bins + bin] += 1 - @numba.jit( nopython=True, @@ -817,7 +813,7 @@ def _compute_auto_correlograms_one_segment_numba( for j in range(start_j, spike_times.size): if i == j: continue - + if spike_unit_indices[i] != spike_unit_indices[j]: continue @@ -850,6 +846,7 @@ def _compute_auto_correlograms_one_segment_numba( ###### ACG area ###### + def compute_auto_correlograms( sorting_analyzer_or_sorting, window_ms: float = 50.0, @@ -872,6 +869,7 @@ def compute_auto_correlograms( sorting_analyzer_or_sorting, window_ms=window_ms, bin_ms=bin_ms, method=method ) + def _compute_auto_correlograms_on_sorting(sorting, window_ms, bin_ms, method="auto"): """ Computes auto-correlograms from multiple units. @@ -1035,7 +1033,7 @@ def auto_correlogram_for_one_segment(spike_times, spike_unit_indices, window_siz # Increment the matching spikes in the correlograms array. bbins = np.bincount(indices) - correlograms[unit_ind, :len(bbins)] += bbins + correlograms[unit_ind, : len(bbins)] += bbins if sign == 1: # For positive sign, the end bin is < num_half_bins (e.g. @@ -1102,9 +1100,9 @@ def _compute_auto_correlograms_numba(sorting, window_size, bin_size): return correlograms - ###### 3D ACG area ###### + class ComputeACG3D(AnalyzerExtension): """ Computes the 3D Autocorrelograms (3D-ACG) from units spike times to analyze how a neuron's temporal firing From 06463dbbdb767da2ccf23ffac29c710685d1a75c Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Thu, 8 Jan 2026 16:52:35 +0100 Subject: [PATCH 4/8] With appropriate merging strategy --- .../postprocessing/correlograms.py | 66 +++++-------------- 1 file changed, 18 insertions(+), 48 deletions(-) diff --git a/src/spikeinterface/postprocessing/correlograms.py b/src/spikeinterface/postprocessing/correlograms.py index 7bdae61355..629f1b452d 100644 --- a/src/spikeinterface/postprocessing/correlograms.py +++ b/src/spikeinterface/postprocessing/correlograms.py @@ -311,55 +311,25 @@ def _merge_extension_data( new_acgs, new_bins = _compute_auto_correlograms_on_sorting(new_sorting_analyzer.sorting, **self.params) new_data = dict(acgs=new_acgs, bins=new_bins) else: - # Make a transformation dict, which tells us how unit_indices from the - # old to the new sorter are mapped. - old_to_new_unit_index_map = {} - for old_unit in self.sorting_analyzer.unit_ids: - old_unit_index = self.sorting_analyzer.sorting.id_to_index(old_unit) - unit_involved_in_merge = False - for merge_unit_group, new_unit_id in zip(merge_unit_groups, new_unit_ids): - new_unit_index = new_sorting_analyzer.sorting.id_to_index(new_unit_id) - # check if the old_unit is involved in a merge - if old_unit in merge_unit_group: - # check if it is mapped to itself - if old_unit == new_unit_id: - old_to_new_unit_index_map[old_unit_index] = new_unit_index - # or to a unit_id outwith the old ones - elif new_unit_id not in self.sorting_analyzer.unit_ids: - if new_unit_index not in old_to_new_unit_index_map.values(): - old_to_new_unit_index_map[old_unit_index] = new_unit_index - unit_involved_in_merge = True - if unit_involved_in_merge is False: - old_to_new_unit_index_map[old_unit_index] = new_sorting_analyzer.sorting.id_to_index(old_unit) - - correlograms, new_bins = deepcopy(self.get_data()) - - for new_unit_id, merge_unit_group in zip(new_unit_ids, merge_unit_groups): - merge_unit_group_indices = self.sorting_analyzer.sorting.ids_to_indices(merge_unit_group) - - # Sum unit rows of the correlogram matrix: C_{k,l} = C_{i,l} + C_{j,l} - # and place this sum in all indices from the merge group - new_col = np.sum(correlograms[merge_unit_group_indices, :, :], axis=0) - # correlograms[merge_unit_group_indices[0], :, :] = new_col - correlograms[merge_unit_group_indices, :, :] = new_col - # correlograms[merge_unit_group_indices[1:], :, :] = 0 - - # Sum unit columns of the correlogram matrix: C_{l,k} = C_{l,i} + C_{l,j} - # and put this sum in all indices from the merge group - new_row = np.sum(correlograms[:, merge_unit_group_indices, :], axis=1) - - for merge_unit_group_index in merge_unit_group_indices: - correlograms[:, merge_unit_group_index, :] = new_row - - new_correlograms = np.zeros( - (len(new_sorting_analyzer.unit_ids), len(new_sorting_analyzer.unit_ids), correlograms.shape[2]) - ) - for old_index_1, new_index_1 in old_to_new_unit_index_map.items(): - for old_index_2, new_index_2 in old_to_new_unit_index_map.items(): - new_correlograms[new_index_1, new_index_2, :] = correlograms[old_index_1, old_index_2, :] - new_correlograms[new_index_2, new_index_1, :] = correlograms[old_index_2, old_index_1, :] + new_bins = self.data["bins"] + all_new_units = new_sorting_analyzer.unit_ids + num_dims = len(self.data["bins"]) + arr = self.data["acgs"] + new_acgs = np.zeros((len(all_new_units), num_dims), dtype=np.int64) + + # compute all new isi at once + new_sorting = new_sorting_analyzer.sorting.select_units(new_unit_ids) + only_new_acgs, _ = _compute_auto_correlograms_on_sorting(new_sorting, **self.params) + + for unit_ind, unit_id in enumerate(all_new_units): + if unit_id not in new_unit_ids: + keep_unit_index = self.sorting_analyzer.sorting.id_to_index(unit_id) + new_acgs[unit_ind, :] = arr[keep_unit_index, :] + else: + new_unit_index = new_sorting.id_to_index(unit_id) + new_acgs[unit_ind, :] = only_new_acgs[new_unit_index, :] - new_data = dict(ccgs=new_correlograms, bins=new_bins) + new_data = dict(acgs=new_acgs, bins=new_bins) return new_data def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs): From d9938cdb27d71413de6be289af6b22c7947c80ad Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Thu, 8 Jan 2026 17:43:14 +0100 Subject: [PATCH 5/8] Adding tests --- src/spikeinterface/postprocessing/__init__.py | 3 + .../postprocessing/correlograms.py | 7 +- .../postprocessing/tests/test_correlograms.py | 110 +++++++++++++++++- 3 files changed, 113 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/postprocessing/__init__.py b/src/spikeinterface/postprocessing/__init__.py index dca9711ccd..fe9dabb727 100644 --- a/src/spikeinterface/postprocessing/__init__.py +++ b/src/spikeinterface/postprocessing/__init__.py @@ -15,9 +15,12 @@ from .correlograms import ( ComputeACG3D, ComputeCorrelograms, + ComputeAutoCorrelograms, compute_acgs_3d, compute_correlograms, + compute_auto_correlograms, correlogram_for_one_segment, + auto_correlogram_for_one_segment, ) from .isi import ( diff --git a/src/spikeinterface/postprocessing/correlograms.py b/src/spikeinterface/postprocessing/correlograms.py index 6f0bfb0b87..b273056b1b 100644 --- a/src/spikeinterface/postprocessing/correlograms.py +++ b/src/spikeinterface/postprocessing/correlograms.py @@ -260,7 +260,6 @@ class ComputeAutoCorrelograms(AnalyzerExtension): def _set_params(self, window_ms: float = 50.0, bin_ms: float = 1.0, method: str = "auto"): params = dict(window_ms=window_ms, bin_ms=bin_ms, method=method) - return params def _select_extension_data(self, unit_ids): @@ -312,13 +311,12 @@ def _merge_extension_data( else: new_bins = self.data["bins"] all_new_units = new_sorting_analyzer.unit_ids - num_dims = len(self.data["bins"]) arr = self.data["acgs"] - new_acgs = np.zeros((len(all_new_units), num_dims), dtype=np.int64) # compute all new isi at once new_sorting = new_sorting_analyzer.sorting.select_units(new_unit_ids) - only_new_acgs, _ = _compute_auto_correlograms_on_sorting(new_sorting, **self.params) + only_new_acgs, new_bins = _compute_auto_correlograms_on_sorting(new_sorting, **self.params) + new_acgs = np.zeros((len(all_new_units), only_new_acgs.shape[1]), dtype=np.int64) for unit_ind, unit_id in enumerate(all_new_units): if unit_id not in new_unit_ids: @@ -434,7 +432,6 @@ def _compute_num_bins(window_size, bin_size): """ num_half_bins = int(window_size // bin_size) num_bins = int(2 * num_half_bins) - return num_bins, num_half_bins diff --git a/src/spikeinterface/postprocessing/tests/test_correlograms.py b/src/spikeinterface/postprocessing/tests/test_correlograms.py index 523aa4ba05..ed4f0cda4a 100644 --- a/src/spikeinterface/postprocessing/tests/test_correlograms.py +++ b/src/spikeinterface/postprocessing/tests/test_correlograms.py @@ -18,13 +18,15 @@ from pytest import param from spikeinterface import NumpySorting, generate_sorting -from spikeinterface.postprocessing import ComputeACG3D, ComputeCorrelograms +from spikeinterface.postprocessing import ComputeACG3D, ComputeCorrelograms, ComputeAutoCorrelograms from spikeinterface.postprocessing.correlograms import ( _compute_3d_acg_one_unit, _compute_correlograms_on_sorting, + _compute_auto_correlograms_on_sorting, _make_bins, compute_acgs_3d, compute_correlograms, + compute_auto_correlograms ) from spikeinterface.postprocessing.tests.common_extension_tests import AnalyzerExtensionCommonTestSuite @@ -55,13 +57,42 @@ def test_sortinganalyzer_correlograms(self, method): params = dict(method=method, window_ms=100, bin_ms=6.5) ext_numpy = sorting_analyzer.compute(ComputeCorrelograms.extension_name, **params) - result_sorting, bins_sorting = compute_correlograms(self.sorting, **params) assert np.array_equal(result_sorting, ext_numpy.data["ccgs"]) assert np.array_equal(bins_sorting, ext_numpy.data["bins"]) +class TestComputeAutoCorrelograms(AnalyzerExtensionCommonTestSuite): + @pytest.mark.parametrize( + "params", + [ + dict(method="numpy"), + dict(method="auto"), + param(dict(method="numba"), marks=SKIP_NUMBA), + ], + ) + def test_extension(self, params): + self.run_extension_tests(ComputeAutoCorrelograms, params) + + @pytest.mark.parametrize("method", ["numpy", param("numba", marks=SKIP_NUMBA)]) + def test_sortinganalyzer_auto_correlograms(self, method): + """ + Test the outputs when using SortingAnalyzer against + the output passing sorting directly to `compute_auto_correlograms`. + Sorting to `compute_auto_correlograms` is tested extensively below + so if these match it means `SortingAnalyzer` is working. + """ + sorting_analyzer = self._prepare_sorting_analyzer("memory", sparse=False, extension_class=ComputeCorrelograms) + + params = dict(method=method, window_ms=100, bin_ms=6.5) + ext_numpy = sorting_analyzer.compute(ComputeAutoCorrelograms.extension_name, **params) + result_sorting, bins_sorting = compute_auto_correlograms(self.sorting, **params) + + assert np.array_equal(result_sorting, ext_numpy.data["acgs"]) + assert np.array_equal(bins_sorting, ext_numpy.data["bins"]) + + # Unit Tests ############ def test_make_bins(): @@ -102,6 +133,48 @@ def test_equal_results_correlograms(window_and_bin_ms): ) assert np.array_equal(result_numpy, result_numba) + + +@pytest.mark.skipif(not HAVE_NUMBA, reason="Numba not available") +@pytest.mark.parametrize("window_and_bin_ms", [(60.0, 2.0), (3.57, 1.6421)]) +def test_equal_results_auto_correlograms(window_and_bin_ms): + """ + Test that the 2 methods have same results with some varied time bins + that are not tested in other tests. + """ + + window_ms, bin_ms = window_and_bin_ms + sorting = generate_sorting(num_units=5, sampling_frequency=30000.0, durations=[10.325, 3.5], seed=0) + + result_numpy, bins_numpy = _compute_auto_correlograms_on_sorting( + sorting, window_ms=window_ms, bin_ms=bin_ms, method="numpy" + ) + result_numba, bins_numba = _compute_auto_correlograms_on_sorting( + sorting, window_ms=window_ms, bin_ms=bin_ms, method="numba" + ) + + assert np.array_equal(result_numpy, result_numba) + + +@pytest.mark.skipif(not HAVE_NUMBA, reason="Numba not available") +@pytest.mark.parametrize("window_and_bin_ms", [(60.0, 2.0), (3.57, 1.6421)]) +def test_equal_results_auto_correlograms(window_and_bin_ms): + """ + Test that the 2 methods have same results with some varied time bins + that are not tested in other tests. + """ + + window_ms, bin_ms = window_and_bin_ms + sorting = generate_sorting(num_units=5, sampling_frequency=30000.0, durations=[10.325, 3.5], seed=0) + + result_numpy, bins_numpy = _compute_auto_correlograms_on_sorting( + sorting, window_ms=window_ms, bin_ms=bin_ms, method="numpy" + ) + result_numba, bins_numba = _compute_auto_correlograms_on_sorting( + sorting, window_ms=window_ms, bin_ms=bin_ms, method="numba" + ) + + assert np.array_equal(result_numpy, result_numba) @pytest.mark.parametrize("method", ["numpy", param("numba", marks=SKIP_NUMBA)]) @@ -263,6 +336,39 @@ def test_compute_correlograms_different_units(method): assert np.array_equal(result[0, 1], np.array([1, 0, 1, 1, 1, 0, 0, 0])) +@pytest.mark.parametrize("method", ["numpy", param("numba", marks=SKIP_NUMBA)]) +def test_compute_auto_correlograms_different_units(method): + """ + Make a supplementary test to `test_compute_correlograms` in which all + units had the same spike train. Test here a simpler and accessible + test case with only two neurons with different spike time differences + within and across units. + + This case is simple enough to validate by hand, for example for the + result[1, 1] case we are looking at the autocorrelogram of the unit '1'. + The spike times are 4 and 16 s, therefore we expect to see a count in + the +/- 10 to 15 s bin. + """ + sampling_frequency = 30000 + spike_times = np.array([0, 4, 8, 16]) / 1000 * sampling_frequency + spike_times.astype(int) + + spike_unit_indices = np.array([0, 1, 0, 1]) + + window_ms = 40 + bin_ms = 5 + + sorting = NumpySorting.from_samples_and_labels( + samples_list=[spike_times], labels_list=[spike_unit_indices], sampling_frequency=sampling_frequency + ) + + result, bins = compute_auto_correlograms(sorting, window_ms=window_ms, bin_ms=bin_ms, method=method) + + assert np.array_equal(result[0], np.array([0, 0, 1, 0, 0, 1, 0, 0])) + assert np.array_equal(result[1], np.array([0, 1, 0, 0, 0, 0, 1, 0])) + + + def generate_correlogram_test_dataset(sampling_frequency, fill_all_bins, hit_bin_edge): """ This generates a detailed correlogram test and expected outputs, for a number of From cd34ea88ce54a93c47551768a088c4677c8a4b29 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 8 Jan 2026 16:43:44 +0000 Subject: [PATCH 6/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/postprocessing/tests/test_correlograms.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/postprocessing/tests/test_correlograms.py b/src/spikeinterface/postprocessing/tests/test_correlograms.py index ed4f0cda4a..569d6c5d80 100644 --- a/src/spikeinterface/postprocessing/tests/test_correlograms.py +++ b/src/spikeinterface/postprocessing/tests/test_correlograms.py @@ -26,7 +26,7 @@ _make_bins, compute_acgs_3d, compute_correlograms, - compute_auto_correlograms + compute_auto_correlograms, ) from spikeinterface.postprocessing.tests.common_extension_tests import AnalyzerExtensionCommonTestSuite @@ -133,7 +133,7 @@ def test_equal_results_correlograms(window_and_bin_ms): ) assert np.array_equal(result_numpy, result_numba) - + @pytest.mark.skipif(not HAVE_NUMBA, reason="Numba not available") @pytest.mark.parametrize("window_and_bin_ms", [(60.0, 2.0), (3.57, 1.6421)]) @@ -368,7 +368,6 @@ def test_compute_auto_correlograms_different_units(method): assert np.array_equal(result[1], np.array([0, 1, 0, 0, 0, 0, 1, 0])) - def generate_correlogram_test_dataset(sampling_frequency, fill_all_bins, hit_bin_edge): """ This generates a detailed correlogram test and expected outputs, for a number of From 9e3e54c727e16c62b8d07f9e9b842a0c4e3e8a86 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Fri, 9 Jan 2026 09:53:48 +0100 Subject: [PATCH 7/8] WIP --- .../postprocessing/correlograms.py | 61 ++++++------------- 1 file changed, 17 insertions(+), 44 deletions(-) diff --git a/src/spikeinterface/postprocessing/correlograms.py b/src/spikeinterface/postprocessing/correlograms.py index b273056b1b..1e22c1414d 100644 --- a/src/spikeinterface/postprocessing/correlograms.py +++ b/src/spikeinterface/postprocessing/correlograms.py @@ -282,51 +282,24 @@ def _merge_extension_data( Here, we apply this formula to quickly compute correlograms for merged units. """ - can_apply_soft_method = True - if censor_ms is not None: - # if censor_ms has no effect, can apply "soft" method. Check if any spikes have been removed - for new_unit_id, merge_unit_group in zip(new_unit_ids, merge_unit_groups): - num_segments = new_sorting_analyzer.get_num_segments() - for segment_index in range(num_segments): - merged_spike_train_length = len( - new_sorting_analyzer.sorting.get_unit_spike_train(new_unit_id, segment_index=segment_index) - ) - - old_spike_train_lengths = len( - np.concatenate( - [ - self.sorting_analyzer.sorting.get_unit_spike_train(unit_id, segment_index=segment_index) - for unit_id in merge_unit_group - ] - ) - ) - - if merged_spike_train_length != old_spike_train_lengths: - can_apply_soft_method = False - break - - if can_apply_soft_method is False: - new_acgs, new_bins = _compute_auto_correlograms_on_sorting(new_sorting_analyzer.sorting, **self.params) - new_data = dict(acgs=new_acgs, bins=new_bins) - else: - new_bins = self.data["bins"] - all_new_units = new_sorting_analyzer.unit_ids - arr = self.data["acgs"] - - # compute all new isi at once - new_sorting = new_sorting_analyzer.sorting.select_units(new_unit_ids) - only_new_acgs, new_bins = _compute_auto_correlograms_on_sorting(new_sorting, **self.params) - new_acgs = np.zeros((len(all_new_units), only_new_acgs.shape[1]), dtype=np.int64) - - for unit_ind, unit_id in enumerate(all_new_units): - if unit_id not in new_unit_ids: - keep_unit_index = self.sorting_analyzer.sorting.id_to_index(unit_id) - new_acgs[unit_ind, :] = arr[keep_unit_index, :] - else: - new_unit_index = new_sorting.id_to_index(unit_id) - new_acgs[unit_ind, :] = only_new_acgs[new_unit_index, :] + new_bins = self.data["bins"] + all_new_units = new_sorting_analyzer.unit_ids + arr = self.data["acgs"] + + # compute all new isi at once + new_sorting = new_sorting_analyzer.sorting.select_units(new_unit_ids) + only_new_acgs, new_bins = _compute_auto_correlograms_on_sorting(new_sorting, **self.params) + new_acgs = np.zeros((len(all_new_units), only_new_acgs.shape[1]), dtype=np.int64) + + for unit_ind, unit_id in enumerate(all_new_units): + if unit_id not in new_unit_ids: + keep_unit_index = self.sorting_analyzer.sorting.id_to_index(unit_id) + new_acgs[unit_ind, :] = arr[keep_unit_index, :] + else: + new_unit_index = new_sorting.id_to_index(unit_id) + new_acgs[unit_ind, :] = only_new_acgs[new_unit_index, :] - new_data = dict(acgs=new_acgs, bins=new_bins) + new_data = dict(acgs=new_acgs, bins=new_bins) return new_data def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs): From 541a8730af4f8b990c82eebc8ab6b2fe32508c32 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Fri, 9 Jan 2026 09:54:22 +0100 Subject: [PATCH 8/8] WIP --- src/spikeinterface/postprocessing/correlograms.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/spikeinterface/postprocessing/correlograms.py b/src/spikeinterface/postprocessing/correlograms.py index 1e22c1414d..8d303296df 100644 --- a/src/spikeinterface/postprocessing/correlograms.py +++ b/src/spikeinterface/postprocessing/correlograms.py @@ -966,10 +966,7 @@ def auto_correlogram_for_one_segment(spike_times, spike_unit_indices, window_siz # Find the indices in the raveled correlograms array that need # to be incremented, taking into account the spike unit labels. - if sign == 1: - indices = spike_diff_b[m] + num_half_bins - else: - indices = spike_diff_b[m] + num_half_bins + indices = spike_diff_b[m] + num_half_bins # Increment the matching spikes in the correlograms array. bbins = np.bincount(indices)