Variable batchsize decrease factor
ChielWH opened this issue · 0 comments
ChielWH commented
First of all, many thanks for this handy utility package! My use-case is to detect the largest batchsize possible for translation during inference, which is rather low. I see that the batchsize decrease factor is fixed at 2:
@dataclassclass Batchsize:
...
def decrease_batchsize(self):
self.value //= 2
assert self.value
For my use-case, this quite an aggressive decrease factor. Might it be possible to make this variable? So that it can be used in such a way:
@toma.batch(initial_batchsize=32, cache_type=GlobalBatchsizeCache, decrease_factor=1.2)
def infer(batchsize, *args, **kwargs):
...
I'm happy to help, so if you would like me to make a PR, please let me know.