Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 21 additions & 6 deletions src/spikeinterface/core/job_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,11 +419,24 @@ def __init__(

if pool_engine == "process":
if mp_context is None:
mp_context = recording.get_preferred_mp_context()
if mp_context is not None and platform.system() == "Windows":
assert mp_context != "fork", "'fork' mp_context not supported on Windows!"
elif mp_context == "fork" and platform.system() == "Darwin":
warnings.warn('As of Python 3.8 "fork" is no longer considered safe on macOS')
# auto choice
if platform.system() == "Windows":
mp_context = "spawn"
elif platform.system() == "Linux":
mp_context = "fork"
elif platform.system() == "Darwin":
# We used to force spawn for macos, this is sad but in some cases fork in macos
# is very unstable and lead to crashes.
mp_context = "spawn"
else:
mp_context = "spawn"

preferred_mp_context = recording.get_preferred_mp_context()
if preferred_mp_context is not None and preferred_mp_context != mp_context:
warnings.warn(
f"You processing chain using pool_engine='process' and mp_context='{mp_context}' is not possible."
f"So use mp_context='{preferred_mp_context}' instead")
mp_context = preferred_mp_context

self.mp_context = mp_context

Expand Down Expand Up @@ -503,6 +516,8 @@ def run(self, recording_slices=None):
if self.pool_engine == "process":

if self.need_worker_index:

multiprocessing.set_start_method(self.mp_context, force=True)
lock = multiprocessing.Lock()
array_pid = multiprocessing.Array("i", n_jobs)
for i in range(n_jobs):
Expand Down Expand Up @@ -530,7 +545,7 @@ def run(self, recording_slices=None):

if self.progress_bar:
results = tqdm(
results, desc=f"{self.job_name} (workers: {n_jobs} processes)", total=len(recording_slices)
results, desc=f"{self.job_name} (workers: {n_jobs} processes {self.mp_context})", total=len(recording_slices)
)

for res in results:
Expand Down
27 changes: 25 additions & 2 deletions src/spikeinterface/core/node_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Optional, Type

import struct
import copy

from pathlib import Path

Expand Down Expand Up @@ -71,6 +72,11 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin, *ar

class PeakSource(PipelineNode):

# this is an important hack : this force a node.compute() before the machininery is started
# this trigger eventually some numba jit compilation and avoid compilation racing
# between processes or threads
need_first_call_before_pipeline = False

def get_trace_margin(self):
raise NotImplementedError

Expand All @@ -86,6 +92,12 @@ def get_peak_slice(
# not needed for PeakDetector
raise NotImplementedError

def _first_call_before_pipeline(self):
# see need_first_call_before_pipeline = True
margin = self.get_trace_margin()
traces = self.recording.get_traces(start_frame=0, end_frame=margin * 2 + 1, segment_index=0)
self.compute(traces, 0, margin * 2 + 1, 0, margin)


# this is used in sorting components
class PeakDetector(PeakSource):
Expand Down Expand Up @@ -601,7 +613,16 @@ def run_node_pipeline(
else:
raise ValueError(f"wrong gather_mode : {gather_mode}")

init_args = (recording, nodes, skip_after_n_peaks_per_worker)
node0 = nodes[0]
if isinstance(node0, PeakSource) and node0.need_first_call_before_pipeline:
# See need_first_call_before_pipeline : this trigger numba compilation before the run
node0._first_call_before_pipeline()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be a common function (by default pass) for all nodes? I assume any nodes using a numba kernel would benefit from this

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Other nodes are more complicated because we need to run the entire chain.
This help now for peak_detection.
Lets see how we can optimize this

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And at the moement this is implemented only in peak source nodes.


if job_kwargs["n_jobs"] != 1 and job_kwargs["pool_engine"] == "thread":
need_shallow_copy = True
else:
need_shallow_copy = False
init_args = (recording, nodes, need_shallow_copy, skip_after_n_peaks_per_worker)

processor = ChunkRecordingExecutor(
recording,
Expand All @@ -620,10 +641,12 @@ def run_node_pipeline(
return outs


def _init_peak_pipeline(recording, nodes, skip_after_n_peaks_per_worker):
def _init_peak_pipeline(recording, nodes, need_shallow_copy, skip_after_n_peaks_per_worker):
# create a local dict per worker
worker_ctx = {}
worker_ctx["recording"] = recording
if need_shallow_copy:
nodes = [copy.copy(node) for node in nodes]
worker_ctx["nodes"] = nodes
worker_ctx["max_margin"] = max(node.get_trace_margin() for node in nodes)
worker_ctx["skip_after_n_peaks_per_worker"] = skip_after_n_peaks_per_worker
Expand Down
2 changes: 2 additions & 0 deletions src/spikeinterface/core/numpyextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,8 @@ def __init__(
shm = SharedMemory(shm_name, create=False)
self.shms.append(shm)
traces = np.ndarray(shape=shape, dtype=dtype, buffer=shm.buf)
# Force read only
traces.flags.writeable = False
traces_list.append(traces)

if channel_ids is None:
Expand Down
3 changes: 3 additions & 0 deletions src/spikeinterface/core/recording_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -902,6 +902,9 @@ def get_chunk_with_margin(
taper = taper[:, np.newaxis]
traces_chunk2[:margin] *= taper
traces_chunk2[-margin:] *= taper[::-1]
# enforce non writable when original was not
# (this help numba to have the same signature and not compile twice)
traces_chunk2.flags.writeable = traces_chunk.flags.writeable
traces_chunk = traces_chunk2
elif add_reflect_padding:
# in this case, we don't want to taper
Expand Down
24 changes: 12 additions & 12 deletions src/spikeinterface/sortingcomponents/clustering/isosplit_isocut.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
##########################
# isocut zone

@numba.jit(nopython=True)
@numba.jit(nopython=True, nogil=True)
def jisotonic5(x, weights):
N = x.shape[0]

Expand Down Expand Up @@ -100,7 +100,7 @@ def jisotonic5(x, weights):

return y, MSE

@numba.jit(nopython=True)
@numba.jit(nopython=True, nogil=True)
def updown_arange(num_bins, dtype=np.int_):
num_bins_1 = int(np.ceil(num_bins / 2))
num_bins_2 = num_bins - num_bins_1
Expand All @@ -111,7 +111,7 @@ def updown_arange(num_bins, dtype=np.int_):
)
)

@numba.jit(nopython=True)
@numba.jit(nopython=True, nogil=True)
def compute_ks4(counts1, counts2):
c1s = counts1.sum()
c2s = counts2.sum()
Expand All @@ -123,7 +123,7 @@ def compute_ks4(counts1, counts2):
ks *= np.sqrt((c1s + c2s) / 2)
return ks

@numba.jit(nopython=True)
@numba.jit(nopython=True, nogil=True)
def compute_ks5(counts1, counts2):
best_ks = -np.inf
length = counts1.size
Expand All @@ -138,7 +138,7 @@ def compute_ks5(counts1, counts2):

return best_ks, best_length

@numba.jit(nopython=True)
@numba.jit(nopython=True, nogil=True)
def up_down_isotonic_regression(x, weights=None):
# determine switch point
_, mse1 = jisotonic5(x, weights)
Expand All @@ -153,14 +153,14 @@ def up_down_isotonic_regression(x, weights=None):

return np.hstack((y1, y2))

@numba.jit(nopython=True)
@numba.jit(nopython=True, nogil=True)
def down_up_isotonic_regression(x, weights=None):
return -up_down_isotonic_regression(-x, weights=weights)

# num_bins_factor = 1
float_0 = np.array([0.0])

@numba.jit(nopython=True)
@numba.jit(nopython=True, nogil=True)
def isocut(samples): # , sample_weights=None isosplit6 not handle weight anymore
"""
Compute a dip-test to check if 1-d samples are unimodal or not.
Expand Down Expand Up @@ -464,7 +464,7 @@ def ensure_continuous_labels(labels):

if HAVE_NUMBA:

@numba.jit(nopython=True)
@numba.jit(nopython=True, nogil=True)
def compute_centroids_and_covmats(X, centroids, covmats, labels, label_set, to_compute_mask):
## manual loop with numba to be faster

Expand Down Expand Up @@ -498,7 +498,7 @@ def compute_centroids_and_covmats(X, centroids, covmats, labels, label_set, to_c
if to_compute_mask[i] and count[i] > 0:
covmats[i, :, :] /= count[i]

@numba.jit(nopython=True)
@numba.jit(nopython=True, nogil=True)
def get_pairs_to_compare(centroids, comparisons_made, active_labels_mask):
n = centroids.shape[0]

Expand Down Expand Up @@ -526,7 +526,7 @@ def get_pairs_to_compare(centroids, comparisons_made, active_labels_mask):

return pairs

@numba.jit(nopython=True)
@numba.jit(nopython=True, nogil=True)
def compute_distances(centroids, comparisons_made, active_labels_mask):
n = centroids.shape[0]
dists = np.zeros((n, n), dtype=centroids.dtype)
Expand All @@ -548,7 +548,7 @@ def compute_distances(centroids, comparisons_made, active_labels_mask):

return dists

@numba.jit(nopython=True)
@numba.jit(nopython=True, nogil=True)
def merge_test(X1, X2, centroid1, centroid2, covmat1, covmat2, isocut_threshold):

if X1.size == 0 or X2.size == 0:
Expand Down Expand Up @@ -584,7 +584,7 @@ def merge_test(X1, X2, centroid1, centroid2, covmat1, covmat2, isocut_threshold)

return do_merge, L12

@numba.jit(nopython=True)
@numba.jit(nopython=True, nogil=True)
def compare_pairs(X, labels, pairs, centroids, covmats, min_cluster_size, isocut_threshold):

clusters_changed_mask = np.zeros(centroids.shape[0], dtype="bool")
Expand Down
3 changes: 3 additions & 0 deletions src/spikeinterface/sortingcomponents/matching/nearest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ class NearestTemplatesPeeler(BaseTemplateMatching):

name = "nearest"
need_noise_levels = True
# this is because numba
need_first_call_before_pipeline = True

params_doc = """
peak_sign : 'neg' | 'pos' | 'both'
The peak sign to use for detection
Expand Down
8 changes: 5 additions & 3 deletions src/spikeinterface/sortingcomponents/matching/tdc_peeler.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ class TridesclousPeeler(BaseTemplateMatching):

name = "tdc-peeler"
need_noise_levels = True
# this is because numba
need_first_call_before_pipeline = True
params_doc = """
peak_sign : str
'neg', 'pos' or 'both'
Expand Down Expand Up @@ -902,7 +904,7 @@ def fit_one_amplitude_with_neighbors(
if HAVE_NUMBA:
from numba import jit, prange

@jit(nopython=True)
@jit(nopython=True, nogil=True)
def construct_prediction_sparse(
spikes, traces, sparse_templates_array, template_sparsity_mask, wanted_channel_mask, nbefore, additive
):
Expand Down Expand Up @@ -932,7 +934,7 @@ def construct_prediction_sparse(
if template_sparsity_mask[cluster_index, chan]:
chan_in_template += 1

@jit(nopython=True)
@jit(nopython=True, nogil=True)
def numba_sparse_distance(
wf, sparse_templates_array, template_sparsity_mask, wanted_channel_mask, possible_clusters
):
Expand Down Expand Up @@ -968,7 +970,7 @@ def numba_sparse_distance(
distances[i] = sum_dist
return distances

@jit(nopython=True)
@jit(nopython=True, nogil=True)
def numba_best_shift_sparse(
traces, sparse_template, sample_index, nbefore, possible_shifts, distances_shift, chan_sparsity
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ class LocallyExclusivePeakDetector(PeakDetector):
engine = "numba"
need_noise_levels = True
preferred_mp_context = None
# this is because numba
need_first_call_before_pipeline = True
params_doc = (
ByChannelPeakDetector.params_doc
+ """
Expand Down Expand Up @@ -140,7 +142,7 @@ def detect_peaks_numba_locally_exclusive_on_chunk(

return peak_sample_ind, peak_chan_ind

@numba.jit(nopython=True, parallel=False)
@numba.jit(nopython=True, parallel=False, nogil=True)
def _numba_detect_peak_pos(
traces, traces_center, peak_mask, exclude_sweep_size, abs_thresholds, peak_sign, neighbours_mask
):
Expand All @@ -165,7 +167,7 @@ def _numba_detect_peak_pos(
break
return peak_mask

@numba.jit(nopython=True, parallel=False)
@numba.jit(nopython=True, parallel=False, nogil=True)
def _numba_detect_peak_neg(
traces, traces_center, peak_mask, exclude_sweep_size, abs_thresholds, peak_sign, neighbours_mask
):
Expand Down
3 changes: 2 additions & 1 deletion src/spikeinterface/sortingcomponents/peak_detection/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ def detect_peaks(
method_class = detect_peak_methods[method]

job_kwargs = fix_job_kwargs(job_kwargs)
job_kwargs["mp_context"] = method_class.preferred_mp_context
if method_class.preferred_mp_context is not None:
job_kwargs["mp_context"] = method_class.preferred_mp_context

if method_class.need_noise_levels:
from spikeinterface.core.recording_tools import get_noise_levels
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def get_convolved_traces(self, traces):
if HAVE_NUMBA:
import numba

@numba.jit(nopython=True, parallel=False)
@numba.jit(nopython=True, parallel=False, nogil=True)
def _numba_detect_peak_matched_filtering(
traces, traces_center, peak_mask, exclude_sweep_size, abs_thresholds, peak_sign, neighbours_mask, num_channels
):
Expand Down