lucidrains/imagen-pytorch

Text to video no Attentions layers

axel588 opened this issue · 2 comments

I don't understand why in the Unet3D we don't use attention layers for text conditionning ( sorry if this is dumb question ).

this : layer_attns = (False, False, False, True),
layer_cross_attns = False

@axel588 hmm, if you are training with text conditioning, but have no cross attention layers set, it should error out (does it not?) i can add it if you show me a script where this is not true

@lucidrains
I applied the attention layer at first, but even with a dimension of 8 ( very low yes ) and a batch of 1 it overflows my 24gb memory card graphic card, this configuration below takes 23Gb of VRAM with 2 of batch, how to solve memory issue ?
this code work for text conditionning without attention layer and gives no error, but yes the sample seems random relative to the prompt :

unet = Unet3D(
      dim = config.dim, # the "Z" layer dimension, i.e. the number of filters the outputs to the first layer
      #cond_dim = config.cond_dim,
      channels = 5,
      dim_mults = config.dim_mults, # the channel dimensions inside the model (multiplied by dim)
     # num_resnet_blocks = config.num_resnet_blocks,
     # layer_attns = (False,) + (True,) * (len(config.dim_mults) - 1),
     # layer_cross_attns = (False,) + (True,) * (len(config.dim_mults) - 1)
    )
    
    imagen = ElucidatedImagen(
        unets = (unet),
        image_sizes = (reshaped_m),
        cond_drop_prob = 0.1,
        text_encoder_name = 't5-base',
        channels=5,
        num_sample_steps = (64), # number of sample steps - 64 for base unet, 32 for upsampler (just an example, have no clue what the optimal values are)
        sigma_min = 0.002,           # min noise level
        sigma_max = (80),       # max noise level, @crowsonkb recommends double the max noise level for upsampler
        sigma_data = 0.5,            # standard deviation of data distribution
        rho = 7,                     # controls the sampling schedule
        P_mean = -1.2,               # mean of log-normal distribution from which noise is drawn for training
        P_std = 1.2,                 # standard deviation of log-normal distribution from which noise is drawn for training
        S_churn = 80,                # parameters for stochastic sampling - depends on dataset, Table 5 in apper
        S_tmin = 0.05,
        S_tmax = 50,
        S_noise = 1.003,
    ).cuda()```