Запуск модели PyTorch в браузере

Инструкция как запустить PyTorch модель в браузере.

Демо версияКод на гитхабе.

ONNX - это библиотека, которая позволяет запускать модели с одного языка на другой. Она позволит запустить нейронную сеть, сделанную на Python, в веб JavaScript. ONNX входит в состав PyTorch.

Для web js, скачайте библиотеки и поместите их в папку с остальными js файлами:

wget https://cdn.jsdelivr.net/npm/onnxruntime-web/dist/ort-wasm-simd.wasm
wget https://cdn.jsdelivr.net/npm/onnxruntime-web/dist/ort.min.js.map
wget https://cdn.jsdelivr.net/npm/onnxjs/dist/onnx.min.js

Сохранить модель в ONNX формат можно следующим образом:

tensor_device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

onnx_model_path = "web/model.onnx"
input_shape = [1, 32, 32]

data_input = torch.zeros(input_shape).to(torch.float32)

model = model.to(tensor_device)
data_input = data_input.to(tensor_device)

torch.onnx.export(
	model,
	data_input,
	onnx_model_path,
	opset_version = 10,
	input_names = ['input'],
	output_names = ['output'],
	verbose=False
)

Очень важные детали:

  • opset_version = 10 Генерирует модель с опкодами 10й версии. Если у вас модель не запускается в браузере и выдает ошибку при загрузке, то попробуйте добавить эту строчку.

Подключите скрипт ./ort.min.js через тэг script

JS файл:

let input_shape = [1, 32, 32];

async function load_model()
{
	model = await ort.InferenceSession.create('./mnist3.onnx', {
		"executionProviders": ["webgl"]
	});
	return model;
}

async function predict(model, input)
{
	input = Float32Array.from(input);
	input = new ort.Tensor('float32', input, input_shape);
	let res = await model.run({ 'input': input });
	let output = res['output'].data;
	return output;
}

async function run()
{
	let input = [...Array(32 * 32).keys()];
	let model = await load_model();
	let output = await predict(model, input);
	console.log();
}

run();

Обратите внимание на строчку input_shape = [1, 32, 32]. Размерность должна быть такой же как и в python скрипте. И массив input в функцию predict должен передаваться такой же размерностью, только в одну линию из 1024 элементов.

Если у вас размерность вектора 32x32 и вам нужно добавить еще одно измерение, то нужно выполнить следующую комманду:

data_input = data_input[None,:]

Эта команда превращает тензор 32x32 в 1x32x32

Чтобы проверить результат, можно  перейти в папку с html файлом и запустить локальный веб сервер, через команду:

php -S 127.0.0.1:8080

И открыть браузер по адресу http://127.0.0.1:8080/

Материалы:

  1. ONNX Runtime
  2. ONNX Runtime Quick Start - Web
  3. Exporting a model from pytorch to onnx
  4. Run PyTorch models in the browser using ONNX.js