facebookresearch/moolib

Transmission failed for python classes with specific shape torch.Tensor members.

xiaomengy opened this issue · 1 comments

When moolib transmit a class instance which contains a torch.Tensor member with specific shape, it may fail with the error below. While if we transmit the Tensor directly there is no problem. This issue may block the GNN use cases because Graphs are usually lib defined classes with such tensors. One reproduce script on devfair is listed below. We have to fix this issue.

RuntimeError: Remote exception during RPC call (print_data): ValueError: read() returned non-bytes object (<class 'memoryview'>)
import asyncio
import traceback

from typing import Any

import torch
import moolib


class DataWrapper:

    def __init__(self, data: Any) -> None:
        self.data = data

    def __str__(self) -> str:
        return f"[{self.data}]"

    def __repr__(self) -> str:
        return self.__str__()


async def process(que, callback):
    try:
        while True:
            ret_cb, args, kwargs = await que
            if args and kwargs:
                ret = callback(*args, **kwargs)
            elif args:
                ret = callback(*args)
            elif kwargs:
                ret = callback(**kwargs)
            else:
                ret = callback()
            ret_cb(ret)
    except asyncio.CancelledError:
        print("[Server] process cancelled")
        pass
    except Exception as e:
        print(e)
        raise


async def main():
    addr = "127.0.0.1:4411"
    timeout = 60

    loop = asyncio.get_running_loop()

    server = moolib.Rpc()
    server.set_name("server")
    server.set_timeout(timeout)

    def print_data(x: Any) -> None:
        print(f"[print_data] graph = {x}")
        return x

    loop.create_task(process(server.define_queue("print_data"), print_data))

    server.listen(addr)

    client = moolib.Rpc()
    client.set_name("client")
    client.set_timeout(timeout)
    client.connect(addr)

    x = torch.randn(422, 95)
    x_wrapped = DataWrapper(x)

    num = 20
    futs = []
    for _ in range(num):
        # fut = client.async_("server", "print_data", x)  # This works
        fut = client.async_("server", "print_data", x_wrapped)
        futs.append(fut)
    for fut in futs:
        await fut


if __name__ == "__main__":
    try:
        asyncio.run(main())
    except:
        traceback.print_exc()

Thanks, I've identified the issue and working on a solution.