bytedance/piano_transcription

For a better user experience, it is recommended to use the `requests` library to download the model, as many computers do not have the `wget` tool.

kevintsq opened this issue · 0 comments

I wrote some code to demonstrate it. It's easy to be integrated into your library code. I used some f-strings for convenience.

import sys

import requests
from pathlib import Path
import hashlib


class ProgressBar:
    def __init__(self, title, total, running_str='Running', completed=0):
        self.title = title
        self.total = total
        self.total_str = self.convert(total)
        self.completed = completed
        self.status = running_str

    @staticmethod
    def convert(size: int):
        units = ['B', 'KB', 'MB', 'GB', 'TB', 'PB', 'EB', 'ZB', 'YB', 'BB', 'NB', 'DB', 'CB']

        for unit in units:
            if size >= 1024:
                size /= 1024
            else:
                return f'{size:.2f} {unit}'

        return f'{size} B'

    def __str__(self):
        return f'[{self.status}] {self.title} {self.convert(self.completed)} of {self.total_str},{self.completed * 100 / self.total : .2f}% Completed'

    def update(self, completed=1):
        self.completed += completed
        print(f'\r{self}', end='')


response = requests.get(
    'https://zenodo.org/record/4034264/files/CRNN_note_F1%3D0.9677_pedal_F1%3D0.9186.pth?download=1',
    stream=True)
response.raise_for_status()
chunk_size = 1024 * 1024
content_size = int(response.headers['content-length'])
progress = ProgressBar('Model', total=content_size, running_str="Downloading")
sha1 = hashlib.sha1()
with open(f'{Path.home()}/piano_transcription_inference_data/note_F1=0.9677_pedal_F1=0.9186.pth', "wb") as file:
    for data in response.iter_content(chunk_size=chunk_size):
        file.write(data)
        sha1.update(data)
        progress.update(completed=len(data))
if sha1.hexdigest() == 'b06d5ab55ff57beae8ab2a76205d5847add01fec':
    print('\nSucceeded in downloading model :)')
else:
    print('\nModel Corrupted :( Please download again!', file=sys.stderr)
    sys.exit(-1)