lucidrains/musiclm-pytorch

Nameerror in musiclm_pytorch.py

jodog0412 opened this issue · 1 comments

Hi, I'm implenting code in Google Colab. And I find a strange error.

import torch
from musiclm_pytorch import MuLaN, AudioSpectrogramTransformer, TextTransformer

audio_transformer = AudioSpectrogramTransformer(
    dim = 512,
    depth = 6,
    heads = 8,
    dim_head = 64,
    spec_n_fft = 128,
    spec_win_length = 24,
    spec_aug_stretch_factor = 0.8
)

text_transformer = TextTransformer(
    dim = 512,
    depth = 6,
    heads = 8,
    dim_head = 64
)

mulan = MuLaN(
    audio_transformer = audio_transformer,
    text_transformer = text_transformer
)

Error code occurs like this.

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ in <cell line: 1>:1                                                                              │
│ in __init__:52                                                                                   │
│                                                                                                  │
│ /usr/local/lib/python3.9/dist-packages/musiclm_pytorch/musiclm_pytorch.py:673 in __init__        │
│                                                                                                  │
│   670 │   │   self.text_to_latents = nn.Linear(self.text.dim, dim_latent)                        │
│   671 │   │   self.audio_to_latents = nn.Linear(self.audio.dim, dim_latent)                      │
│   672 │   │                                                                                      │
│ ❱ 673 │   │   klass = SigmoidContrastiveLearning if sigmoid_contrastive_loss else partial(Soft   │
│   674 │   │   self.contrast = klass()                                                            │
│   675 │   │                                                                                      │
│   676 │   │   self.multi_layer_contrastive_learning = None                                       │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
NameError: name 'partial' is not defined

fixed in 0.2.1