Функция обучения PyTorch

Функция обучения train

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')


def get_acc_class(model_res, batch_y):
	
	"""
	Функция Accuracy для классификатора
	"""

	batch_y = torch.argmax(batch_y, dim=1)
	model_res = torch.argmax(model_res, dim=1)
	acc = torch.sum( torch.eq(batch_y, model_res) ).item()
		
	return acc


def train(dataset, model, epochs=5, debug=True):
	
	"""
	Функция обучения модели pyTorch
	"""
	
	optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
	loss = nn.MSELoss()
	
	batch_size = 64
	model = model.to(device)
	k = 0.1

    # Разделить датасет на обучающий и проверочный
	train_dataset, val_dataset = torch.utils.data.random_split(
		dataset, [ round(len(dataset)*(1-k)), round(len(dataset)*k) ]
	)
	
	train_count = len(train_dataset)
	val_count = len(val_dataset)
	
	train_loader = DataLoader(
		train_dataset,
		batch_size=batch_size,
		drop_last=False,
		shuffle=True
	)
	
	val_loader = DataLoader(
		val_dataset,
		batch_size=batch_size,
		drop_last=False,
		shuffle=False
	)
	
	for step_index in range(epochs):
		
		loss_train = 0
		loss_val = 0
		acc_train = 0
		acc_val = 0
		count_train = 0
		count_val = 0
		batch_iter = 0
		
		model.train()
		
		# Обучение
		for batch_x, batch_y in train_loader:
			
			batch_x = batch_x.to(device)
			batch_y = batch_y.to(device)
			
			# Вычислим результат модели
			model_res = model(batch_x)
			
			# Найдем значение ошибки между ответом модели и правильными ответами
			loss_value = loss(model_res, batch_y)
			loss_train = loss_train + loss_value.item()
			acc_train = acc_train + get_acc(model_res, batch_y)
			
			# Вычислим градиент
			optimizer.zero_grad()
			loss_value.backward()
			
			# Оптимизируем
			optimizer.step()
			
			count_train = count_train + len(batch_x)
			batch_iter = batch_iter + len(batch_x)
			batch_iter_value = round(batch_iter / train_count * 100)
			if debug:
				print (f"\rStep {step_index+1}, {batch_iter_value}%", end='')
			
			del batch_x, batch_y
			
			# Очистим кэш CUDA
			if torch.cuda.is_available():
				torch.cuda.empty_cache()
		
		model.eval()
		
		# Вычислим ошибку на тестовом датасете
		for batch_x, batch_y in val_loader:
			
			batch_x = batch_x.to(device)
			batch_y = batch_y.to(device)
			
			# Вычислим результат модели
			model_res = model(batch_x)
			
			# Найдем значение ошибки между ответом модели и правильными ответами
			loss_value = loss(model_res, batch_y)
			loss_val = loss_val + loss_value.item()
			acc_val = acc_val + get_acc(model_res, batch_y)
			
			count_val = count_val + len(batch_x)
			batch_iter = batch_iter + len(batch_x)
			batch_iter_value = round(batch_iter / train_count * 100)
			
			if debug:
				print (f"\rStep {step_index+1}, {batch_iter_value}%", end='')
				
			del batch_x, batch_y
			
		
		# Отладочная информация
		if debug:
			loss_train = '%.3e' % loss_train
			loss_val = '%.3e' % loss_val
			
			acc_train = str(round(acc_train / count_train * 10000) / 10000).ljust(6, "0")
			acc_val = str(round(acc_val / count_val * 10000) / 10000).ljust(6, "0")
			
			print ("\r", end='')
			print (f"Step {step_index+1}, " +
				f"acc: {acc_train}, acc_val: {acc_val}, " +
				f"loss: {loss_train}, loss_val: {loss_val}"
			)