365 lines
14 KiB
Python
365 lines
14 KiB
Python
import time
|
||
|
||
from psychopy import visual, core, logging # import some libraries from PsychoPy
|
||
import random
|
||
from datetime import datetime
|
||
|
||
# LAB STREAMING LAYER1
|
||
from pylsl import StreamInfo, StreamOutlet
|
||
from psychopy import event
|
||
import numpy as np
|
||
from DecoderDW.Server import TCPServer
|
||
from DecoderDW.Client import TCPClient
|
||
# import subprocess
|
||
|
||
# ----------------------
|
||
# constants
|
||
# size of the window
|
||
WINWIDTH = 1920
|
||
WINHEIGHT = 1080
|
||
REFRESH_RATE = 144
|
||
|
||
|
||
|
||
def get_keypress():
|
||
keys = event.getKeys()
|
||
if keys:
|
||
return keys[0]
|
||
else:
|
||
return None
|
||
|
||
|
||
def shutdown(win,client):
|
||
client.send_data('saveData', 0)
|
||
client.send_data('predict',2)
|
||
win.close()
|
||
core.quit()
|
||
|
||
|
||
# end of configuration
|
||
# ----------------------
|
||
|
||
def generate_square_wave(frequency, sampling_rate=REFRESH_RATE, duration=5):
|
||
"""
|
||
生成方波序列
|
||
|
||
参数:
|
||
frequency (float): 频率(Hz)
|
||
sampling_rate (int): 采样率(Hz),应与屏幕刷新率一致
|
||
duration (float): 时长(秒)
|
||
|
||
返回:
|
||
square_wave (list): 方波序列
|
||
"""
|
||
# 计算总点数
|
||
n_points = int(duration * sampling_rate)
|
||
|
||
# 生成时间序列
|
||
time = np.linspace(0, duration, n_points, endpoint=False)
|
||
|
||
# 生成正弦波数据
|
||
sin_wave = np.sin(2 * np.pi * frequency * time)
|
||
# 生成方波数据
|
||
square_wave = np.where(sin_wave >= 0, 1, 0)
|
||
|
||
return square_wave.tolist()
|
||
|
||
|
||
# 启动一个进程,不等待其完成
|
||
import os
|
||
if __name__ == "__main__":
|
||
# ----------------------------------------------------------------------------------
|
||
# main window settings
|
||
main_win = visual.Window(size=(WINWIDTH, WINHEIGHT), units='height', screen=0, fullscr=False,
|
||
gammaErrorPolicy='warn', color=(0.7, 0.7, 0.7))
|
||
print('starting 1')
|
||
# Set up LabStreamingLayer stream.
|
||
info = StreamInfo(name='psychopy_stimuli', type='Markers', channel_count=1, channel_format='string',
|
||
source_id='psychopy_stimuli_001')
|
||
outlet = StreamOutlet(info) # Broadcast the stream.
|
||
|
||
imageStim1 = visual.ImageStim(main_win, size=(300, 300), pos=(-600, 300), units='pix', image='UI/figures/xy.jpg')
|
||
txtStim1 = visual.TextStim(win=main_win, text='△', font='SimHei', height=80, color='black', units='pix', bold=True,
|
||
italic=False, pos=(-600, 30))
|
||
|
||
imageStim2 = visual.ImageStim(main_win, size=(300, 300), pos=(0, 300), units='pix', image='UI/figures/xy.jpg')
|
||
txtStim2 = visual.TextStim(win=main_win, text='△', font='SimHei', height=80, color='black', units='pix', bold=True,
|
||
italic=False, pos=(0, 30))
|
||
|
||
imageStim3 = visual.ImageStim(main_win, size=(300, 300), pos=(600, 300), units='pix', image='UI/figures/xy.jpg')
|
||
txtStim3 = visual.TextStim(win=main_win, text='△', font='SimHei', height=80, color='black', units='pix', bold=True,
|
||
italic=False, pos=(600, 30))
|
||
imageStim4 = visual.ImageStim(main_win, size=(300, 300), pos=(-600, -200), units='pix', image='UI/figures/xy.jpg')
|
||
txtStim4 = visual.TextStim(win=main_win, text='△', font='SimHei', height=80, color='black', units='pix', bold=True,
|
||
italic=False, pos=(-600, -470))
|
||
imageStim5 = visual.ImageStim(main_win, size=(300, 300), pos=(0, -200), units='pix', image='UI/figures/xy.jpg')
|
||
txtStim5 = visual.TextStim(win=main_win, text='△', font='SimHei', height=80, color='black', units='pix', bold=True,
|
||
italic=False, pos=(0, -470))
|
||
imageStim6 = visual.ImageStim(main_win, size=(300, 300), pos=(600, -200), units='pix', image='UI/figures/xy.jpg')
|
||
txtStim6 = visual.TextStim(win=main_win, text='△', font='SimHei', height=80, color='black', units='pix', bold=True,
|
||
italic=False, pos=(600, -470))
|
||
imageStim1red = visual.ImageStim(main_win, size=(300, 300), pos=(-600, 300), units='pix', image='UI/figures/xy_red.jpg')
|
||
imageStim2red = visual.ImageStim(main_win, size=(300, 300), pos=(0, 300), units='pix', image='UI/figures/xy_red.jpg')
|
||
imageStim3red = visual.ImageStim(main_win, size=(300, 300), pos=(600, 300), units='pix', image='UI/figures/xy_red.jpg')
|
||
imageStim4red = visual.ImageStim(main_win, size=(300, 300), pos=(-600, -200), units='pix', image='UI/figures/xy_red.jpg')
|
||
imageStim5red = visual.ImageStim(main_win, size=(300, 300), pos=(0, -200), units='pix', image='UI/figures/xy_red.jpg')
|
||
imageStim6red = visual.ImageStim(main_win, size=(300, 300), pos=(600, -200), units='pix', image='UI/figures/xy_red.jpg')
|
||
|
||
|
||
frequencies = [25,26,27,28,29,30] #[9,10,11,12,13,14] #[30,31,32,33,34,35] [25,26,27,28,29,30]
|
||
# 生成方波数据
|
||
square_wave_9 = generate_square_wave(frequencies[0], REFRESH_RATE, 5)
|
||
square_wave_11 = generate_square_wave(frequencies[1], REFRESH_RATE, 5)
|
||
square_wave_12 = generate_square_wave(frequencies[2], REFRESH_RATE, 5)
|
||
square_wave_13 = generate_square_wave(frequencies[3], REFRESH_RATE, 5)
|
||
square_wave_14 = generate_square_wave(frequencies[4], REFRESH_RATE, 5)
|
||
square_wave_15 = generate_square_wave(frequencies[5], REFRESH_RATE, 5)
|
||
|
||
# 创建刺激对象列表,便于管理
|
||
image_stims = [imageStim1, imageStim2, imageStim3, imageStim4, imageStim5, imageStim6]
|
||
txt_stims = [txtStim1, txtStim2, txtStim3, txtStim4, txtStim5, txtStim6]
|
||
square_waves = [square_wave_9, square_wave_11, square_wave_12, square_wave_13, square_wave_14, square_wave_15]
|
||
|
||
time.sleep(2)
|
||
# grating.color = 'black'
|
||
server = TCPServer()
|
||
server.start()
|
||
client = TCPClient('127.0.0.1', 8099)
|
||
client.connect()
|
||
print('Connected decoder_main')
|
||
# client.send_data('impedance', 1)
|
||
# time.sleep(20)
|
||
# client.send_data('impedance', 2)
|
||
client.send_data('targetFreqs', frequencies) # 使用frequencies变量,确保与刺激频率一致
|
||
time.sleep(1)
|
||
# 开启全程数据保存到 EEGFiles
|
||
client.send_data('saveData',1)
|
||
# client.send_data('impedance',1)
|
||
|
||
|
||
|
||
# 实验参数
|
||
repeats = 3
|
||
seq_freq = frequencies * repeats
|
||
seq_freq = np.random.permutation(seq_freq).tolist()
|
||
num_trials = len(seq_freq) # 总试验次数, 6*6=36
|
||
trial_count = 0
|
||
|
||
# 在线解码精度计算
|
||
online_results = [] # 存储每个trial的解码结果
|
||
correct_predictions = 0 # 正确预测计数
|
||
|
||
# 保存序列信息
|
||
seq_info = {
|
||
'total_trials': num_trials,
|
||
'frequencies': frequencies,
|
||
'sequence': seq_freq,
|
||
'start_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||
}
|
||
# 保存序列信息到文件
|
||
import json
|
||
seq_file_path = f'EEGFiles/sequence_{datetime.now().strftime("%Y-%m-%d-%H-%M-%S")}.json'
|
||
with open(seq_file_path, 'a', encoding='utf-8') as f:
|
||
json.dump(seq_info, f, ensure_ascii=False, indent=2)
|
||
|
||
|
||
#========================Trials Started======================#
|
||
while trial_count < num_trials:
|
||
# 从序列中获取当前试验的目标频率
|
||
target_freq = seq_freq[trial_count]
|
||
target_freq_index = frequencies.index(target_freq)
|
||
print(f'Trials {trial_count + 1}/{num_trials} - Target Frequency: {target_freq}Hz (Label: {target_freq_index + 1})')
|
||
|
||
# Stage 1: Cue Stage
|
||
# print('Cue Stage: The target frequency is in Red')
|
||
client.send_data('setLabelAndTrialInfo', {
|
||
'label': 0,
|
||
'trial_info': {
|
||
'trial': trial_count + 1,
|
||
'phase': 'cue',
|
||
'target_freq': target_freq
|
||
}
|
||
})
|
||
|
||
for frameN in range(int(1 * REFRESH_RATE)): # 1秒提示
|
||
key_press = get_keypress()
|
||
if key_press in ['q']:
|
||
shutdown(main_win, client)
|
||
|
||
# 显示所有刺激,目标刺激为红色
|
||
for i, stim in enumerate(image_stims):
|
||
if i == target_freq_index:
|
||
# 目标刺激显示红色
|
||
if i == 0:
|
||
imageStim1red.draw()
|
||
elif i == 1:
|
||
imageStim2red.draw()
|
||
elif i == 2:
|
||
imageStim3red.draw()
|
||
elif i == 3:
|
||
imageStim4red.draw()
|
||
elif i == 4:
|
||
imageStim5red.draw()
|
||
elif i == 5:
|
||
imageStim6red.draw()
|
||
else:
|
||
# 其他刺激显示正常颜色
|
||
stim.draw()
|
||
|
||
main_win.flip()
|
||
|
||
# Stage 2: Flanker Stimulus
|
||
# print('Flanker Stage: flank all frequencies')
|
||
client.send_data('predict', 1)
|
||
client.send_data('setLabelAndTrialInfo', {
|
||
'label': target_freq_index + 1, # 设置目标频率标签 这里+1,是因为0代表不记录数据
|
||
'trial_info': {
|
||
'trial': trial_count + 1, # trial 从0开始
|
||
'phase': 'stimulus',
|
||
'target_freq': target_freq
|
||
}
|
||
})
|
||
outlet.push_sample(['S 1'])
|
||
|
||
for frameN in range(6 * REFRESH_RATE): # 6秒刺激
|
||
key_press = get_keypress()
|
||
if key_press in ['q']:
|
||
shutdown(main_win, client)
|
||
|
||
# 所有频率按照方波闪烁
|
||
if square_wave_9[frameN % len(square_wave_9)] == 1:
|
||
imageStim1.draw()
|
||
if square_wave_11[frameN % len(square_wave_11)] == 1:
|
||
imageStim2.draw()
|
||
if square_wave_12[frameN % len(square_wave_12)] == 1:
|
||
imageStim3.draw()
|
||
if square_wave_13[frameN % len(square_wave_13)] == 1:
|
||
imageStim4.draw()
|
||
if square_wave_14[frameN % len(square_wave_14)] == 1:
|
||
imageStim5.draw()
|
||
if square_wave_15[frameN % len(square_wave_15)] == 1:
|
||
imageStim6.draw()
|
||
|
||
main_win.flip()
|
||
if server.ChoosenNum != -1:
|
||
break
|
||
|
||
# 记录在线解码结果
|
||
predicted_freq_index = server.ChoosenNum # 解码结果
|
||
predicted_freq = frequencies[predicted_freq_index] if predicted_freq_index != -1 else -1
|
||
|
||
# 判断解码是否正确
|
||
is_correct = (predicted_freq_index == target_freq_index) if predicted_freq_index != -1 else False
|
||
if is_correct:
|
||
correct_predictions += 1
|
||
|
||
# 记录trial结果
|
||
trial_result = {
|
||
'trial': trial_count + 1,
|
||
'target_freq': target_freq,
|
||
'target_freq_index': target_freq_index,
|
||
'predicted_freq': predicted_freq,
|
||
'predicted_freq_index': predicted_freq_index,
|
||
'is_correct': is_correct,
|
||
'status': 'Success' if predicted_freq_index != -1 else 'Failed'
|
||
}
|
||
online_results.append(trial_result)
|
||
|
||
# 打印当前trial结果
|
||
status_symbol = "✓" if is_correct else "✗"
|
||
if predicted_freq_index == -1:
|
||
print(f'Trial {trial_count + 1}: 目标{target_freq}Hz -> 解码失败 - {status_symbol}')
|
||
else:
|
||
print(f'Trial {trial_count + 1}: 目标{target_freq}Hz -> 预测{predicted_freq}Hz - {status_symbol}')
|
||
|
||
|
||
# Stage 3: Decoding Feedback
|
||
outlet.push_sample(['S 2'])
|
||
client.send_data('setLabelAndTrialInfo', {
|
||
'label': 0, # 反馈阶段标签为0
|
||
'trial_info': {
|
||
'trial': trial_count + 1,
|
||
'phase': 'feedback',
|
||
'target_freq': target_freq
|
||
}
|
||
})
|
||
# print('反馈阶段: 显示解码结果')
|
||
|
||
for frameN in range(1 * REFRESH_RATE): # 1秒反馈
|
||
key_press = get_keypress()
|
||
if key_press in ['q']:
|
||
shutdown(main_win, client)
|
||
|
||
# 显示所有刺激但不闪烁
|
||
for stim in image_stims:
|
||
stim.draw()
|
||
|
||
# 显示解码结果
|
||
if server.ChoosenNum == 0:
|
||
txtStim1.draw()
|
||
elif server.ChoosenNum == 1:
|
||
txtStim2.draw()
|
||
elif server.ChoosenNum == 2:
|
||
txtStim3.draw()
|
||
elif server.ChoosenNum == 3:
|
||
txtStim4.draw()
|
||
elif server.ChoosenNum == 4:
|
||
txtStim5.draw()
|
||
elif server.ChoosenNum == 5:
|
||
txtStim6.draw()
|
||
|
||
main_win.flip()
|
||
|
||
server.ChoosenNum = -1
|
||
trial_count += 1
|
||
|
||
# 计算总体在线解码精度
|
||
total_trials = len(online_results)
|
||
successful_trials = len([r for r in online_results if r['status'] == 'Success'])
|
||
failed_trials = len([r for r in online_results if r['status'] == 'Failed'])
|
||
overall_accuracy = correct_predictions / total_trials if total_trials > 0 else 0
|
||
|
||
# Print Accuracy
|
||
print(f"Total Accuracy: {overall_accuracy:.3f} ({correct_predictions}/{total_trials})")
|
||
|
||
# 按频率分析准确率
|
||
print(f"\n=== 按频率分析准确率 ===")
|
||
freq_accuracy = {}
|
||
for result in online_results:
|
||
freq = result['target_freq']
|
||
if freq not in freq_accuracy:
|
||
freq_accuracy[freq] = {'correct': 0, 'total': 0, 'failed': 0}
|
||
|
||
freq_accuracy[freq]['total'] += 1
|
||
if result['status'] == 'Failed':
|
||
freq_accuracy[freq]['failed'] += 1
|
||
elif result['is_correct']:
|
||
freq_accuracy[freq]['correct'] += 1
|
||
|
||
print(f"{'频率':<8} {'准确率':<8} {'正确/总数':<10} {'失败数':<8}")
|
||
print("-" * 40)
|
||
for freq in sorted(freq_accuracy.keys()):
|
||
stats = freq_accuracy[freq]
|
||
accuracy = stats['correct'] / stats['total'] if stats['total'] > 0 else 0
|
||
print(f"{freq}Hz{'':<4} {accuracy:.3f}{'':<4} {stats['correct']}/{stats['total']}{'':<6} {stats['failed']}")
|
||
|
||
# 保存在线解码结果到文件
|
||
online_results_file = f'EEGFiles/online_results_{datetime.now().strftime("%Y-%m-%d-%H-%M-%S")}.json'
|
||
online_summary = {
|
||
'total_trials': total_trials,
|
||
'successful_trials': successful_trials,
|
||
'failed_trials': failed_trials,
|
||
'correct_predictions': correct_predictions,
|
||
'overall_accuracy': overall_accuracy,
|
||
# 'freq_accuracy': freq_accuracy,
|
||
'trial_results': online_results,
|
||
# 'end_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||
}
|
||
|
||
with open(online_results_file, 'w', encoding='utf-8') as f:
|
||
json.dump(online_summary, f, ensure_ascii=False, indent=2)
|
||
|
||
|
||
client.send_data('predict',2) # 关闭系统
|
||
main_win.close()
|