TorchRWKV is a pure PyTorch implementation of the RWKV large language model inference framework. This project aims to provide a flexible and easily scalable PyTorch implementation for the RWKV x060 model, supporting various features such as batch inference, parallel inference, ONNX format export, and standalone training.
- Native PyTorch implementation
- Batch inference support
- Parallel inference to fully leverage RWKV advantages
- Clean, readable, and easily extendable codebase
- ONNX format model export and inference
- Simple standalone training
We support various hardware devices, including but not limited to:
- NVIDIA GPUs
- Intel GPUs
- AMD GPUs
- Moore Thread MUSA GPUs
- Huawei Ascend NPUs
Contributions for additional device support are welcome.
-
Clone the repository:
git clone -b dev https://github.com/uniartisan/TorchRWKV.git
-
Install dependencies:
cd TorchRWKV pip install -r requirements.txt
-
Download the RWKV6 model from BlinkDL/rwkv-6-world and place the weights in the
weight
folder.
Benchmark: (we use native torch to autoregress)
import time
import os
import torch
from torchrwkv.rwkv6 import RWKV6
from torchrwkv.model_utils import RWKVConfig
from torchrwkv.sampler import sample_logits
from torchrwkv.rwkv_tokenizer import RWKV_TOKENIZER
config = RWKVConfig(model_path='weight/RWKV-x060-World-1B6-v2.1-20240328-ctx4096',
state_path='weight/rwkv-x060-chn_single_round_qa-1B6-20240516-ctx2048.pth',
prefill_kernel="triton-chunk",)
model = RWKV6(config=config)
# Please do not use torch.compile, since JIT is on by default
# Also, this will reduce the accuracy of the model by unknown reasons
# model = torch.compile(model)
initial_string = """hello"""
batch_size = 128
TEMPERATURE = 1.0
TOP_P = 0.0
LENGTH_PER_TRIAL = 100
state = model.init_state(batch_size)
encoded_input = model.tokenizer.encode([initial_string] * batch_size)
token = torch.tensor(encoded_input).long().to(config.device) #
t1 = time.time()
state = None
out, state = model.forward(token, state)
t2 = time.time()
print(f"Time: {t2 - t1}")
start_time = time.time()
for step in range(LENGTH_PER_TRIAL):
token_sampled = sample_logits(out, TEMPERATURE, TOP_P)
out, state = model.forward(token_sampled, state)
end_time = time.time()
total_time = end_time - start_time
tokens_generated = LENGTH_PER_TRIAL * batch_size
speed = tokens_generated / total_time
print(f"\nTotal time: {total_time:.2f} seconds")
print(f"Tokens generated: {tokens_generated}")
print(f"Token generation speed: {speed:.2f} tokens/second")
Method | Batch Size | Token Length | Prefill Time (ms) | Token Generation Speed (tokens/second) | Notes |
---|---|---|---|---|---|
triton-chunk | 1 | 1024 | 132.50 | 42.83 | Suitable for inference and state fine-tuning, better speed for long tokens |
triton | 1 | 1024 | 105.49 | - | Suitable for inference and training, high accuracy |
torch | 1 | 1024 | 595.22 | - | Suitable for inference on devices where Triton is unavailable |
manual-torch | 1 | 1024 | 2468.00 | - | Suitable for training on devices where Triton is unavailable, high accuracy |
- | 1 | - | - | 48.42 | Excluding prefill |
- | 64 | - | - | 1266.77 | Excluding prefill |
- | 128 | - | - | 1875.03 | Excluding prefill |
Notes:
- "-" indicates data not provided or not applicable.
- For batch size 1, only the triton-chunk method provides token generation speed (including prefill).
- For other batch sizes, token generation speeds do not include prefill time.
- Tested on WSL2, PyTorch 2.5, Intel Arc A770, 1.6B.
For normal use:
initial_string = """User: 你好! 请问你是什么模型?"""
batch_size = 2
state = None
TEMPERATURE = 1.0
TOP_P = 0.0
LENGTH_PER_TRIAL = 100
encoded_input = tokenizer.encode([initial_string] * batch_size)
token = torch.tensor(encoded_input).long().to(config.device)
token_all = torch.tensor(encoded_input).long().to(config.device)
for step in range(LENGTH_PER_TRIAL):
out, state = model.forward(token, state)
token = sample_logits(out, TEMPERATURE, TOP_P)
token_all = torch.cat((token_all, token.unsqueeze(1)), 1)
os.system('cls' if os.name == 'nt' else 'clear')
decoded_sequences = tokenizer.decode(token_all.cpu().tolist())
for i, seq in enumerate(decoded_sequences):
print(f"Batch {i+1}: {seq}")
You can also try:
print(model.generate(initial_string, LENGTH_PER_TRIAL, TEMPERATURE, TOP_P, include_prompt=True))
print(model.chat([{"role": "user", "content": "你是什么模型?"}], 500, TEMPERATURE, TOP_P))
for i in model.chat([{"role": "user", "content": "你好呀"}], 500, TEMPERATURE, TOP_P, stream=True):
print(i, end="", flush=True)
- Modify parameters in
onnx_export.py
for your desired model. - Run:
python onnx_export.py
- (Optional) Create a directory for simplified models:
mkdir ONNX_Simplified
- (Optional) Simplify the model:
python simplify_large_onnx.py -m onnx/{model name}.onnx -o ONNX_Simplified/{model name}.onnx
- (Optional) Modify the model path in
onnx_infer.py
and run:python onnx_infer.py
To start an OpenAI server:
python -m torchrwkv.openai_server --model model_path --state state_path(optional) --host 0.0.0.0 --port 8848
This framework currently supports only RWKV v6 models, specifically version x060.
We plan to adapt this project for the AI Pro development board launched by Xunlong Orange Pi, enabling inference of the RWKV model on the Ascend ecosystem.
Special thanks to:
-
Yang, S., & Zhang, Y. (2024). FLA: A Triton-Based Library for Hardware-Efficient Implementations of Linear Attention Mechanism (Version 0.0.1) [Computer software]. https://github.com/sustcsonglin/flash-linear-attention
We have used modified implementations based on their work in different kernels.
We welcome contributions from the community. Please feel free to submit PRs and raise Issues. Your input is valuable and helps improve the project for everyone.
优化模型用到的仓库:
Zhiyuan Li |
Yuunnn_w |
WuTianyi |
Null |
Dejiao Zeng |
We warmly invite everyone to contribute to the project by submitting PRs and raising Issues! Your input and contributions are highly valued and play a vital role in improving the project for the entire community. Let's collaborate and make this project even better together!