ml-explore/mlx-swift-examples

Avoid exceeding maximum allowed buffer size

DePasqualeOrg opened this issue ยท 5 comments

I'm already protecting against loading models that won't fit in memory (#70) and checking if the prompt exceeds the model's context window, but long prompts that fit in the context window can still exceed the maximum allowed buffer size:

libc++abi: terminating due to uncaught exception of type std::runtime_error: Attempting to allocate 9948865536 bytes which is greater than the maximum allowed buffer size of 8589934592 bytes.
  if (size > device_->maxBufferLength()) {
    std::ostringstream msg;
    msg << "Attempting to allocate " << size << " bytes which is greater than"
        << " the maximum allowed buffer size of " << device_->maxBufferLength()
        << " bytes.";
    throw std::runtime_error(msg.str());
  }

Is there a way to estimate the buffer length required by a given prompt to guard against this fatal error?

I am not sure about estimating other than measuring yourself with some token lengths and see how much it uses.

#93 has some things that, once built, should help with memory use -- potentially capping the amount of memory used.

What exactly do I need to measure? Is it GPU.activeMemory? This seems to fluctuate quite a bit during generation, and the peak values can even be significantly above the recommendedMaxWorkingSetSize of the Metal device without the app crashing, so I'm not sure how to approach this.

awni commented

In your case the largest buffer is from the scores for self-attention. That will be during the SDPA computation which will use seq_len^2 * n_heads * 2 bytes (assuming quantized/fp16/ bf16). For example if n_heads = 32 and the sequence length is 5000 then SDPA will need about 1.5 GB. So you can estimate that.

You can get the maximum allowed buffer size from the deviceInfo struct.

However just because you can allocate the buffer doesn't mean it's a good idea to run with such a large sequence length. It will probably still use up a lot of memory and run very slowly.

I think a better medium term course of action is what @davidkoski suggested (which is to pull in the prompt splitting from #93 which will substantially reduce the memory needed for the SDPA).

Thanks, @awni, that's very helpful. Is only num_attention_heads from the model's config.json relevant here, or does num_key_value_heads also need to be taken into account?

awni commented

num_attention_heads is the relevant one because it's what dictates the maximum size of the scores. The largest array you make for long prompts (in the current setup, this will probably change in the future) are the scores which have shape [batch, num_heads, sequence, sequence] (batch is 1 here).