mechanicalsea/lighthubert

10 hours ASR Fine-tuning

sausage-333 opened this issue · 1 comments

Hello, I have a question about 10-hour ASR fine-tuning in your paper.

Can you give me a procedure about this experiment? (or the link I can refer)
I just want to conduct the my own experiments for 10-hour ASR fine-tuning using fairseq.

Thanks!

Thanks for your attention to our work! @sausage-333

We used a setting similar to that of HuBERT. You can refer to fairseq’s examples/hubert/config/finetune/base_10h.yaml with corresponding instructions. Additionally, you need to set

  • task.normalize: true, since LightHuBERT is trained with normalized waveforms
  • w2v_path: /path/to/lighthubert, specifying the pre-trained LightHuBERT checkpoint

To load the model successfully, some modifications are needed and we will include them in this repo soon. But here we provide a temporary but quick solution. You can make the following changes to HubertEncoder's __init__ in fairseq/models/hubert/hubert_asr.py:

        # pretrain_task = tasks.setup_task(w2v_args.task)
        # if state is not None and "task_state" in state:
        #     # This will load the stored "dictionaries" object
        #     pretrain_task.load_state_dict(state["task_state"])
        # else:
        #     pretrain_task.load_state_dict(task.state_dict())

        # model = pretrain_task.build_model(w2v_args.model, from_checkpoint=True)
        # if state is not None and not cfg.no_pretrained_weights:
        #     # set strict=False because we omit some modules
        #     model.load_state_dict(state["model"], strict=False)

        # model.remove_pretraining_modules()

        # super().__init__(pretrain_task.source_dictionary)

        # d = w2v_args.model.encoder_embed_dim

        
        from lighthubert import LightHuBERT, LightHuBERTConfig
        lighthubert_cfg = LightHuBERTConfig(state['cfg']['model'])
        lighthubert_cfg.supernet_type = "base"
        model = LightHuBERT(lighthubert_cfg)
        model.load_state_dict(state["model"], strict=False)
        model.remove_pretraining_modules()
        subnet = {
            'layer_num': 12,
            'embed_dim': 640,
            'heads_num': [10,] * 12,
            'atten_dim': [640,] * 12,
            'ffn_embed': [2560,] * 12,
            'slide_wsz': ['global'] * 12,
        }
        model.set_sample_config(subnet)
        total_params = model.calc_sampled_param_num()
        print(f"target pre-trained subnet ({total_params:,} Params): {subnet}")
        super().__init__(None)
        d = subnet['embed_dim']

Above is the Base subnet setting. You can specify supernet_type and subset by passing additional arguments.