This commit is contained in:
2026-06-11 11:06:59 +08:00
parent 4574798d86
commit 0570d41439
4 changed files with 19 additions and 19 deletions

View File

@@ -318,11 +318,7 @@ class ExP():
train_pred = torch.max(outputs, 1)[1]
train_acc = float((train_pred == label).cpu().numpy().astype(int).sum()) / float(label.size(0))
algo_log('Epoch:', e,
' Train loss: %.6f' % loss.detach().cpu().numpy(),
' Test loss: %.6f' % loss_test.detach().cpu().numpy(),
' Train accuracy %.6f' % train_acc,
' Test accuracy is %.6f' % acc, level="debug")
algo_log(f"Epoch = {e}, Train loss = {loss.detach().cpu().numpy():.6f}, Test loss = {loss_test.detach().cpu().numpy():.6f}, Train accuracy = {train_acc:.6f}, Test accuracy = {acc:.6f}", level="debug")
self.log_write.write(str(e) + " " + str(acc) + "\n")
num = num + 1
@@ -335,8 +331,8 @@ class ExP():
torch.save(self.model, model_path)
averAcc = averAcc / num
algo_log('The average accuracy is:', averAcc, level="debug")
algo_log('The best accuracy is:', bestAcc, level="debug")
algo_log(f"The average accuracy is: {averAcc}", level="debug")
algo_log(f"The best accuracy is: {bestAcc}", level="debug")
self.log_write.write('The average accuracy is: ' + str(averAcc) + "\n")
self.log_write.write('The best accuracy is: ' + str(bestAcc) + "\n")
@@ -366,12 +362,13 @@ def onlineTrain(data_queue,result_queue):
data = data_queue.get(timeout=30)
all_data, all_label,model_path,n_chan = data['data'], data['label'],data['modelPath'],data['n_chan']
exp = ExP(n_chan)
algo_log('训练参数: ',np.shape(all_data),np.shape(all_label),model_path, level="debug")
algo_log(f"训练参数: {np.shape(all_data)}, {np.shape(all_label)}, {model_path}", level="debug")
bestAcc, averAcc, Y_true, Y_pred = exp.train(all_data,all_label,model_path)
algo_log('THE BEST ACCURACY IS ' + str(bestAcc), level="debug")
algo_log(f"THE BEST ACCURACY IS {str(bestAcc)}", level="debug")
endtime = datetime.datetime.now()
algo_log('train duration: ',str(endtime - starttime), level="debug")
algo_log(f"train duration: {endtime - starttime}", level="debug")
# 将模型或参数传回
result_queue.put({
@@ -387,7 +384,7 @@ def offlineTrain(all_data,all_label,modelPath):
# seed_n = np.random.randint(2025)
seed_n = 1877
algo_log('seed is ' + str(seed_n), level="debug")
algo_log(f"seed is {seed_n}", level="debug")
random.seed(seed_n)
np.random.seed(seed_n)
torch.manual_seed(seed_n)
@@ -400,7 +397,7 @@ def offlineTrain(all_data,all_label,modelPath):
algo_log('THE BEST ACCURACY IS ' + str(bestAcc), level="debug")
endtime = datetime.datetime.now()
algo_log('train duration: ',str(endtime - starttime), level="debug")
algo_log(f"train duration: {endtime - starttime}", level="debug")