[CQL] How to speed up?
nissymori opened this issue · 0 comments
nissymori commented
@partial(jax.jit, static_argnames=("self", "bc", "batch_size", "n"))
def train_n_step(self, dataset, batch_size, n, bc=False):
for _ in range(n):
batch = batch_to_jax(subsample_batch(dataset, batch_size))
metrics = self.train(batch, bc)
return metrics
を追加して,forjit
やってみた(どこのrepoにも残していない.).以下,1000 stepにかかる時間の比較.
method | time | jit time |
---|---|---|
forloop(original) | 3.8s | - |
forjit(10) | 5.6s | 74s |
forjit(100) | 3.3s | 740s(approx.) |
forjit(100)は早いけど,jitの時間長すぎてoriginalと比べて早くはない.originalが相当いい実装なのかも.とりあえずこコピーは置いておいて,一旦IQLを片付けよう.