实现: rust 将 onnx 加载/预测封装成 C 接口, C 内部测试调用.

github 代码地址

为什么这么拧巴? 需要中间套一层 rust 么?

一方面, rust 相比 C 更现代, 实现很多功能更快速, 因此在为一个 C 引擎添加一些功能模块的时候, 使用 rust 写功能模块, 然后 C 通过 api 调用,这样的架构未尝不可;

另一方面, 也是为了学习 rust, 在练习中学习。

1. 生成 onnx 模型

这里我们定义一个简单的模型 ( 假设是最傻的自回归模型, 输入为 8 个字,输出是下 1 个字, 当然这里是用了字对应的 id), 并且导出为 onnx;

pytorch 模型定义

参考 genOnnx.py 代码

import torch
import torch.onnx
import torch.nn as nn

class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.dense = nn.Linear(8, 1)

    def forward(self, x):
        float_tensor = x.to(dtype=torch.float32)
        float_tensor = self.dense(float_tensor)
        int_tensor = float_tensor.to(dtype=torch.int32)
        return int_tensor

输入为 [batch, 8] 的 int 向量, 输出为 [batch, 1] 的 int 向量

训练与验证

# 创建torch模型
model = SimpleModel()
dummy_input = torch.tensor(inputData, dtype=torch.int32)
output = model(dummy_input)
print("torch Input:", dummy_input)
print("torch Output:", output)

我们跳过训练过程, 直接使用初始化的权重作为训练之后的权重;

假设我们输入是 [[79, 30, 73, 65, 69, 51, 57, 67]], 输出是[[33]].

由于是随机初始化的权重, 不一定能复现。

导出 onnx 模型

# 导出onnx模型
onnx_filename = "simple_model.onnx"
torch.onnx.export(model, dummy_input, onnx_filename, verbose=False, input_names=['input'], output_names=['output'])

# 测试onnx模型
import onnxruntime
import numpy as np
ort_session = onnxruntime.InferenceSession(onnx_filename)
output = ort_session.run(['output'], {'input': np.array(inputData,dtype=np.int32)})
print("onnx Output:", output)

我们将 torch 模型导出之后, 使用 onnxruntime 加载并测试, 发现输出是和 torch 模型输出一致的;

2.rust 包装成 C 接口

参考 rust-wrapper 代码.

rust 调用 onnx 测试

加载模型

pub fn try_load_onnx_model(model_path: &str) -> Result<ort::Session, ort::Error> {
    let model: Session = Session::builder()?
        .with_optimization_level(GraphOptimizationLevel::Level3)?
        .with_intra_threads(1)?
        .with_model_from_file(model_path)?;
    println!("rust inited model.");
    Ok(model)
}

释放模型

pub fn try_free_model(session: Session) {
    drop(session);
    println!("rust freed model.");
}

预测

这里我们为了方便, 没有使用传入的 content 进行 tokenize, 而是直接使用 fake 的 token id;

pub fn try_infer_sentence(session: &Session, content: &str) -> Result<String, ort::Error> {
    println!("rust model input: {}", content);
    let input: Array2<i32> = array![[79, 30, 73, 65, 69, 51, 57, 67]];
    let outputs = session.run(ort::inputs!["input" => input.view()]?)?;
    let output_0: Tensor<i32> = outputs["output"].extract_tensor()?;
    let output_0 = output_0.view();
    let output_0 = output_0.iter().clone().collect::<Vec<_>>();
    let output_0 = output_0.get(0).unwrap();
    let result = format!("{:?}", output_0);
    println!("rust model result: {}", result);
    Ok(result)
}

main 测试

fn main() {
    let session = try_load_onnx_model("../simple_model.onnx").expect("load model error");
    let sentence = "假设这个是测试文本";
    let result: String = try_infer_sentence(&session,sentence).expect("模型预测出错");
    println!("{}",result);
    try_free_model(session);
}

为了能先在 rust 下进行测试, 需要先注释 cargo.toml 中如下部分, 并且将 main.bak.rs 更改为 main.rs.

  #[lib]
  #name = "rust_wrapper"
  #path = "src/lib.rs"
  #crate-type = ["cdylib"]

使用 cargo run 运行 main,可以看到, 这里的预测结果和 python torch 是一致的;

rust 调用包装成 C 接口

我们为了方便, 使用 json 来传递输入输出, 也方便后期对接口的更改; 当然,如果你有更高的性能要求, 请自定义接口。

C 加载模型

传入带模型地址的 json char*, 传出 onnx 模型的 session 指针.

#[no_mangle]
pub unsafe extern "C" fn rust_try_load_onnx_model(ptr: *const c_char) -> *mut Session {
    let c_str = CStr::from_ptr(ptr);
    let rust_str = c_str.to_str().expect("Bad encoding in c_str").to_owned();
    let json_in: JsonValue = serde_json::from_str(&rust_str).unwrap();
    let model_path = json_in["model_path"].as_str().unwrap();
    let session = try_load_onnx_model(model_path).unwrap();
    Box::into_raw(Box::new(session)) // Move ownership to C
}

C 释放模型

#[no_mangle]
pub unsafe extern "C" fn rust_try_free_model(p_session: *mut Session) {
    let session = unsafe {
        assert!(!p_session.is_null());
        *Box::from_raw(p_session) // Move ownership to rust
    };
    try_free_model(session);
}

C 预测

这里输入为 session 地址, 带测试文本的 json char*, 输出为 结果 json char*.

注意, 这里的输出 char*是 rust 申明的空间, 因此外部拿到结果后,应该释放掉该地址;

当然, 更常见的方式是, 输出地址由外部调用申明, 将该地址当作入参传入即可。

#[no_mangle]
pub unsafe extern "C" fn rust_try_infer_sentence(
    p_session: *mut Session,
    ptr: *const c_char,
) -> *const c_char {
    let c_str = CStr::from_ptr(ptr);
    let rust_str = c_str.to_str().expect("Bad encoding in c_str").to_owned();
    let json_in: JsonValue = serde_json::from_str(&rust_str).unwrap();
    let sentence = json_in["sentence"].as_str().unwrap();

    let session = unsafe {
        assert!(!p_session.is_null());
        &*(p_session) // not Move ownership
    };
    let result = try_infer_sentence(session, sentence).expect("模型预测出错");
    let data = json!({"result":result}).to_string();
    let c_string: CString = CString::new(data).expect("CString::new failed");
    c_string.into_raw() // Move ownership to C
}

C 释放 string

#[no_mangle]
pub unsafe extern "C" fn rust_free_string(ptr: *const c_char) {
    let _ = CString::from_raw(ptr as *mut _); // Move ownership to rust
}

3.C 调用 rust

参考 c-call 代码.

C 简单调用 rust api 库, 注意, 需要 rust build 之后将生成的 librust_wrapper.so 拷贝到 lib 中; 由于 rust 中 ort 包实际为 onnxruntime 的 C 封装, 因此也需要 copy 对应的 onnx 库到 libs 中。

C 中调用 rust 库和调用 C 库一样简单, 经过测试发现, 预测符合预期。

typedef struct OnnxModel_S OnnxModel_t;
extern OnnxModel_t *rust_try_load_onnx_model(char *);
extern const char *rust_try_infer_sentence(OnnxModel_t *, char *);
extern void rust_try_free_model(OnnxModel_t *);
extern void rust_free_string(const char *);

int main()
{
    // init-model
    cJSON *root = cJSON_CreateObject();
    cJSON_AddStringToObject(root, "model_path", "../simple_model.onnx");
    OnnxModel_t *onnxModel = rust_try_load_onnx_model(cJSON_PrintUnformatted(root));
    cJSON_Delete(root);


    // infer-one sentence
    cJSON *infer_data = cJSON_CreateObject();
    cJSON_AddStringToObject(infer_data, "sentence", "假设这个是测试文本");
    printf("send to rust:%s\n", cJSON_PrintUnformatted(infer_data));
    const char *rust_result = rust_try_infer_sentence(onnxModel, cJSON_PrintUnformatted(infer_data));
    printf("get from rust:%s\n", rust_result);
    rust_free_string(rust_result);
    cJSON_Delete(infer_data);

    // free-model
    rust_try_free_model(onnxModel);

    return 0;
}