fix 提取数据不成功

This commit is contained in:
2026-06-14 10:25:56 +08:00
parent c27e250fad
commit 7f7760c1b6
4 changed files with 69 additions and 50 deletions

View File

@@ -276,22 +276,31 @@ class Decoder_main(threading.Thread):
'''训练阶段采集数据'''
if self.zmqServer.state_mode == 'train': # 训练状态
if self.zmqServer.epoch_finished and self.zmqServer.paradigmBuffer.GetDataLenCount() >= \
self.train_epoch[1] + self.zmqServer.event_inner_idx:
self.currentLabel = self.zmqServer.currentLabel
trainTrial = self.zmqServer.paradigmBuffer.get_SSMVEPData() # 取出所有数据
algo_log(f"取出的:{trainTrial.shape}event{trainTrial[-2, self.zmqServer.event_inner_idx]}", level="DEBUG")
trainTrial = self.preprocess(trainTrial[:self.n_chan, :]) # 预处理
trainTrial = trainTrial[:, self.zmqServer.event_inner_idx + self.train_epoch[
0]:self.zmqServer.event_inner_idx + self.train_epoch[1]]
if trainTrial.shape[1] == (self.train_epoch[1] - self.train_epoch[0]) and isinstance(
self.trainLabel, list) \
and self.trainLabel.count(self.currentLabel) < self.single_train:
self.trainData.append(trainTrial)
self.trainLabel.append(self.currentLabel)
algo_log(f"SSMVEP训练集{np.shape(self.trainData)}", level="DEBUG")
if self.zmqServer.pack_contain_event:
with self.zmqServer.paradigmBufferLock:
self.zmqServer.paradigmBuffer.resetAllPara()
self.zmqServer.pack_contain_event = False
if self.zmqServer.epoch_finished:
data_length = self.zmqServer.paradigmBuffer.GetDataLenCount()
if data_length >= self.train_epoch[1] + self.zmqServer.event_inner_idx:
self.currentLabel = self.zmqServer.currentLabel
trainTrial = self.zmqServer.paradigmBuffer.get_SSMVEPData() # 取出所有数据
algo_log(f"取出的:{trainTrial.shape}event{trainTrial[-2, self.zmqServer.event_inner_idx]}", level="DEBUG")
trainTrial = self.preprocess(trainTrial[:self.n_chan, :]) # 预处理
trainTrial = trainTrial[:, self.zmqServer.event_inner_idx + self.train_epoch[
0]:self.zmqServer.event_inner_idx + self.train_epoch[1]]
if trainTrial.shape[1] == (self.train_epoch[1] - self.train_epoch[0]) and isinstance(
self.trainLabel, list) \
and self.trainLabel.count(self.currentLabel) < self.single_train:
self.trainData.append(trainTrial)
self.trainLabel.append(self.currentLabel)
algo_log(f"SSMVEP训练集{np.shape(self.trainData)}", level="DEBUG")
else:
algo_log(f"epoch_finished {self.zmqServer.epoch_finished}, 数据长度不足 {data_length}", level="DEBUG")
self.zmqServer.epoch_finished = False
else:
time.sleep(0.0001)
time.sleep(0.001)
return
elif self.zmqServer.state_mode == 'predict': # 测试状态
@@ -395,7 +404,7 @@ class Decoder_main(threading.Thread):
0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]])
self.plotLabel.append(self.currentLabel)
else:
time.sleep(0.0001)
time.sleep(0.001)
return
elif self.zmqServer.state_mode == 'predict' and self.load_model == True: # 测试状态
@@ -408,7 +417,7 @@ class Decoder_main(threading.Thread):
if self.zmqServer.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \
self.interval_epoch[1] \
+ self.zmqServer.event_inner_idx:
time.sleep(0.0001)
time.sleep(0.001)
return
originalData = self.zmqServer.paradigmBuffer.get_MIData() # 读取全部数据
algo_log(f"取出的:{originalData.shape},event: {originalData[-2, self.zmqServer.event_inner_idx]}", level="DEBUG")