Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 44 additions & 21 deletions src/spikeinterface/metrics/template/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from spikeinterface.core.analyzer_extension_core import BaseMetric


def get_trough_and_peak_idx(template):
def get_trough_and_peak_idx(template, peak_sign="neg"):
"""
Return the indices into the input template of the detected trough
(minimum of template) and peak (maximum of template, after trough).
Expand All @@ -25,6 +25,15 @@ def get_trough_and_peak_idx(template):
The index of the peak
"""
assert template.ndim == 1

# If peak_sign is 'pos', invert the template
if peak_sign == "pos":
template = -template
elif peak_sign == "both":
max_idx = np.abs(template).argmax()
if template[max_idx] > 0:
template = -template

trough_idx = np.argmin(template)
peak_idx = trough_idx + np.argmax(template[trough_idx:])
return trough_idx, peak_idx
Expand Down Expand Up @@ -107,26 +116,27 @@ def get_half_width(template_single, sampling_frequency, trough_idx=None, peak_id
if trough_idx is None or peak_idx is None:
trough_idx, peak_idx = get_trough_and_peak_idx(template_single)

# Edge case: template is flat
if peak_idx == 0:
return np.nan

trough_val = template_single[trough_idx]
# threshold is half of peak height (assuming baseline is 0)
threshold = 0.5 * trough_val

(cpre_idx,) = np.where(template_single[:trough_idx] < threshold)
(cpost_idx,) = np.where(template_single[trough_idx:] < threshold)
# Find where the template crosses the threshold before and after the trough
threshold_crossings = np.where(np.diff(template_single >= threshold))[0]
crossings_before_trough = threshold_crossings[threshold_crossings < trough_idx]
crossings_after_trough = threshold_crossings[threshold_crossings >= trough_idx]

if len(cpre_idx) == 0 or len(cpost_idx) == 0:
if len(crossings_before_trough) == 0 or len(crossings_after_trough) == 0:
hw = np.nan

else:
# last occurence of template lower than thr, before peak
cross_pre_pk = cpre_idx[0] - 1
# first occurence of template lower than peak, after peak
cross_post_pk = cpost_idx[-1] + 1 + trough_idx
last_crossing_before_trough = crossings_before_trough[-1]
first_crossing_after_trough = crossings_after_trough[0]

hw = (first_crossing_after_trough - last_crossing_before_trough) / sampling_frequency

hw = (cross_post_pk - cross_pre_pk) / sampling_frequency
return hw


Expand Down Expand Up @@ -163,11 +173,12 @@ def get_repolarization_slope(template_single, sampling_frequency, trough_idx=Non
if trough_idx == 0:
return np.nan

(rtrn_idx,) = np.nonzero(template_single[trough_idx:] >= 0)
if len(rtrn_idx) == 0:
# Find where the template crosses the baseline (0) after the trough
baseline_crossings = np.where(np.diff(template_single[trough_idx:] >= 0))[0]
if len(baseline_crossings) == 0:
return np.nan
# first time after trough, where template is at baseline
return_to_base_idx = rtrn_idx[0] + trough_idx
return_to_base_idx = baseline_crossings[0] + trough_idx + 1

if return_to_base_idx - trough_idx < 3:
return np.nan
Expand Down Expand Up @@ -218,6 +229,9 @@ def get_recovery_slope(template_single, sampling_frequency, peak_idx=None, **kwa
max_idx = int(peak_idx + ((recovery_window_ms / 1000) * sampling_frequency))
max_idx = np.min([max_idx, template_single.shape[0]])

if max_idx - peak_idx < 3:
return np.nan

res = scipy.stats.linregress(times[peak_idx:max_idx], template_single[peak_idx:max_idx])
return res.slope

Expand Down Expand Up @@ -315,6 +329,7 @@ def get_velocity_fits(template, channel_locations, sampling_frequency, **kwargs)
The sampling frequency of the template
**kwargs: Required kwargs:
- depth_direction: the direction to compute velocity above and below ("x", "y", or "z")
- peak_sign: whether expected peaks are negative, positive, or both ("neg", "pos", "both")
- min_channels: the minimum number of channels above or below to compute velocity
- min_r2: the minimum r2 to accept the velocity fit
- column_range: the range in um in the x-direction to consider channels for velocity
Expand All @@ -327,11 +342,13 @@ def get_velocity_fits(template, channel_locations, sampling_frequency, **kwargs)
The velocity below the max channel
"""
assert "depth_direction" in kwargs, "depth_direction must be given as kwarg"
assert "peak_sign" in kwargs, "peak_sign must be given as kwarg"
assert "min_channels" in kwargs, "min_channels must be given as kwarg"
assert "min_r2" in kwargs, "min_r2 must be given as kwarg"
assert "column_range" in kwargs, "column_range must be given as kwarg"

depth_direction = kwargs["depth_direction"]
peak_sign = kwargs["peak_sign"]
min_channels_for_velocity = kwargs["min_channels"]
min_r2 = kwargs["min_r2"]
column_range = kwargs["column_range"]
Expand All @@ -340,6 +357,14 @@ def get_velocity_fits(template, channel_locations, sampling_frequency, **kwargs)
template, channel_locations = transform_column_range(template, channel_locations, column_range, depth_direction)
template, channel_locations = sort_template_and_locations(template, channel_locations, depth_direction)

# If peak_sign is 'pos', invert the template
if peak_sign == "pos":
template = -template
elif peak_sign == "both":
peak_value = template.flat[np.abs(template).argmax()]
if peak_value > 0:
template = -template

# find location of max channel
max_sample_idx, max_channel_idx = np.unravel_index(np.argmin(template), template.shape)
max_peak_time = max_sample_idx / sampling_frequency * 1000
Expand Down Expand Up @@ -454,9 +479,10 @@ def get_spread(template, channel_locations, sampling_frequency, **kwargs) -> flo
sampling_frequency : float
The sampling frequency of the template
**kwargs: Required kwargs:
- depth_direction: the direction to compute velocity above and below ("x", "y", or "z")
- spread_threshold: the threshold to compute the spread
- column_range: the range in um in the x-direction to consider channels for velocity
- depth_direction: the direction to compute spread ("x", "y", or "z")
- spread_threshold: the threshold (0-1) to compute spread
- spread_smooth_um: the smoothing in um to apply to the amplitude profile before computing spread
- column_range: the range in um in the x-direction to consider channels for spread

Returns
-------
Expand Down Expand Up @@ -666,6 +692,7 @@ def _get_velocity_fits_metric_function(sorting_analyzer, unit_ids, tmp_data, **m
channel_locations_multi = tmp_data["channel_locations_multi"]
sampling_frequency = tmp_data["sampling_frequency"]
metric_params["depth_direction"] = tmp_data["depth_direction"]
metric_params["peak_sign"] = tmp_data["peak_sign"]
for unit_index, unit_id in enumerate(unit_ids):
channel_locations = channel_locations_multi[unit_index]
template = templates_multi[unit_index]
Expand All @@ -678,11 +705,7 @@ def _get_velocity_fits_metric_function(sorting_analyzer, unit_ids, tmp_data, **m
class VelocityFits(BaseMetric):
metric_name = "velocity_fits"
metric_function = _get_velocity_fits_metric_function
metric_params = {
"min_channels": 3,
"min_r2": 0.2,
"column_range": None,
}
metric_params = {"min_channels": 3, "min_r2": 0.2, "column_range": None}
metric_columns = {"velocity_above": float, "velocity_below": float}
metric_descriptions = {
"velocity_above": "Velocity of the spike propagation above the max channel in um/ms",
Expand Down
6 changes: 3 additions & 3 deletions src/spikeinterface/metrics/template/template_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class ComputeTemplateMetrics(BaseMetricExtension):
metric_params : dict of dicts or None, default: None
Dictionary with parameters for template metrics calculation.
Default parameters can be obtained with: `si.metrics.template_metrics.get_default_template_metrics_params()`
peak_sign : {"neg", "pos"}, default: "neg"
peak_sign : {"neg", "pos", "both"}, default: "neg"
Whether to use the positive ("pos") or negative ("neg") peaks to estimate extremum channels.
upsampling_factor : int, default: 10
The upsampling factor to upsample the templates
Expand Down Expand Up @@ -209,8 +209,7 @@ def _prepare_data(self, sorting_analyzer, unit_ids):
template_upsampled = resample_poly(template_single, up=upsampling_factor, down=1)
else:
template_upsampled = template_single
sampling_frequency_up = sampling_frequency
trough_idx, peak_idx = get_trough_and_peak_idx(template_upsampled)
trough_idx, peak_idx = get_trough_and_peak_idx(template_upsampled, peak_sign=peak_sign)

templates_single.append(template_upsampled)
troughs[unit_id] = trough_idx
Expand Down Expand Up @@ -246,6 +245,7 @@ def _prepare_data(self, sorting_analyzer, unit_ids):
tmp_data["templates_multi"] = templates_multi
tmp_data["channel_locations_multi"] = channel_locations_multi
tmp_data["depth_direction"] = self.params["depth_direction"]
tmp_data["peak_sign"] = self.params["peak_sign"]

return tmp_data

Expand Down