bghira/SimpleTuner

class-preservation target loss for LoRA / LyCORIS

Closed this issue · 7 comments

the idea is based on this pastebin entry: https://pastebin.com/3eRwcAJD

snippet:

                    if batch['prompt'][0] == "woman":
                        with torch.no_grad():
                            self.model.transformer_lora.remove_hook_from_module()
                            regmodel_output_data = self.model_setup.predict(self.model, batch, self.config, train_progress)
                            self.model.transformer_lora.hook_to_module()
 
                        model_output_data = self.model_setup.predict(self.model, batch, self.config, train_progress)
                        model_output_data['target']=regmodel_output_data['predicted']
                        loss = self.model_setup.calculate_loss(self.model, batch, model_output_data, self.config)
                        loss *= 1.0
                        print("\nregmodel loss:",loss)
                    else:
 
                        model_output_data = self.model_setup.predict(self.model, batch, self.config, train_progress)
 
                        loss = self.model_setup.calculate_loss(self.model, batch, model_output_data, self.config)

the idea is that we can set a flag inside the multidatabackend.json for a dataset that contains our regularisation data.

instead of training on this data as we currently do, we will instead;

  • temporarily disable the lora/lycoris adapter
  • run a prediction using the regularisation data on the (probably quantised) base model network
  • re-enable the lora/lycoris adapter
  • run the prediction on the adapter
  • update the loss target from the clean latent to the base model prediction

instead of checking for woman in the first element's caption, the batch will come with a flag to enable this behaviour, from multidatabackend.json somehow.

this will indeed run more slowly as it runs two forward passes during training from the regularisation dataset but it has the intended effect of maintaining the original model's outputs for the given inputs, which helps substantially prevent subject bleed.

note: i'm not aware of the author of the code snippet, but i would love to give credit to whoever did create it.

example that came with the snippet:

image

requested by a user on the terminus research discord.

I'm the author of this. I am not entirely convinced yet myself that this is a useful feature. It seems to limit somewhat the training of the concept you do want to change ("ohwx woman" in this sample), by insisting that the concept "woman" remains exactly the same during training.

this was an experiment I first ran yesterday, so I have limited test data myself. Training TE or training additional embeddings might overcome the issue mentioned above by separating the concepts in TE space? I am currently trying embeddings.

Happy to help with your implementation of this!

TIPO with random seeds and temperatures can be used to generate random prompts for related concepts. It can do tags -> natural language prompt or short prompt -> long prompt.

https://huggingface.co/KBlueLeaf/TIPO-500M

Screenshot_2024-10-06_20-00-26

this was an experiment I first ran yesterday, so I have limited test data myself. Training TE or training additional embeddings might overcome the issue mentioned above by separating the concepts in TE space? I am currently trying embeddings.

There is no need to train the text encoder for flux models, as the model is partially a large text encoder aligned to image space.

as the model is partially a large text encoder aligned to image space.

source, more info?

mm-dit is this.

After running some more tests, now I do think this is worth implementing.
It even works well with an empty prompt and no external reg image set - just reuse the training data set and:
if batch['prompt'][0] == "":

Making this a feature that does not require data, captions or configuration otherwise. Since there is no prompt provided, it can potentially preserve multiple classes and whatever you train on.

branch here for anyone who wants to try: https://github.com/dxqbYD/OneTrainer/tree/prior_reg
but it's the same code as above