GUI 代码 onnx模型输入 不匹配
Closed this issue · 16 comments
从Release获得的onnx模型,从MixTexUI.py获取的GUI代码。
运行发现报错:
'past_key_values.2.value', 'use_cache_branch']) are missing from input feed (['input_ids', 'encoder_hidden_states'])
用0填充输入,发现比打包的exe慢了将近一倍,cpu占用也更高
decoder_outputs = decoder_session.run(None, {
"input_ids": decoder_input_ids,
"encoder_hidden_states": encoder_outputs,
"use_cache_branch": np.array([False]),
'past_key_values.0.key': np.zeros((1, 12, 224, 64), dtype=np.float32),
'past_key_values.0.value': np.zeros((1, 12, 224, 64), dtype=np.float32),
'past_key_values.1.key': np.zeros((1, 12, 224, 64), dtype=np.float32),
'past_key_values.1.value': np.zeros((1, 12, 224, 64), dtype=np.float32),
'past_key_values.2.key': np.zeros((1, 12, 224, 64), dtype=np.float32),
'past_key_values.2.value': np.zeros((1, 12, 224, 64), dtype=np.float32)
})[0]
请问Release里面打包的代码和模型是怎么样的呢?可以分享吗?
感谢你的支持,哥们,我不被允许开源运算加速部分,抱歉了。
好的,Onnx模型推理的方法也是吗,我好像通过模型输入名字猜出来怎么推理了,可以自己写吗?
没问题,只要不是我放出来的就行
谢谢, 还请问一下对放出来的模型有什么许可限制吗,可以进行修改或者量化之类吗
感谢你的贡献!!可以修改,需要署名,不可以商用
decoder_outputs = decoder_session.run(None, { "input_ids": decoder_input_ids, "encoder_hidden_states": encoder_outputs, "use_cache_branch": np.array([False]), 'past_key_values.0.key': np.zeros((1, 12, 224, 64), dtype=np.float32), 'past_key_values.0.value': np.zeros((1, 12, 224, 64), dtype=np.float32), 'past_key_values.1.key': np.zeros((1, 12, 224, 64), dtype=np.float32), 'past_key_values.1.value': np.zeros((1, 12, 224, 64), dtype=np.float32), 'past_key_values.2.key': np.zeros((1, 12, 224, 64), dtype=np.float32), 'past_key_values.2.value': np.zeros((1, 12, 224, 64), dtype=np.float32) })[0]
通过你的这一段代码,main 分支里面的 mixtex_ui.py 总算是跑起来了,但是感觉识别的准确率会略低于打包好的 EXE 文件。
逻辑是这样:
impl MixTexOnnx {
pub fn build() -> Result<Self, Box<dyn std::error::Error>> {
let encoder_builder = Session::builder()?;
let decoder_builder = Session::builder()?;
let encoder_session = encoder_builder
.with_execution_providers(
[
// encoder_cuda.build(),
// dm.build()
]
)?
.with_optimization_level(GraphOptimizationLevel::Level3)?
.with_intra_threads(8)?
.with_inter_threads(8)?
.commit_from_memory(ENCODER_BYTES)?;
let decoder_session = decoder_builder
.with_execution_providers(
[
// decoder_cuda.build(),
// decoder_dm.build(),
// dm.build()
]
)?
.with_optimization_level(GraphOptimizationLevel::Level3)?
// .with_parallel_execution(true)?
.with_intra_threads(12)?
.with_inter_threads(12)?
.commit_from_memory(DECODER_BYTES)?;
Ok(MixTexOnnx {
encoder_session,
decoder_session,
tokenizer: Tokenizer::from_str(TOKENIZER_STR).expect("Fail to load tokenizer"),
})
}
fn init_decode(&self, img: &[f32])-> std::result::Result<(usize,ArrayViewD<f32>,SessionOutputs), Box<dyn std::error::Error>>{
let encoder_result = self.encoder_session.run(ort::inputs! {"pixel_values" => ([1,3,448,448],img)}?)?;
let hidden_state = encoder_result["last_hidden_state"].try_extract_tensor::<f32>()?;
let mut decode_input_ids = array![[0,0,30000_i64]];
let mut result_idx = [0_u32; MAX_LENGTH];
let k_0 = Array::<f32, _>::zeros((1, 12, 0, 64).f()).into_dyn();
let k_1 = Array::<f32, _>::zeros((1, 12, 0, 64).f()).into_dyn();
let k_2 = Array::<f32, _>::zeros((1, 12, 0, 64).f()).into_dyn();
let v_0 = Array::<f32, _>::zeros((1, 12, 0, 64).f()).into_dyn();
let v_1 = Array::<f32, _>::zeros((1, 12, 0, 64).f()).into_dyn();
let v_2 = Array::<f32, _>::zeros((1, 12, 0, 64).f()).into_dyn();
let decoder_result = self.decoder_session.run(ort::inputs! {
"encoder_hidden_states" => hidden_state.view(),
"input_ids"=> decode_input_ids.view(),
"use_cache_branch"=>array![true],
"past_key_values.0.key"=>k_0.view(),
"past_key_values.0.value"=>v_0.view(),
"past_key_values.1.key"=>k_1.view(),
"past_key_values.1.value"=>v_1.view(),
"past_key_values.2.key"=>k_2.view(),
"past_key_values.2.value"=>v_2.view(),
}?)?;
let mut logits = decoder_result["logits"].try_extract_tensor::<f32>()?;
let mut next_token_id = logits.slice(s![0,-1,..])
.iter()
.enumerate()
.max_by(|&(_, x), &(_, y)| {
x.partial_cmp(&y).unwrap()
})
.unwrap()
.0;
Ok((next_token_id,hidden_state,decoder_result))
}
fn decode_once(&self,state:(usize,ArrayViewD<f32>,SessionOutputs))-> std::result::Result<(usize,ArrayViewD<f32>,SessionOutputs), Box<dyn std::error::Error>>{
let (mut next_token_id,mut hidden_state,mut decoder_result) = state;
decoder_result = self.decoder_session.run(ort::inputs! {
"encoder_hidden_states" => hidden_state.view(),
"input_ids"=> array![[next_token_id as i64]],
"use_cache_branch"=>array![true],
"past_key_values.0.key"=>decoder_result["present.0.key"].try_extract_tensor::<f32>()?,
"past_key_values.0.value"=>decoder_result["present.0.value"].try_extract_tensor::<f32>()?,
"past_key_values.1.key"=>decoder_result["present.1.key"].try_extract_tensor::<f32>()?,
"past_key_values.1.value"=>decoder_result["present.1.value"].try_extract_tensor::<f32>()?,
"past_key_values.2.key"=>decoder_result["present.2.key"].try_extract_tensor::<f32>()?,
"past_key_values.2.value"=>decoder_result["present.2.value"].try_extract_tensor::<f32>()?,
}?)?;
// println!("---->loop {i} {:?} ",start_loop.elapsed());
let logits = decoder_result["logits"].try_extract_tensor::<f32>()?;
next_token_id = logits.slice(s![0,-1,..])
.iter()
.enumerate()
.max_by(|&(_, x), &(_, y)| {
x.partial_cmp(&y).unwrap()
})
.unwrap()
.0;
Ok((next_token_id, hidden_state, decoder_result))
}
pub fn inference(&self, img: &[f32]) -> std::result::Result<String, Box<dyn std::error::Error>> {
// eprintln!("Start inference!");
let start = std::time::Instant::now();
let check_rate = MAX_LENGTH / 64;
let mut result_idx = [0_u32; MAX_LENGTH];
let (mut next_token_id,mut hidden_state,mut decoder_result) = self.init_decode(img)?;
result_idx[0] = next_token_id as u32;
for i in 1..MAX_LENGTH {
// let start_loop = std::time::Instant::now();
(next_token_id,hidden_state,decoder_result) = self.decode_once((next_token_id,hidden_state,decoder_result))?;
result_idx[i] = next_token_id as u32;
// stop token 的id,这里硬编码
if next_token_id == 30000 {
break;
}
// decode_input_ids = concatenate![Axis(1),decode_input_ids,array![[next_token_id as i64]]];
if ((i + 1) % check_rate == 0) && check_repeat(&result_idx[..=i]) {
break;
}
}
eprintln!("\x1b[31mTime cost:\x1b[32m{:?}\x1b[0m", start.elapsed());
Ok(self.tokenizer.decode(&result_idx, true).unwrap())
}
}
应该和精度没关系吧,main分支的那个他没做KVCache每次都要Prefill嗯推就会更慢
我所做的只是将 main 分支下 .py 中的 mixtex_inference
改掉了,加上了
"past_key_values.0.key": np.zeros(
(1, 12, 224, 64), dtype=np.float32
),
"past_key_values.0.value": np.zeros(
(1, 12, 224, 64), dtype=np.float32
),
"past_key_values.1.key": np.zeros(
(1, 12, 224, 64), dtype=np.float32
),
"past_key_values.1.value": np.zeros(
(1, 12, 224, 64), dtype=np.float32
),
"past_key_values.2.key": np.zeros(
(1, 12, 224, 64), dtype=np.float32
),
"past_key_values.2.value": np.zeros(
(1, 12, 224, 64), dtype=np.float32
),
推理部分的其它代码我是没改的。相较于未开源的打包结果,抛开性能不谈,修改后的真开源代码的确表现出一些字母被混淆的概率变高,也不知道问题出在哪,这方面我实在是不太懂。
模型没变,问题应该就是代码问题?我觉得他那个结构和命名就是三层Transformer,他有个use_cache_branch是个分支代码,设置true就是另外一边有KVCache的,可能模型变了精度也更高?你试试:
- use_cache_branch 改成 np.array([true])
- 第一次所有past_key_value都是(1, 12, 0, 64)
- 之后所有推理输入的past_key_value 改成输出的present_key_value的值(序号对应)
我测的这样同一个图片快了一倍,精度没测我也没数据集
把它改成这样后不但没有改善,反而还变得更奇怪了,之前只是错一些比较容易混淆的,改了之后识别出来的和图片中的都快没啥关系了😢
# 初始化 past_key_values
past_key_values = [
np.zeros((1, 12, 0, 64), dtype=np.float32) for _ in range(6)
]
for _ in range(max_length):
decoder_outputs = decoder_session.run(
None,
{
"input_ids": decoder_input_ids,
"encoder_hidden_states": encoder_outputs,
"use_cache_branch": np.array([True]),
"past_key_values.0.key": past_key_values[0],
"past_key_values.0.value": past_key_values[1],
"past_key_values.1.key": past_key_values[2],
"past_key_values.1.value": past_key_values[3],
"past_key_values.2.key": past_key_values[4],
"past_key_values.2.value": past_key_values[5],
},
)
# 更新past_key_values
past_key_values = decoder_outputs[1:]
next_token_logits = decoder_outputs[0][:, -1, :]
next_token_id = np.argmax(next_token_logits, axis=-1)
decoder_input_ids = np.concatenate(
[decoder_input_ids, next_token_id[:, None]], axis=-1
)
generated_text += tokenizer.decode(
next_token_id, skip_special_tokens=True
)
self.log(
tokenizer.decode(next_token_id, skip_special_tokens=True), end=""
)
if self.check_repetition(generated_text, 12):
self.log("\n===?!重复重复重复!?===\n")
break
if next_token_id == tokenizer.eos_token_id:
self.log("\n===成功复制到剪切板===\n")
break
return generated_text
要传上一次的推理结果,不是一直是0
encoder_outputs = encoder_session.run(None, {"pixel_values": inputs})[0]
print(f"Encoder Time cost: {time.perf_counter() - start:.6f}s")
# decoder_input_ids = tokenizer("<s>", return_tensors="np").input_ids.astype(np.int64)
decoder_outputs = [
np.array([[0, 0, 30000]], dtype=np.int64),
np.zeros((1, 12, 0, 64), dtype=np.float32),
np.zeros((1, 12, 0, 64), dtype=np.float32),
np.zeros((1, 12, 0, 64), dtype=np.float32),
np.zeros((1, 12, 0, 64), dtype=np.float32),
np.zeros((1, 12, 0, 64), dtype=np.float32),
np.zeros((1, 12, 0, 64), dtype=np.float32)
]
print(decoder_outputs)
for i in range(max_length):
start_loop = time.perf_counter()
decoder_outputs = decoder_session.run([
"logits",
"present.0.key",
"present.0.value",
"present.1.key",
"present.1.value",
"present.2.key",
"present.2.value"
], {
"input_ids": decoder_outputs[0],
"encoder_hidden_states": encoder_outputs,
"use_cache_branch": np.array([true]),
'past_key_values.0.key': decoder_outputs[1],
'past_key_values.0.value': decoder_outputs[2],
'past_key_values.1.key': decoder_outputs[3],
'past_key_values.1.value': decoder_outputs[4],
'past_key_values.2.key': decoder_outputs[5],
'past_key_values.2.value': decoder_outputs[6]
})
# print(f"loop {i} Inference time cost:{time.perf_counter() - start_loop:.6f}s")
# print("decoder_outputs", decoder_outputs.shape)
next_token_id = int(np.argmax(decoder_outputs[0][:, -1, :], axis=-1))
decoder_outputs[0] = np.array([[next_token_id]],dtype=np.int64)
# print("next_token_id", next_token_id.shape)
# print("decoder_input_ids", decoder_input_ids.shape)
generated_text += tokenizer.decode(next_token_id, skip_special_tokens=True)
# print(generated_text)
if check_repeat(generated_text, 12):
break
# tokenizer.eos_token_id =30000
if next_token_id == 30000:
break
# print(f"Decoder loop {i} Time cost: {time.perf_counter() - start_loop:.6f}s")
这是一小部分按你的代码修改后的识别结果,和不使用缓存好像没有太大区别,我这边也没有可靠的数据集来具体测试,只能人工反复截图测试了😢
\begin{align*}
f(x) &= \cos(u t + \phi)
\end{align*}
\begin{align*}
f(x) &= \cos(u t + \phi)
\end{align*}
\begin{align*}
\int_0^1 \chi^2 \, d\chi &= \frac{1}{3} \\
\sum_{n=1}^{\infty} \frac{1}{n^2} &= \frac{n^2}{6} \\
\mathbf{e}^{i\pi} + 1 &= 0 \\
\nabla \cdot \mathbf{E} &= \frac{\rho}{E_0} \\
\bar{f}(\chi) &= \mathbf{C} \mathbf{C} (\omega t + \phi)
\end{align*}
如果要测精度,你可以开两个方式同时复制一个图片测一下?我更关心速度和占用。如果只是公式的话用SimpleTex的接口感觉体验更好,MixTex主要是可以文字公式一把梭更方便
我感觉我得先去补补课(