Функция обучения PyTorch
Функция обучения train
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
def get_acc_class(model_res, batch_y):
"""
Функция Accuracy для классификатора
"""
batch_y = torch.argmax(batch_y, dim=1)
model_res = torch.argmax(model_res, dim=1)
acc = torch.sum( torch.eq(batch_y, model_res) ).item()
return acc
def train(dataset, model, epochs=5, debug=True):
"""
Функция обучения модели pyTorch
"""
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss = nn.MSELoss()
batch_size = 64
model = model.to(device)
k = 0.1
# Разделить датасет на обучающий и проверочный
train_dataset, val_dataset = torch.utils.data.random_split(
dataset, [ round(len(dataset)*(1-k)), round(len(dataset)*k) ]
)
train_count = len(train_dataset)
val_count = len(val_dataset)
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
drop_last=False,
shuffle=True
)
val_loader = DataLoader(
val_dataset,
batch_size=batch_size,
drop_last=False,
shuffle=False
)
for step_index in range(epochs):
loss_train = 0
loss_val = 0
acc_train = 0
acc_val = 0
count_train = 0
count_val = 0
batch_iter = 0
model.train()
# Обучение
for batch_x, batch_y in train_loader:
batch_x = batch_x.to(device)
batch_y = batch_y.to(device)
# Вычислим результат модели
model_res = model(batch_x)
# Найдем значение ошибки между ответом модели и правильными ответами
loss_value = loss(model_res, batch_y)
loss_train = loss_train + loss_value.item()
acc_train = acc_train + get_acc(model_res, batch_y)
# Вычислим градиент
optimizer.zero_grad()
loss_value.backward()
# Оптимизируем
optimizer.step()
count_train = count_train + len(batch_x)
batch_iter = batch_iter + len(batch_x)
batch_iter_value = round(batch_iter / train_count * 100)
if debug:
print (f"\rStep {step_index+1}, {batch_iter_value}%", end='')
del batch_x, batch_y
# Очистим кэш CUDA
if torch.cuda.is_available():
torch.cuda.empty_cache()
model.eval()
# Вычислим ошибку на тестовом датасете
for batch_x, batch_y in val_loader:
batch_x = batch_x.to(device)
batch_y = batch_y.to(device)
# Вычислим результат модели
model_res = model(batch_x)
# Найдем значение ошибки между ответом модели и правильными ответами
loss_value = loss(model_res, batch_y)
loss_val = loss_val + loss_value.item()
acc_val = acc_val + get_acc(model_res, batch_y)
count_val = count_val + len(batch_x)
batch_iter = batch_iter + len(batch_x)
batch_iter_value = round(batch_iter / train_count * 100)
if debug:
print (f"\rStep {step_index+1}, {batch_iter_value}%", end='')
del batch_x, batch_y
# Отладочная информация
if debug:
loss_train = '%.3e' % loss_train
loss_val = '%.3e' % loss_val
acc_train = str(round(acc_train / count_train * 10000) / 10000).ljust(6, "0")
acc_val = str(round(acc_val / count_val * 10000) / 10000).ljust(6, "0")
print ("\r", end='')
print (f"Step {step_index+1}, " +
f"acc: {acc_train}, acc_val: {acc_val}, " +
f"loss: {loss_train}, loss_val: {loss_val}"
)