bytedance/piano_transcription

训练 note transcription system的问题

WiseTimeC opened this issue · 1 comments

先生你好!十分感谢您抽空来看我的问题!
我是一位python的初学者,最近在使用您的piano_transcription来训练模型,现在遇到了这个问题,在执行train 的时候报错:
root : INFO <class 'main.training'>
root : INFO Using GPU.
root : INFO train segments: 670
root : INFO Evaluate train segments: 670
root : INFO Evaluate validation segments: 0
root : INFO Evaluate test segments: 0
GPU number: 1
root : INFO ------------------------------------
root : INFO Iteration: 0
Traceback (most recent call last):
File "D:/piano_transcription-master/paper/train 1.py", line 18, in
train(training)
File "D:\piano_transcription-master\pytorch\main.py", line 208, in train
validate_statistics = evaluator.evaluate(validate_loader)
File "D:\piano_transcription-master\pytorch\evaluate.py", line 51, in evaluate
output_dict = forward_dataloader(self.model, dataloader, self.batch_size)
File "D:\piano_transcription-master\pytorch\pytorch_utils.py", line 53, in forward_dataloader
for n, batch_data_dict in enumerate(dataloader):
File "C:\Users\HP\AppData\Local\Programs\Python\Python37\lib\site-packages\torch\utils\data\dataloader.py", line 279, in iter
return _MultiProcessingDataLoaderIter(self)
File "C:\Users\HP\AppData\Local\Programs\Python\Python37\lib\site-packages\torch\utils\data\dataloader.py", line 746, in init
self._try_put_index()
File "C:\Users\HP\AppData\Local\Programs\Python\Python37\lib\site-packages\torch\utils\data\dataloader.py", line 861, in _try_put_index
index = self._next_index()
File "C:\Users\HP\AppData\Local\Programs\Python\Python37\lib\site-packages\torch\utils\data\dataloader.py", line 339, in _next_index
return next(self._sampler_iter) # may raise StopIteration
File "D:\piano_transcription-master\paper../utils\data_generator.py", line 302, in iter
index = self.segment_indexes[pointer]
IndexError: index 0 is out of bounds for axis 0 with size 0
是一开始定义的数组的问题吗?还是其他的问题?谢谢了!

这里应该是需要检查之前生成数据的时候是否生成了对应的数据文件。evaluate的部分会使用验证集获取文件。在Windows平台上,utility.py文件中需要做一个对应修改,否则可能会有获取不到文件路径的问题。我这里做了以下修改:

line 26:
def get_filename(path):
path = os.path.realpath(path)
na_ext = path.split('\' if 'nt' == os.name else '/')[-1]
na = os.path.splitext(na_ext)[0]
return na