BlackHC/toma

Variable batchsize decrease factor

ChielWH opened this issue · 0 comments

First of all, many thanks for this handy utility package! My use-case is to detect the largest batchsize possible for translation during inference, which is rather low. I see that the batchsize decrease factor is fixed at 2:

@dataclassclass Batchsize:    
    ...
    def decrease_batchsize(self):        
        self.value //= 2        
        assert self.value

For my use-case, this quite an aggressive decrease factor. Might it be possible to make this variable? So that it can be used in such a way:

@toma.batch(initial_batchsize=32, cache_type=GlobalBatchsizeCache, decrease_factor=1.2)
def infer(batchsize, *args, **kwargs):
    ...

I'm happy to help, so if you would like me to make a PR, please let me know.