Request: Support sharing mutable memory between host and guest
AaronFriel opened this issue · 14 comments
It looks like the WASI specification with the component model / interface types specification as written today does not support a shared memory type. Mutable shared memory is an important capability for many kinds of software that have tight bounds on latency and/or would benefit from zero-copy IO, ranging from:
- Virtual machines and emulators
- High performance computing, especially networking and message passing interfaces
- Databases software, especially with
mmap
backed buffers from the host - Video and audio streaming, as zero copy IO reduces latency and jitter, memory subsystem overhead
In fact, the current design of the WASI specification makes it more difficult to share memory between host and guest as compared to pre-component model specification, by constraining input & output codegen to specific named types.
There are many ways to address this gap, e.g.:
- If certain types, like
list<T>
supported being declared as aborrow<list<T>>
. In Rust, this would be received as a&mut [T]
, in JavaScript aTypedArray
, and so on. - A dedicated shared buffer type created instead
- #49
- #304
For a concrete use case, the Microsoft AI Controller Interface project uses shared memory to reduce latency on mutating moderately large arrays of floating point numbers representing the logits of a large language model.
CC @mmoskal & @squillace, re:
Even though the high-level semantics looks like "copies" are being made when passing component-level values into and out of components from the host, in a real host implementation, the i32
offsets into linear memory defined by the Canonical ABI can be used directly by the host to access linear memory for the duration of a host import call. With 0.3 and the addition of streams, the Canonical ABI will further be extended to allow registering subranges of linear memory to be asynchronously read-from or written-into during streaming I/O (as described here). This io_uring-style of streams ends up handling a bunch of cases that folks often otherwise thinks requires mmap
, often much better than mmap. Avoiding raw mmap
-style memory sharing also avoids a ton of portability and security hazards that otherwise appear when considering the wide range of hardware and OSes that wasm gets used on.
That being said, maybe there are other use cases not addressed by these above two points, so I'd be happy to dig into other concrete use cases.
@lukewagner for the concrete use case in AICI, I think @mmoskal can speak to the latency constraints of AICI better than I could. That said, to add specificity to this GitHub issue, I'll describe some of the workings of guided LLM generation in the context of AICI.
There are three processes involved:
- The LLM inference server (Python/PyTorch/CUDA/Cutlass), which itself heavily relies on shared memory
- The AICI runtime (Rust WASM host)
- The AICI controller (WASM guest)
The LLM inference server can serve multiple requests concurrently in a batch, each request is called a sequence, and there is a 1:1 correspondence between sequences produced by the LLM and AICI controllers.
The LLM inference server processes all sequences at once in a batch, producing one token per sequence per decoding cycle. Decoding cycles, for our purposes, consists of two phases: logit generation and sampling. Generation is the more compute-intensive step, and depending on the language model and inference server, this step takes from 100ms to under 2ms (e.g.: groq).
Each AICI controller is tasked with, given the result of the previous step's sampling, producing a set of logit biases as an array. These logit biases are of a fixed size for a given language model, and are on the order of 200-500 KiB (between 32k and 100k 32 bit floating point numbers.) After generation and before or during sampling, these biases are applied to affect the token chosen for the sequence during sampling.
This situation lends itself extremely well to mutable shared memory. The buffers are fixed size, multiple processes or threads are involved, and the memory shared could be mapped into coherent memory on an accelerator. This is important because the memory subsystem here is "slow", and round tripping through main system memory is a real cause of poor performance in LLMs. We don't want to stall the inference server for any reason, as a stall either delays the entire batch or wastes compute as the sequences that miss a deadline must backtrack.
While I could imagine a mechanism that uses an IO-uring like buffer, managing a pair of buffers across these processes seems much more complex - and error prone.
And while I'm familiar with the Are You Sure You Want to Use MMAP in Your Database Management System? paper, I think it's hard to understate that a great many project (database or otherwise) has bootstrapped itself by using memory mapped IO to let operating system kernels do the heavy lifting to build a proof of concept prior to building their own primitives.
Another salient, concrete example comes to mind from my day job at Pulumi. This was another use case where memory mapped IO is valuable for reducing latency, resident memory,
Pulumi is an infrastructure as code system that uses a plugin architecture to support managing various cloud providers with various languages.
Cloud provider plugins implement a schema describing the resources they manage, and schemas range in size from kilobytes to over 100MiB for major clouds. Language plugins support writing Pulumi code in six languages or interpreter runtimes (Node.js, Python, Go, .NET, Java, and YAML). The Pulumi engine and CLI manages launching these plugin processes and each kind of plugin implements a gRPC protocol.
Most of our languages use a generated per-provider generated SDK to help ensure the correctness of programs written in Pulumi, The nodejs
language plugin for example relies on TypeScript SDKs. The SDKs are generated from the schemas declared by cloud provider plugins.
The YAML language runtime however does not have any SDK - it instead, at runtime, relies on a provider plugin's schema to perform type-checking. For one provider in particular, the schema is on the order of 100MiB, and if naively implemented, the YAML language plugin would request the schema from the engine, the engine would intermediate and request the schema from the provider plugin. This would result in an excessive number of copies of the schema president in memory, which can result in OOMs in memory-constrained CI/CD systems. The solution implemented was to support caching the schema to a file and use memory mapped IO.
(Speaking as an individual who likes to hack on these things.) If the Pulumi engine were to support provider plugins implemented in WASM, I expect we would still want to use memory mapped IO for the GetSchema
function exported by provider plugins. Though, unlike the LLM use case above, this GetSchema RPC would return an immutable buffer mapped from either a section of constant data in the WASM plugin (an analogue to the .rodata
section) or from memory mapped IO of a file or file(s) packaged alongside the plugin.
Edit: I should add here that while it may seem "obvious" to simply ship the schema as a file alongside the plugin, some cloud provider plugins rely on the schema as an artifact to determine the cloud APIs to call or how to encode/decode RPCs. This plugin in particular that had a 100MiB schema is one of those. And of course, it's much nicer to be able to ship a single binary with embedded data than a set of files, and it leaves less room for error. There are more design constraints here than I can go into in one aside though. :)
@lukewagner IIUC, both io_uring styling streaming and directly usable i32 offset share a same strong assumption which is wasm code plays the producer(to allocate resource and manage them). In those cases, the question becomes how to let host consumes data in the linear memory efficiently. B in other side, wasm are plugins and actors of a big system. Host manages resources and holds data, like incoming messages, raw bytes and so on. Because of security(of sandbox), need a copy-in to pass arguments and a copy-out to return results for a host-wasm call. Even worse, those data aren't only from MMU managed memory.
I suggest this issue should be moved to https://github.com/WebAssembly/ComponentModel - any concerns for how Wasm primitives can be shared between guests, or host and guest, are in the Component Model's domain, and WASI just builds on top of what the CM provides.
Happy to move discussion over to the Component Model repo (also, the URL has a hyphen, so it's: https://github.com/WebAssembly/component-model/), but just to try to reply with the context already given above:
@AaronFriel Thanks for providing all that background. From what I was able to understand, it sounds like what you are describing is very much a streaming use case. In this context, I think the io_uring-style of ABI (in which pointers are submitted for asynchronous reading-out-of and writing-into) should have great performance (potentially better, even, than the mmap approach, even, given that memory accesses that result in I/O are blocking). To be clear, I'm not saying that the idea is to literally standardize io_uring (in particular, there wouldn't be request/response ring buffers managed by wasm) -- I'm just describing the rough shape of the ABI (in which pointers are submitted when initiating async I/O operations). In a Windows setting, we could say this was "Overlapped I/O"-style.
IIUC, both io_uring styling streaming and directly usable i32 offset share a same strong assumption which is wasm code plays the producer(to allocate resource and manage them).
Yes, but if we are talking about data that is to be read or written by wasm compiled from linear-memory languages (like C, C++ and Rust), that wasm expects the data to be in the default linear memory and thus, one way or another, we have to get the data into that default linear memory. We can imagine trying to pass in a pointer to external host memory (via externref
, memoryref
, (ref (array u8))
, multi-memory), but we inevitably end up at the same problem which is that the core wasm producer toolchain has no natural way to directly access this non-default-linear-memory from normal code (which always loads and stores from the default linear memory); instead you need to introduce a whole new type annotations (like LLVM's addrspace
) to tell the code generator to emit loads/stores to something other than the default linear memory and that tends to become extremely infective, requiring lots of code to be rewritten. That's why folks who have pursued this idea in the past have stopped.
Instead, with the io_uring-style of ABI we can acknowledge that, independent of the above issue, there is inevitably a copy from the kernel (or DMA engine or smart NIC) into user-space anyways, so let's have this one copy copy directly into or out of the default linear memory (by supplying the i32
offset into linear memory when we initiate the operation). This approach neatly addresses several issues at once.
From what I was able to understand, it sounds like what you are describing is very much a streaming use case.
I'm not sure I agree.
For the description of the LLM problem the data is of a fixed, well defined size and shared by multiple processes, and synchronization over main memory can even be considered "slow". We could imagine the data backed by the logit biases being mapped into coherent memory in an accelerator, or shared via IPC with multiple other processes written in other languages.
For the description of plugin provider schemas, large RPC calls, streaming might be more appropriate, but in practice what we saw was that the duplication of pages in memory in both sides of the RPC interface doubled or tripled peak memory usage. In the scenario described, three processes are involved: a provider plugin, the engine, and a language host plugin. If the engine is intermediating, it should not need to materialize a copy of the schema in memory. However, the provider plugin would read all of its .rodata
section or embedded file, increasing its resident set by the amount, and the language host would, as it pulls data from the io-uring buffer, copy it out of its read stream.
I believe that this will result in memory pressure, though I think the kernel will be fine evicting .rodata pages or equivalent, but it still results in at least one extra copy in memory.
That is, I think streaming IO has overheads that aren't recognized here and while the latencies might be slight, they are non-zero, and it seems weird to contort the system such that the memory must be owned by a singular WASM guest.
For guest to guest memory sharing, it sounds like this naturally requires multiple copies, or in the LLM use case, an entirely different process on the host could own the memory, which could be mapped by the IOMMU into a different device!
For the description of the LLM problem the data is of a fixed, well defined size and shared by multiple processes
Ah, is the data shared read-only, and is the goal here to have the same physical pages mapped into the virtual address space of each process?
For the LLM use case, it is mutable.
To try to understand which part of the whole architecture we're talking about here: for this mutable data, are we talking about the big arrays of logit biases you mentioned above:
Each AICI controller is tasked with, given the result of the previous step's sampling, producing a set of logit biases as an array. These logit biases are of a fixed size for a given language model, and are on the order of 200-500 KiB (between 32k and 100k 32 bit floating point numbers.)
? If so, then it sounded like while, at the low-level, there is memory mutation, at a high-level, what's happening is that data is being passed into the guest, which computes some new data that is then passed out of the guest. I'm also guessing (but let me know if that's not right) that a single wasm instance is reused many times in a row (or even concurrently). Thus, at a high level, it seems like one could think of the wasm as implementing a big function from inputs to output (logit biases).
Given that, the point I was also just making above is that wasm has a hard time working with anything outside of its default linear memory. Thus, we somehow need to get the input data into linear memory for the wasm to be able to access it. Similarly, the output logit arrays will also be written into the same linear memory because that's all the guest can write to directly from regular C/C++/Rust/... code. But let me know though if you've built something with wasm that works differently using multiple memories or references or something else; I can imagine various hypothetical alternative approaches, but so far they all have appeared difficult to get working in practice (e.g., with existing C/C++ code), but I'm always interested to hear if folks have built something that works.
Otherwise, it seems like the basic problem statement is how do we efficiently bulk-move data into and out of wasm's linear memory as part of making this big function call. Now, "linear memory" is just a spec-level concept that doesn't have to be implemented with plain anonymous vmem in the runtime: it is also possible (and some engines already do) map files or other things into linear memory (e.g., using mmap(MAP_FIXED)
). But this can just be an implementation detail of what is, at the spec-level, just a "copy" of data into or out of the linear memory. So the question is whether the low-level wasm ABI allow hosts to do this sort of optimization in various concrete scenarios where they want to.
It's hard to discuss this question in the abstract because how precisely mapping works varies greatly by API and OS and more; this is also why it's hard to spec mapping portably at the wasm layer without simply ruling out whole classes of embeddings. But if you have a particular host API that you'd like to discuss in this context of this LLM scenario, I'd be happy to dig into that.
? If so, then it sounded like while, at the low-level, there is memory mutation, at a high-level, what's happening is that data is being passed into the guest, which computes some new data that is then passed out of the guest. I'm also guessing (but let me know if that's not right) that a single wasm instance is reused many times in a row (or even concurrently). Thus, at a high level, it seems like one could think of the wasm as implementing a big function from inputs to output (logit biases).
I think the current (pre-Component) implementation in AICI is helpful to look at:
- Module implementing mmap https://github.com/microsoft/aici/blob/main/aicirt/src/shm.rs
- In the host, prior to calling the imported function
aici_mid_process
, we set up (zeroes or initializes memory inShm
) which is long-lived: https://github.com/microsoft/aici/blob/main/aicirt/src/moduleinstance.rs#L261-L274 - In the host, an exported method for "returning" logits by directly reading caller memory and writing it into shared logit bias memory shared with the LLM: https://github.com/microsoft/aici/blob/main/aicirt/src/hostimpl.rs#L458-L478
Due to the reasons you've described above, it's not easy today to share the memory; instead and I believe to further reduce overheads a bitmask is used. But it would be great, I think, to fully obviate the need for the bitmask and without compromise by enabling the host to export functions which loan memory to a guest (a la externref
, etc.) to work on data without copy overhead.
Even better if this works with this shared memory, because I think it would allow AICI runtime host to be almost entirely "hands off" on the shared memory and the inner loop as it acts as an intermediary between the LLM and the AICI controller guests. This would give the guests the ability to write full logit maps, even participate in the sampling step of the LLM.
Thanks for all the links, that definitely helps paint a better picture of the system we're talking about.
I didn't read all the code, so I may be missing the bigger picture, but from my brief reading, it looks like the shmem is being mmap'ed into a location outside of wasm's default linear memory (pointed to by logit_ptr
) and then aici_host_return_logit_bias
is copying out of wasm's linear memory (via read_caller_mem
) into logit_ptr
. Thus, if you were to transition aici_host_return_logit_bias
to WIT using a list<u8>
, the host implementation would be passed an i32
offset into linear memory (just like it is now) and the same copy-into-shmem could happen in the host and there would be no regression.
It is worth asking how we could further improve upon this situation by even fancier mapping techniques, that further reduced copying, of course. But at least, if I'm not missing something, we'd not be regressing things as-is, so we could open this up as a future optimization discussion.