Запуск модели PyTorch в браузере
Инструкция как запустить PyTorch модель в браузере.
ONNX - это библиотека, которая позволяет запускать модели с одного языка на другой. Она позволит запустить нейронную сеть, сделанную на Python, в веб JavaScript. ONNX входит в состав PyTorch.
Для web js, скачайте библиотеки и поместите их в папку с остальными js файлами:
wget https://cdn.jsdelivr.net/npm/onnxruntime-web/dist/ort-wasm-simd.wasm
wget https://cdn.jsdelivr.net/npm/onnxruntime-web/dist/ort.min.js.map
wget https://cdn.jsdelivr.net/npm/onnxjs/dist/onnx.min.js
Сохранить модель в ONNX формат можно следующим образом:
tensor_device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
onnx_model_path = "web/model.onnx"
input_shape = [1, 32, 32]
data_input = torch.zeros(input_shape).to(torch.float32)
model = model.to(tensor_device)
data_input = data_input.to(tensor_device)
torch.onnx.export(
model,
data_input,
onnx_model_path,
opset_version = 10,
input_names = ['input'],
output_names = ['output'],
verbose=False
)
Очень важные детали:
- opset_version = 10 Генерирует модель с опкодами 10й версии. Если у вас модель не запускается в браузере и выдает ошибку при загрузке, то попробуйте добавить эту строчку.
Подключите скрипт ./ort.min.js через тэг script
JS файл:
let input_shape = [1, 32, 32];
async function load_model()
{
model = await ort.InferenceSession.create('./mnist3.onnx', {
"executionProviders": ["webgl"]
});
return model;
}
async function predict(model, input)
{
input = Float32Array.from(input);
input = new ort.Tensor('float32', input, input_shape);
let res = await model.run({ 'input': input });
let output = res['output'].data;
return output;
}
async function run()
{
let input = [...Array(32 * 32).keys()];
let model = await load_model();
let output = await predict(model, input);
console.log();
}
run();
Обратите внимание на строчку input_shape = [1, 32, 32]. Размерность должна быть такой же как и в python скрипте. И массив input в функцию predict должен передаваться такой же размерностью, только в одну линию из 1024 элементов.
Если у вас размерность вектора 32x32 и вам нужно добавить еще одно измерение, то нужно выполнить следующую комманду:
data_input = data_input[None,:]
Эта команда превращает тензор 32x32 в 1x32x32
Чтобы проверить результат, можно перейти в папку с html файлом и запустить локальный веб сервер, через команду:
php -S 127.0.0.1:8080
И открыть браузер по адресу http://127.0.0.1:8080/
Материалы:
- ONNX Runtime
- ONNX Runtime Quick Start - Web
- Exporting a model from pytorch to onnx
- Run PyTorch models in the browser using ONNX.js