SpikeInterface/spikeinterface

Can't interpolate traces with dtype('int16')

Closed this issue · 7 comments

Hello!
I've been having the same error as was seen in [https://github.com//issues/3146#issue-2391280072](this thread). I was happy to switch to float32, as they did, but it seems like this is causing problems further down in our processing chain with Phy. We are running Kilosort4 from SI after performing motion correction—what's the best course of action for us? Should we somehow convert our data back to int16 prior to running kilosort?

Traceback (most recent call last):
File "----------", line 122, in
rec_corrected, motion_info = si.correct_motion(rec, preset='nonrigid_accurate', interpolate_motion_kwargs={'border_mode':'force_extrapolate'},folder=result_folder, output_motion_info=True)
File "---------/miniconda3/envs/kilosort/lib/python3.9/site-packages/spikeinterface/preprocessing/motion.py", line 433, in correct_motion
recording_corrected = InterpolateMotionRecording(recording, motion, **interpolate_motion_kwargs)
File "-----------/miniconda3/envs/kilosort/lib/python3.9/site-packages/spikeinterface/sortingcomponents/motion/motion_interpolation.py", line 346, in init
raise ValueError(f"Can't interpolate traces of recording with non-floating dtype={recording.dtype=}.")
ValueError: Can't interpolate traces of recording with non-floating dtype=recording.dtype=dtype('int16').
srun: error: gpu-n40: task 0: Exited with exit code 1

zm711 commented

could you post your full spikeinterface script? Kilosort 1-3 require int16. I think KS4 can handle not int16, but still prefers int16. So our wrappers should handle this.

from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import shutil
import spikeinterface.full as si
import spikeinterface.sorters as ss
from spikeinterface.sorters import run_sorter
from spikeinterface.sortingcomponents.peak_localization import localize_peaks

from spikeinterface.sortingcomponents.peak_detection import detect_peaks

folder='----------'

global_job_kwargs = dict(n_jobs=40, chunk_duration='1s', progress_bar=True)
si.set_global_job_kwargs(**global_job_kwargs)

raw_rec = si.read_spikeglx(folder_path=folder, stream_id = "imec0.ap")
raw_rec

#rec = preprocess_chain(raw_rec)
rec = si.bandpass_filter(raw_rec, freq_min=300., freq_max=6000, dtype = 'int16') #previously changed dtype to float32, resolving issue
bad_channels, channel_labels = si.detect_bad_channels(rec)
print('bad_channel_ids', bad_channels)
rec = si.phase_shift(recording=rec)
rec = si.common_reference(rec, reference='global', operator = 'median') #instead of highpass_spatial_filter in IBL

job_kwargs = dict(n_jobs=40, chunk_duration='1s', progress_bar=True)

rec #recording post-filter+phase-shift


# estimate the noise on the scaled traces (microV) or on the raw one (which is in our case int16).
noise_levels_microV = si.get_noise_levels(rec, return_scaled=True)
noise_levels_int16 = si.get_noise_levels(rec, return_scaled=False)

# Detect peaks
job_kwargs = dict(n_jobs=40, chunk_duration='1s', progress_bar=True)
peaks = detect_peaks(rec,  method='locally_exclusive', noise_levels=noise_levels_int16,
                     detect_threshold=5, radius_um=50., **job_kwargs)
peaks

# Localize peaks
peak_locations = localize_peaks(rec, peaks, method='center_of_mass', radius_um=50., **job_kwargs)

rec_corrected, motion_info = si.correct_motion(rec, preset='nonrigid_accurate', interpolate_motion_kwargs={'border_mode':'force_extrapolate'},folder=result_folder, output_motion_info=True)

rec_corrected

preprocess_folder = folder + 'preprocess'
rec_corrected = rec_corrected.save(folder=preprocess_folder, format='binary', dtype='int16', **job_kwargs)

# run kilosort4 without drift correction
params = si.get_default_sorter_params(sorter_name_or_class='kilosort4')
params_kilosort4 = {
    'do_correction': False,
    'bad_channels': None #would need to change if we choose not to delete bad channels
}

sorting = si.run_sorter('kilosort4', rec, output_folder=folder + 'kilosort4_output',
                        docker_image=False, verbose=True, **params_kilosort4)

sorting

As the error suggests, the interpolation part of the correct_motion requires the input to be float. You can simply fix this with this line:

rec = spre.astype(rec, "float")

Honestly I think we should do this by default even if the input is int16 and I find the current behavior pretty annoying! @samuelgarcia what do you think?

@jackrwaters sorry I read too quickly! You could do this: cast to float, run correct motion, and then use the astype to recast to int16.

Can you give it a try?

zm711 commented

Hey Alessio shouldn't we just convert with astype for the user? If KS1-4 prefer int16 (or require it) we should check the dtype and then just run the astype ourselves, no? I can check the wrappers later and add it if we think we should.

Hey Alessio shouldn't we just convert with astype for the user? If KS1-4 prefer int16 (or require it) we should check the dtype and then just run the astype ourselves, no? I can check the wrappers later and add it if we think we should.

That is done for KS1-3, but KS4 accepts floats too so we shouldn't cast to int16 IMO

It might be better to check the dtype at the start of correct_motion? Encountering an error after the time-consuming steps of localize_peaks and estimate_motion can be frustrating, especially when motion_info hasn't been saved to the folder.