C调用RUST, ONNX预测
实现: rust 将 onnx 加载/预测封装成 C 接口, C 内部测试调用.
为什么这么拧巴? 需要中间套一层 rust 么?
一方面, rust 相比 C 更现代, 实现很多功能更快速, 因此在为一个 C 引擎添加一些功能模块的时候, 使用 rust 写功能模块, 然后 C 通过 api 调用,这样的架构未尝不可;
另一方面, 也是为了学习 rust, 在练习中学习。
1. 生成 onnx 模型
这里我们定义一个简单的模型 ( 假设是最傻的自回归模型, 输入为 8 个字,输出是下 1 个字, 当然这里是用了字对应的 id), 并且导出为 onnx;
pytorch 模型定义
参考 genOnnx.py 代码
import torch
import torch.onnx
import torch.nn as nn
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.dense = nn.Linear(8, 1)
def forward(self, x):
float_tensor = x.to(dtype=torch.float32)
float_tensor = self.dense(float_tensor)
int_tensor = float_tensor.to(dtype=torch.int32)
return int_tensor
输入为 [batch, 8] 的 int 向量, 输出为 [batch, 1] 的 int 向量
训练与验证
# 创建torch模型
model = SimpleModel()
dummy_input = torch.tensor(inputData, dtype=torch.int32)
output = model(dummy_input)
print("torch Input:", dummy_input)
print("torch Output:", output)
我们跳过训练过程, 直接使用初始化的权重作为训练之后的权重;
假设我们输入是 [[79, 30, 73, 65, 69, 51, 57, 67]], 输出是[[33]].
由于是随机初始化的权重, 不一定能复现。
导出 onnx 模型
# 导出onnx模型
onnx_filename = "simple_model.onnx"
torch.onnx.export(model, dummy_input, onnx_filename, verbose=False, input_names=['input'], output_names=['output'])
# 测试onnx模型
import onnxruntime
import numpy as np
ort_session = onnxruntime.InferenceSession(onnx_filename)
output = ort_session.run(['output'], {'input': np.array(inputData,dtype=np.int32)})
print("onnx Output:", output)
我们将 torch 模型导出之后, 使用 onnxruntime 加载并测试, 发现输出是和 torch 模型输出一致的;
2.rust 包装成 C 接口
参考 rust-wrapper 代码.
rust 调用 onnx 测试
加载模型
pub fn try_load_onnx_model(model_path: &str) -> Result<ort::Session, ort::Error> {
let model: Session = Session::builder()?
.with_optimization_level(GraphOptimizationLevel::Level3)?
.with_intra_threads(1)?
.with_model_from_file(model_path)?;
println!("rust inited model.");
Ok(model)
}
释放模型
pub fn try_free_model(session: Session) {
drop(session);
println!("rust freed model.");
}
预测
这里我们为了方便, 没有使用传入的 content 进行 tokenize, 而是直接使用 fake 的 token id;
pub fn try_infer_sentence(session: &Session, content: &str) -> Result<String, ort::Error> {
println!("rust model input: {}", content);
let input: Array2<i32> = array![[79, 30, 73, 65, 69, 51, 57, 67]];
let outputs = session.run(ort::inputs!["input" => input.view()]?)?;
let output_0: Tensor<i32> = outputs["output"].extract_tensor()?;
let output_0 = output_0.view();
let output_0 = output_0.iter().clone().collect::<Vec<_>>();
let output_0 = output_0.get(0).unwrap();
let result = format!("{:?}", output_0);
println!("rust model result: {}", result);
Ok(result)
}
main 测试
fn main() {
let session = try_load_onnx_model("../simple_model.onnx").expect("load model error");
let sentence = "假设这个是测试文本";
let result: String = try_infer_sentence(&session,sentence).expect("模型预测出错");
println!("{}",result);
try_free_model(session);
}
为了能先在 rust 下进行测试, 需要先注释 cargo.toml 中如下部分, 并且将 main.bak.rs 更改为 main.rs.
#[lib]
#name = "rust_wrapper"
#path = "src/lib.rs"
#crate-type = ["cdylib"]
使用 cargo run 运行 main,可以看到, 这里的预测结果和 python torch 是一致的;
rust 调用包装成 C 接口
我们为了方便, 使用 json 来传递输入输出, 也方便后期对接口的更改; 当然,如果你有更高的性能要求, 请自定义接口。
C 加载模型
传入带模型地址的 json char*, 传出 onnx 模型的 session 指针.
#[no_mangle]
pub unsafe extern "C" fn rust_try_load_onnx_model(ptr: *const c_char) -> *mut Session {
let c_str = CStr::from_ptr(ptr);
let rust_str = c_str.to_str().expect("Bad encoding in c_str").to_owned();
let json_in: JsonValue = serde_json::from_str(&rust_str).unwrap();
let model_path = json_in["model_path"].as_str().unwrap();
let session = try_load_onnx_model(model_path).unwrap();
Box::into_raw(Box::new(session)) // Move ownership to C
}
C 释放模型
#[no_mangle]
pub unsafe extern "C" fn rust_try_free_model(p_session: *mut Session) {
let session = unsafe {
assert!(!p_session.is_null());
*Box::from_raw(p_session) // Move ownership to rust
};
try_free_model(session);
}
C 预测
这里输入为 session 地址, 带测试文本的 json char*, 输出为 结果 json char*.
注意, 这里的输出 char*是 rust 申明的空间, 因此外部拿到结果后,应该释放掉该地址;
当然, 更常见的方式是, 输出地址由外部调用申明, 将该地址当作入参传入即可。
#[no_mangle]
pub unsafe extern "C" fn rust_try_infer_sentence(
p_session: *mut Session,
ptr: *const c_char,
) -> *const c_char {
let c_str = CStr::from_ptr(ptr);
let rust_str = c_str.to_str().expect("Bad encoding in c_str").to_owned();
let json_in: JsonValue = serde_json::from_str(&rust_str).unwrap();
let sentence = json_in["sentence"].as_str().unwrap();
let session = unsafe {
assert!(!p_session.is_null());
&*(p_session) // not Move ownership
};
let result = try_infer_sentence(session, sentence).expect("模型预测出错");
let data = json!({"result":result}).to_string();
let c_string: CString = CString::new(data).expect("CString::new failed");
c_string.into_raw() // Move ownership to C
}
C 释放 string
#[no_mangle]
pub unsafe extern "C" fn rust_free_string(ptr: *const c_char) {
let _ = CString::from_raw(ptr as *mut _); // Move ownership to rust
}
3.C 调用 rust
参考 c-call 代码.
C 简单调用 rust api 库, 注意, 需要 rust build 之后将生成的 librust_wrapper.so 拷贝到 lib 中; 由于 rust 中 ort 包实际为 onnxruntime 的 C 封装, 因此也需要 copy 对应的 onnx 库到 libs 中。
C 中调用 rust 库和调用 C 库一样简单, 经过测试发现, 预测符合预期。
typedef struct OnnxModel_S OnnxModel_t;
extern OnnxModel_t *rust_try_load_onnx_model(char *);
extern const char *rust_try_infer_sentence(OnnxModel_t *, char *);
extern void rust_try_free_model(OnnxModel_t *);
extern void rust_free_string(const char *);
int main()
{
// init-model
cJSON *root = cJSON_CreateObject();
cJSON_AddStringToObject(root, "model_path", "../simple_model.onnx");
OnnxModel_t *onnxModel = rust_try_load_onnx_model(cJSON_PrintUnformatted(root));
cJSON_Delete(root);
// infer-one sentence
cJSON *infer_data = cJSON_CreateObject();
cJSON_AddStringToObject(infer_data, "sentence", "假设这个是测试文本");
printf("send to rust:%s\n", cJSON_PrintUnformatted(infer_data));
const char *rust_result = rust_try_infer_sentence(onnxModel, cJSON_PrintUnformatted(infer_data));
printf("get from rust:%s\n", rust_result);
rust_free_string(rust_result);
cJSON_Delete(infer_data);
// free-model
rust_try_free_model(onnxModel);
return 0;
}