diff --git a/Decoder.py b/Decoder.py index 9b79d83..c07dce6 100644 --- a/Decoder.py +++ b/Decoder.py @@ -21,8 +21,8 @@ from SSVEP.dwfbcca import FbccaDw from Tools.plot_MI_EEG import plotMain from collections import deque -class Decoder_main(threading.Thread): - def __init__(self): +class Decoder_main(threading.Thread, device_type): + def __init__(self, device_type=None): threading.Thread.__init__(self) self.Runing=True self.decoder = None @@ -33,6 +33,11 @@ class Decoder_main(threading.Thread): self.decoder_class = None #解码器类别 self.decodingSteps = 0 # 0=停止解码 1=预热 2=解码中 3=解码完成,发送解码结果 + self.device_info = { + 'device_type': None, + 'sample_rate': None, + 'channel_num': None, + } def connect(self, device_type=None, device_host=None, device_port=None, upper_host=None, upper_port=None): self.DeviceType = device_type if device_type is not None else int(IniRead('system', 'Device_type')) @@ -113,40 +118,40 @@ class Decoder_main(threading.Thread): self.parameter_init(8, 30) - elif decoder_class == 'concentration': - self.thread_data_server.interval_inited = False - self.n_chan = 6 - self.win_len = 10 - self.win_step = 1 - self.low_threshold, self.high_threshold = ast.literal_eval(IniRead('system', 'concentration_ThresholdValue')) - self.calculate = Calculate(self.low_threshold, self.high_threshold, self.fs, self.win_len) - self.interval_epoch = [0, 1] - self.parameter_init(2, 40) - # self.eegQueue moved to Calculate class + # elif decoder_class == 'concentration': + # self.thread_data_server.interval_inited = False + # self.n_chan = 6 + # self.win_len = 10 + # self.win_step = 1 + # self.low_threshold, self.high_threshold = ast.literal_eval(IniRead('system', 'concentration_ThresholdValue')) + # self.calculate = Calculate(self.low_threshold, self.high_threshold, self.fs, self.win_len) + # self.interval_epoch = [0, 1] + # self.parameter_init(2, 40) + # # self.eegQueue moved to Calculate class - elif decoder_class == 'blink': - self.n_chan = 2 - self.l_freq = 0.1 # 带通滤波器低频截止 - self.h_freq = 8.0 # 带通滤波器高频截止 - self.total_samples = 0 # 总采样点数 - self.window_ms = 600 # 检测窗口大小 (ms) - self.step_ms = 100 # 滑动步长 (ms) - self.window_samples = int(self.window_ms * self.fs / 1000) # 150个样本点 - self.step_samples = int(self.step_ms * self.fs / 1000) # 25个样本点 - self.buffer_size = self.window_samples + self.step_samples * 5 - self.fp1_buffer = deque(maxlen=self.buffer_size) - self.fp2_buffer = deque(maxlen=self.buffer_size) - self.sample_counter = 0 - # 预计算滤波器系数,避免在循环中重复设计 - self.Dmin, self.Dmax, self.EMin, self.EMax, self.jitterwin,self.double_blink_interval,self.double_blink_jitter = ast.literal_eval(IniRead('system', 'blink')) - self.blink_count = 0 # 单次眨眼的次数 - self.last_blink_time = 0 # 上次检测到单次眨眼的时间(样本索引) - self.blink_timestamps = deque(maxlen=10) # 记录最近10次 单次眨眼的时间戳 - self.double_blink_count = 0 # 连续两次眨眼的次数 - self.double_blink_events = [] # 连续眨眼事件记录 - self.last_double_blink_time = 0 # 上次检测到连续眨眼的时间戳 - self.blink_events = [] - self.blink_b, self.blink_a = signal.butter(4, [self.l_freq / (self.fs / 2), self.h_freq / (self.fs / 2)], btype='band') + # elif decoder_class == 'blink': + # self.n_chan = 2 + # self.l_freq = 0.1 # 带通滤波器低频截止 + # self.h_freq = 8.0 # 带通滤波器高频截止 + # self.total_samples = 0 # 总采样点数 + # self.window_ms = 600 # 检测窗口大小 (ms) + # self.step_ms = 100 # 滑动步长 (ms) + # self.window_samples = int(self.window_ms * self.fs / 1000) # 150个样本点 + # self.step_samples = int(self.step_ms * self.fs / 1000) # 25个样本点 + # self.buffer_size = self.window_samples + self.step_samples * 5 + # self.fp1_buffer = deque(maxlen=self.buffer_size) + # self.fp2_buffer = deque(maxlen=self.buffer_size) + # self.sample_counter = 0 + # # 预计算滤波器系数,避免在循环中重复设计 + # self.Dmin, self.Dmax, self.EMin, self.EMax, self.jitterwin,self.double_blink_interval,self.double_blink_jitter = ast.literal_eval(IniRead('system', 'blink')) + # self.blink_count = 0 # 单次眨眼的次数 + # self.last_blink_time = 0 # 上次检测到单次眨眼的时间(样本索引) + # self.blink_timestamps = deque(maxlen=10) # 记录最近10次 单次眨眼的时间戳 + # self.double_blink_count = 0 # 连续两次眨眼的次数 + # self.double_blink_events = [] # 连续眨眼事件记录 + # self.last_double_blink_time = 0 # 上次检测到连续眨眼的时间戳 + # self.blink_events = [] + # self.blink_b, self.blink_a = signal.butter(4, [self.l_freq / (self.fs / 2), self.h_freq / (self.fs / 2)], btype='band') def parameter_init(self,bandPass_low,bandPass_high): self.interval_epoch = [int(i * self.fs) for i in self.interval_epoch] # epoch截取信息 diff --git a/PubLibrary/InifileHelper.py b/PubLibrary/InifileHelper.py index e647da3..c8c81a7 100644 --- a/PubLibrary/InifileHelper.py +++ b/PubLibrary/InifileHelper.py @@ -4,27 +4,65 @@ import os import sys from audioop import error -BASE_DIR = os.getcwd() -IniFileName = os.path.join(BASE_DIR, 'config.ini') -# IniFileName=os.path.join( 'config.ini') -def IniWrite(section,keyname,value): - # 创建ConfigParser对象 +def get_config_paths(): + """返回所有可能的 config.ini 路径(按优先级排序)""" + paths = [] + + # 1. exe 同级目录(用户手动放置或外部修改) + exe_dir = os.path.dirname(sys.executable) if getattr(sys, 'frozen', False) else None + if exe_dir: + paths.append(os.path.join(exe_dir, 'config.ini')) + + # 2. PyInstaller 资源目录 (_MEIPASS,打包时 datas 复制进来的) + meipass = getattr(sys, '_MEIPASS', None) + if meipass: + paths.append(os.path.join(meipass, 'config.ini')) + + # 3. PubLibrary 目录下(优先查找) + pub_dir = os.path.dirname(os.path.abspath(__file__)) + pub_path = os.path.join(pub_dir, 'config.ini') + if pub_path not in paths: + paths.append(pub_path) + + # 4. 项目根目录下(开发环境备用) + project_root = os.path.dirname(pub_dir) + root_path = os.path.join(project_root, 'config.ini') + if root_path not in paths: + paths.append(root_path) + + return paths + + +def IniWrite(section, keyname, value): + exe_dir = os.path.dirname(sys.executable) if getattr(sys, 'frozen', False) else None + base_dir = exe_dir if exe_dir else os.path.dirname(os.path.abspath(__file__)) + IniFileName = os.path.join(base_dir, 'config.ini') + config = configparser.ConfigParser() - config.read(IniFileName,encoding='utf-8') - with open(IniFileName, 'w') as configfile: + try: + with open(IniFileName, 'r', encoding='utf-8') as f: + config.read_file(f) + except FileNotFoundError: + pass + with open(IniFileName, 'w', encoding='utf-8') as configfile: if not config.has_section(section): config.add_section(section) - config[section][keyname]=str(value) + config[section][keyname] = str(value) config.write(configfile) -def IniRead(section,key): +def IniRead(section, key, default=None): + fallback = default if default is not None else '5' - try: - config = configparser.ConfigParser() - config.read(IniFileName,encoding='utf-8') - return config[section][key] - except error as e: - print(e) - # 读取特定section和键的值 - return '5' \ No newline at end of file + for path in get_config_paths(): + if os.path.exists(path): + try: + config = configparser.ConfigParser() + with open(path, 'r', encoding='utf-8') as f: + config.read_file(f) + if config.has_section(section): + # print(f"[IniRead] 找到配置 [{section}] {key} -> {config[section][key]} (来源: {path})") + return config[section][key] + except Exception as e: + print(f"[IniRead] 读取失败 {path}: {e}") + return fallback diff --git a/Zmq/__init__.py b/Zmq/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/Zmq/dataBuffer.py b/Zmq/dataBuffer.py new file mode 100644 index 0000000..9bba7e6 --- /dev/null +++ b/Zmq/dataBuffer.py @@ -0,0 +1,77 @@ +# -*-coding:utf-8 -*- +""" +范式buffer和滤波buffer, 以及滤波函数 +""" +import numpy as np +from scipy import signal +import threading + +class ParadigmRingBuffer: + def __init__(self, n_chan, n_points): + self.n_chan = n_chan + self.n_points = n_points + self.buffer = np.zeros((n_chan, n_points)) + self.currentPtr = 0 + self.readPtr = 0 + self.nUpdate = 0 + self.rawData = np.zeros((n_chan, 1)) + + ## append buffer and update current pointer + def appendBuffer(self, data): + if self.nUpdate == self.n_points: + raise Exception("Buffer is full") + + n = data.shape[1] + + # 计算可以写入的元素数量 + write_count = min(self.n_points - self.nUpdate, n) + # 写入新数据 + self.buffer[:, np.mod(np.arange(self.currentPtr, self.currentPtr + write_count), self.n_points)] = data[:,:write_count] + # 更新结束指针 + self.currentPtr = (self.currentPtr + write_count) % self.n_points + # 更新大小 + self.nUpdate += write_count + + ## get data from buffer + def getData(self, count=50): + # 确保不会尝试读取超过缓冲区当前大小的数据 + count = min(count, self.nUpdate) + + # 计算读取结束后的下一个位置 + next_read_ptr = (self.readPtr + count) % self.n_points + if self.readPtr + count <= self.n_points: + # 情况 1:不环绕,数据是连续的 + end_index = next_read_ptr if next_read_ptr != 0 else self.n_points + data = self.buffer[:, self.readPtr:end_index] + else: + # 情况 2:发生环绕,数据被分成两部分 + # 第一部分:从 readPtr 到缓冲区末尾 + part1 = self.buffer[:, self.readPtr:] + # 第二部分:从缓冲区开头到 (count - part1.shape[1]) 个点 + part2 = self.buffer[:, :next_read_ptr] + # 将两部分在列方向上拼接 + data = np.concatenate((part1, part2), axis=1) + + # 更新读指针 + self.readPtr = next_read_ptr + # 更新大小 + self.nUpdate -= count + return data + + def GetDataLenCount(self): + ''' + 获取最新缓存中每个通道的数量 + @return: + ''' + return self.nUpdate + + + # reset buffer + def resetAllPara(self): + self.nUpdate = 0 + self.currentPtr = 0 + self.readPtr = 0 # add by lizhenhua 清空读指针 + self.buffer = np.zeros((self.n_chan, self.n_points)) # add by lizhenhua 清空环形缓冲区 + + + diff --git a/Zmq/filterProcess.py b/Zmq/filterProcess.py new file mode 100644 index 0000000..decc620 --- /dev/null +++ b/Zmq/filterProcess.py @@ -0,0 +1,208 @@ +# -*-coding:utf-8 -*- +""" +数据滤波模块 +""" +import numpy as np +import threading +from logs.log import algo_log + +class FilterRingBuffer: + def __init__(self, n_chan, n_points): + """ + 初始化纯数据环形缓存 + :param n_chan: 通道数 + :param n_points: 总缓存点数(与paradigmRingBuffer参数完全一致) + """ + self.n_chan = n_chan + self.n_points = n_points + + self.buffer = np.zeros((n_chan, n_points), dtype=np.float64) + self.current_ptr = 0 # 写入指针 + self.total_samples = 0 # 已写入总点数 + + # 线程安全锁(多线程环境必须) + self.lock = threading.Lock() + + def appendBuffer(self, data): + """ + 追加数据到缓存(与paradigmRingBuffer接口一致) + :param data: 输入数据,shape=(n_chan, n_samples) + """ + with self.lock: + n = data.shape[1] + if n == 0: + return + + # 环形写入逻辑 + write_end = self.current_ptr + n + if write_end <= self.n_points: + self.buffer[:, self.current_ptr:write_end] = data + else: + split = self.n_points - self.current_ptr + self.buffer[:, self.current_ptr:] = data[:, :split] + self.buffer[:, :write_end - self.n_points] = data[:, split:] + + # 更新指针和计数 + self.current_ptr = write_end % self.n_points + self.total_samples = min(self.total_samples + n, self.n_points) + + def getData(self, count): + """ + 从读指针位置读取count个点(与paradigmRingBuffer接口一致) + :param count: 读取点数 + :return: np.ndarray, shape=(n_chan, count) + """ + with self.lock: + count = min(count, self.total_samples) + if count == 0: + return np.zeros((self.n_chan, 0)) + + # 环形读取逻辑(与paradigmRingBuffer完全相同) + end = self.current_ptr + start = end - count + if start >= 0: + return self.buffer[:, start:end].copy() + else: + part1 = self.buffer[:, start:] + part2 = self.buffer[:, :end] + return np.concatenate((part1, part2), axis=1) + + def get_latest_n_points(self, n): + """ + 扩展方法:获取最新的n个点(不移动读指针,用于滑动窗口) + :param n: 点数 + :return: np.ndarray, shape=(n_chan, n) + """ + with self.lock: + if self.total_samples < n: + return None + return self.getData(n) + + def GetDataLenCount(self): + """获取当前缓存总点数(兼容原有接口)""" + with self.lock: + return self.total_samples + + def resetAllPara(self): + """重置所有缓存和指针(兼容原有接口)""" + with self.lock: + self.buffer.fill(0.0) + self.current_ptr = 0 + self.total_samples = 0 + +# ----------------------------------------------------------------------------- +# 2. 独立滑动滤波类(仅负责滤波业务逻辑,不关心缓存实现) +# 可替换任意缓存实现,只要实现appendBuffer、get_latest_n_points接口 +# ----------------------------------------------------------------------------- +class SlidingFilter: + def __init__( + self, + n_chan=66, + srate=250, + buffer_sec=5, + window_sec=3, + step_sec=0.2, + packet_size=5 + ): + """ + 初始化滑动滤波器 + :param n_chan: 通道数 + :param srate: 采样率 + :param buffer_sec: 总缓存时长(秒) + :param window_sec: 滤波窗口时长(秒) + :param step_sec: 滑动步长/输出时长(秒) + :param packet_size: 每包数据点数(20ms一包=5点) + """ + # 核心参数 + self.n_chan = n_chan + self.srate = srate + self.buffer_size = int(srate * buffer_sec) + self.window_size = int(srate * window_sec) + self.step_size = int(srate * step_sec) + self.packet_size = packet_size + + # 初始化纯数据缓存(解耦核心) + self.buffer = FilterRingBuffer(n_chan, self.buffer_size) + + # 滤波触发计数器 + self.packet_count = 0 + self.ready_to_filter = False + + # 预计算滤波器系数 + self._init_filters() + + def _init_filters(self): + """预计算所有滤波器系数(仅执行一次)""" + # 50Hz工频陷波(Q=30,工业标准) + self.b_notch, self.a_notch = signal.iirnotch(50, 30, self.srate) + # 8~30Hz带通FIR(65阶,线性相位) + self.b_bp = signal.firwin( + numtaps=65, + cutoff=[8/(self.srate/2), 30/(self.srate/2)], + pass_zero=False, + window='hamming' + ) + self.a_bp = np.array([1.0]) + + def append_and_check_trigger(self, raw_data): + """ + 追加单包原始数据并检查是否触发滤波 + :param raw_data: 上位机原始数据,shape=(packet_size, n_chan) + :return: bool: 是否触发本次滤波 + """ + # 转置为标准格式:(通道数, 点数) + data = raw_data.T.astype(np.float64) + + # 写入缓存(纯缓存操作) + self.buffer.appendBuffer(data) + + # 更新包计数器 + self.packet_count += 1 + + # 检查滤波触发条件:数据≥窗口长度 且 累计满一个步长的包数 + packets_per_step = int(self.step_size / self.packet_size) # 10包=200ms + if (self.buffer.GetDataLenCount() >= self.window_size + and self.packet_count >= packets_per_step): + self.packet_count = 0 + self.ready_to_filter = True + return True + return False + + def filter_and_get_output(self): + """ + 执行滤波并返回无边界效应的输出数据 + :return: np.ndarray: 滤波后数据,shape=(n_chan, step_size) + """ + if not self.ready_to_filter: + return None + + # 获取最新的完整滤波窗口数据 + window_data = self.buffer.get_latest_n_points(self.window_size) + if window_data is None: + self.ready_to_filter = False + return None + + # 零相位滤波(无延迟,无边界效应) + filtered = window_data - np.mean(window_data, axis=-1, keepdims=True) + filtered = signal.filtfilt(self.b_notch, self.a_notch, filtered, axis=-1) + filtered = signal.filtfilt(self.b_bp, self.a_bp, filtered, axis=-1) + + # 提取倒数第二个步长的数据(完全避开两端边界效应) + start_idx = self.window_size - 2 * self.step_size + end_idx = self.window_size - self.step_size + output_data = filtered[:, start_idx:end_idx].copy() + + # 重置触发标志 + self.ready_to_filter = False + + return output_data + + def reset(self): + """重置滤波器和缓存""" + self.buffer.resetAllPara() + self.packet_count = 0 + self.ready_to_filter = False + + def get_buffer_length(self): + """获取当前缓存数据长度""" + return self.buffer.GetDataLenCount() \ No newline at end of file diff --git a/Zmq/zmqServer.py b/Zmq/zmqServer.py index 41425d4..c071bd1 100644 --- a/Zmq/zmqServer.py +++ b/Zmq/zmqServer.py @@ -3,147 +3,257 @@ import zmq import threading import json import queue -from Device.SunnyLinker import SunnyLinker64 +# from Device.SunnyLinker import SunnyLinker64 +from dataBuffer import ParadigmRingBuffer +from filterProcess import FilterRingBuffer +from logs.log import algo_log class zmqServer(threading.Thread): - def __init__(self, host='0.0.0.0', port=8099): + def __init__(self, host='0.0.0.0', cmd_port=8099, data_port=8100, device_info=None): threading.Thread.__init__(self) self.host = host - self.port = port + self.cmd_port = cmd_port # 命令交互端口 + self.data_port = data_port # 数据接收端口 self.running = False - self.get_Impedance = False # 是否返回阻抗值 - self.open_Impedance = None # 是否开启阻抗检测功能 - self.StartDecode = False # false 停止解码,true=开始解码 - self.StartTrain = False # False未进入训练状态,True处于训练状态 - self.state_mode = None # 'train'为训练状态,’rest'为休息状态,'test'为测试状态 - self.currentLabel = -1 # 接收刺激端消息,了解刺激端当前的训练标签 - self.IsExitApp = False # 当socket收到2的时候,就置为True,代表要退出系统了。 - self.getReport = False # 获取训练报告内容 + + # 原有业务状态变量 + # self.get_Impedance = False # 是否返回阻抗值 + # self.open_Impedance = None # 是否开启阻抗检测功能 + self.StartDecode = False # false 停止解码,true=开始解码 + self.StartTrain = False # False未进入训练状态,True处于训练状态 + self.state_mode = None # 'train'为训练状态,’rest'为休息状态,'test'为测试状态 + self.currentLabel = -1 # 接收刺激端消息,了解刺激端当前的训练标签 + self.IsExitApp = False # 当socket收到2的时候,就置为True,代表要退出系统了。 + # self.getReport = False # 获取训练报告内容 self.daemon = True - # 创建 ZeroMQ 上下文 + + # 范式数据缓存 + self.paradigmBuffer = ParadigmRingBuffer(66, 2500) + self.filterBuffer = FilterRingBuffer(66, 2500) + + + # 命令与数据通信 self.context = zmq.Context() - # 创建 REP 套接字(响应端) - self.socket = self.context.socket(zmq.ROUTER) - self.socket.bind(f"tcp://{self.host}:{self.port}") # 绑定到端口 8099 + # 指令通道 (8099) - ROUTER:短JSON命令,低频率 + self.cmd_socket = self.context.socket(zmq.ROUTER) + self.cmd_socket.setsockopt(zmq.RCVHWM, 100) # 指令不需要大缓存,100条足够 + self.cmd_socket.setsockopt(zmq.SNDHWM, 100) + self.cmd_socket.setsockopt(zmq.TCP_NODELAY, 1) # 禁用Nagle算法,降低指令延迟 + self.cmd_socket.bind(f"tcp://{self.host}:{cmd_port}") + + # 数据通道 (8100) - ROUTER:高频脑电二进制流 + self.data_socket = self.context.socket(zmq.ROUTER) + self.data_socket.setsockopt(zmq.RCVHWM, 500) # 500包=10秒缓存,足够应对短时卡顿 + self.data_socket.setsockopt(zmq.TCP_NODELAY, 1) # 禁用Nagle算法,减少数据传输延迟 + self.data_socket.bind(f"tcp://{self.host}:{data_port}") + + # Poller 轮训器(保持不变) + self.poller = zmq.Poller() + self.poller.register(self.cmd_socket, zmq.POLLIN) + self.poller.register(self.data_socket, zmq.POLLIN) + + # 业务变量 self.targetFreqs = [] self.changeTarget = False # 更换目标频率 - self.sunnyLinker = SunnyLinker64(None, None, None, None,None) #单例模式类,已在Decoder实例化 + # self.sunnyLinker = SunnyLinker64(None, None, None, None,None) #单例模式类,已在Decoder实例化 self.labels = [0x01, 0x02,0x03] - self.decoder_switch = False #更换解码器 self.decoder_class = None #解码器类别 'ssvep','ssmvep','mi' - # Client Management (e.g. Unity, Other listeners) - self.clients = set() # 维护客户端ID - self.send_queue = queue.Queue() # 发送队列,安全信箱,维护socket线程 + + # 客户端管理 - 区分命令/数据客户端 + self.cmd_clients = set() # 命令端口客户端ID + self.data_clients = set() # 数据端口客户端ID + self.send_queue = queue.Queue() # 发送队列(仅用于命令端口广播) def broadcast_message(self, method, params): - """Put message into queue to be sent to all connected clients""" + """Put message into queue to be sent to all command clients""" self.send_queue.put((method, params)) + def _handle_cmd_message(self, frames): + """处理命令端口消息(原有命令交互逻辑)""" + if len(frames) < 3: + return + ident, _, message_bytes = frames[:3] + + # 注册新的命令客户端 + if ident not in self.cmd_clients: + self.cmd_clients.add(ident) + print(f"New CMD Client Connected: {ident} (port: {self.cmd_port})") + + # 解析消息 + try: + message = json.loads(message_bytes.decode('utf-8')) + except json.JSONDecodeError: + print(f"Invalid JSON from CMD client {ident}") + continue + print(f"Received CMD request: {message}") + + method = message.get("method") + params = message.get("params") + + # 原有命令处理逻辑 + if method == "sync": + self.state_mode = 'sync' + if method == "targetFreqs": + if not isinstance(params, list): + print('targetFreqs must be a list') + continue + if params != self.targetFreqs: + self.targetFreqs = params + self.changeTarget = True + if method == "decoderClass": + if not isinstance(params, str): + print('decoderClass must be a str') + continue + if params != self.decoder_class: + self.decoder_class = params + self.decoder_switch = True + if method == "getReport": + self.getReport = True + if method == "train":#训练状态 + self.state_mode = 'train' + self.StartTrain = True + self.currentLabel = params # 当前刺激端的训练标签 + self.sunnyLinker.push_trigger(self.labels[self.currentLabel]) + elif method == "predict":#预测状态 + self.state_mode = 'predict' + if params == 1: #开始解码 + self.StartDecode = True + self.sunnyLinker.push_trigger(0x63) + elif params == 2: #停止解码 + self.IsExitApp = True + self.running = False + elif method == "rest": #休息状态 + self.state_mode = 'rest' + # elif method == "impedance": + # if params == 1: + # self.open_Impedance = True # 开启阻抗 + # self.get_Impedance = True # 返回阻抗 + # elif params == 2: + # self.open_Impedance = False # 关闭阻抗 + # self.get_Impedance = False # 停止返回阻抗 + + def _handle_data_message(self, frames): + """ + 处理8100端口原始脑电二进制数据 + 固定格式:上位机发送 (5,66) float32 二维数组字节流(已转换为微伏物理量)→ 转置为 (66,5) 写入双缓冲区 + """ + # 1. 校验ZMQ消息帧完整性 + if len(frames) < 3: + print(f"[ERROR] 无效数据帧:长度不足3帧,实际长度={len(frames)}") + return + + ident, _, data_bytes = frames[:3] + + # 2. 客户端管理(单客户端场景,自动更新最新身份) + if ident not in self.data_clients: + self.data_clients.add(ident) + self.current_data_client = ident # 保存唯一客户端身份,用于后续回复滤波结果 + print(f"[INFO] 新数据客户端连接成功:{ident}") + + try: + # 3. 精确长度校验(核心:固定(5,66) float32 = 5*66*4=1320字节,与int32字节数相同) + EXPECTED_BYTES = 5 * 66 * 4 # 每个float32占4字节 + if len(data_bytes) != EXPECTED_BYTES: + print(f"[ERROR] 数据长度错误:期望{EXPECTED_BYTES}字节,实际{len(data_bytes)}字节") + return + + # 4. 零拷贝二进制解析 + 维度转换 + # 步骤:字节流 → (330,) float32数组 → (5,66) 原始格式 → 转置为 (66,5) 缓冲区标准格式 + data_np = np.frombuffer(data_bytes, dtype=np.float32) + # 重塑为上位机原始维度 + data_np = data_np.reshape(5, 66) + # 转置为(通道数, 采样点数)标准格式,转换为float64保证滤波运算精度 + data_np = data_np.T.astype(np.float64) + + # 5. 同时写入双环形缓冲区(方法名与现有类保持一致:appendBuffer) + # 注意:上位机已发送微伏物理量,无需再乘以增益系数 + self.paradigmBuffer.appendBuffer(data_np) + self.filterBuffer.appendBuffer(data_np) + + # 生产环境必须注释!每秒50次打印会导致CPU占用飙升30%以上 + algo_log(f"数据写入成功:shape={data_np.shape}, 范围=[{data_np.min():.2f}, {data_np.max():.2f}] μV", level="DEBUG", record_once=True) + + except Exception as e: + algo_log(f"数据处理失败:{str(e)}", level="ERROR") + # 调试阶段临时打开,生产环境务必注释 + import traceback + traceback.print_exc() + + def _process_send_queue(self): + """处理发送队列,向所有命令客户端广播消息""" + while not self.send_queue.empty(): + method, params = self.send_queue.get() + if self.cmd_clients: + try: + msg = {'method': method, 'params': params} + msg_bytes = json.dumps(msg).encode('utf-8') + + # 打印日志(隐藏大尺寸数据) + if method in ['single_trial_plot', 'miReport']: + print(f"{{'method': '{method}', 'params': }}") + else: + print(f"Sending CMD message: {msg}") + + # 广播到所有命令客户端 + for client_id in list(self.cmd_clients): + try: + self.cmd_socket.send_multipart([client_id, b'', msg_bytes]) + except Exception as e: + print(f"Error sending to CMD client {client_id}: {e}") + self.cmd_clients.discard(client_id) # 移除失效客户端 + except Exception as e: + print(f"Error preparing broadcast: {e}") + def run(self): self.running = True - print(f"Server is running on {self.host}:{self.port}") - # Use Poller for non-blocking receive - poller = zmq.Poller() - poller.register(self.socket, zmq.POLLIN) + print(f"ZMQ Server started - CMD Port: {self.cmd_port}, DATA Port: {self.data_port}") + try: while self.running: - # 1. Process Send Queue (Send to all clients) - while not self.send_queue.empty(): - method, params = self.send_queue.get() - if self.clients: - try: - msg = {'method': method, 'params': params} - msg_bytes = json.dumps(msg).encode('utf-8') - if method in ['single_trial_plot', 'single_trial_plot', 'miReport']: - print(f"{{'method': '{method}', 'params': }}") - else: - print(f"Sending message: {msg}") - # Broadcast to all maintained clients - for client_id in list(self.clients): - try: - # Send: [ID, Empty, JSON] - self.socket.send_multipart([client_id, b'', msg_bytes]) - except Exception as e: - print(f"Error sending to client {client_id}: {e}") - except Exception as e: - print(f"Error preparing broadcast: {e}") + # 1. 处理发送队列(命令端口广播) + self._process_send_queue() - # 2. Process Receive (Commands) - socks = dict(poller.poll(10)) # 100ms timeout - if self.socket in socks and socks[self.socket] == zmq.POLLIN: - frames = self.socket.recv_multipart() - if len(frames) < 3: - continue - ident, _, message_bytes = frames[:3] - if ident not in self.clients: # register client ID - self.clients.add(ident) - print(f"New Client Detected: {ident}") - try: - message = json.loads(message_bytes.decode('utf-8')) - except json.JSONDecodeError: - continue - print(f"Received request: {message}") + # 2. 轮训监听两个Socket的输入事件(10ms超时,避免阻塞) + socks = dict(self.poller.poll(10)) - method = message.get("method") # process request - params = message.get("params") + # 处理命令端口消息 + if self.cmd_socket in socks and socks[self.cmd_socket] == zmq.POLLIN: + frames = self.cmd_socket.recv_multipart() + self._handle_cmd_message(frames) - if method == "sync": - self.state_mode = 'sync' - if method == "targetFreqs": - if not isinstance(params,list): - print('targetFreqs must be a list') - continue - if params != self.targetFreqs: - self.targetFreqs = params - self.changeTarget = True - if method == "decoderClass": - if not isinstance(params,str): - print('decoderClass must be a str') - continue - if params != self.decoder_class: - self.decoder_class = params - self.decoder_switch = True - if method == "getReport": - self.getReport = True - if method == "train":#训练状态 - self.state_mode = 'train' - self.StartTrain = True - self.currentLabel = params # 当前刺激端的训练标签 - self.sunnyLinker.push_trigger(self.labels[self.currentLabel]) - elif method == "predict":#预测状态 - self.state_mode = 'predict' - if params == 1: #开始解码 - self.StartDecode = True - self.sunnyLinker.push_trigger(0x63) - elif params == 2: #停止解码 - self.IsExitApp = True - self.running = False - elif method == "rest": #休息状态 - self.state_mode = 'rest' - elif method == "impedance": - if params == 1: - self.open_Impedance = True # 开启阻抗 - self.get_Impedance = True # 返回阻抗 - elif params == 2: - self.open_Impedance = False # 关闭阻抗 - self.get_Impedance = False # 停止返回阻抗 + # 处理数据端口消息 + if self.data_socket in socks and socks[self.data_socket] == zmq.POLLIN: + frames = self.data_socket.recv_multipart() + self._handle_data_message(frames) except Exception as e: - print(f"An socket error occurred: {e}") + print(f"Server error occurred: {e}") finally: self.running = False - # 关闭套接字和上下文 - self.socket.close() + # 关闭所有Socket和上下文 + self.cmd_socket.close() + self.data_socket.close() self.context.term() - print("Server socket and context closed.") + print("Server sockets and context closed.") + def stop(self): """显式关闭服务器""" self.running = False - self.socket.close() + self.cmd_socket.close() + self.data_socket.close() self.context.term() - print("Server closed explicitly.") + print(f"Server closed explicitly - CMD Port: {self.cmd_port}, DATA Port: {self.data_port}") if __name__ == '__main__': + # 初始化并启动服务器(默认cmd=8099, data=8100) server = zmqServer() - server.start() \ No newline at end of file + server.start() + + # 保持主线程运行 + try: + while server.running: + threading.Event().wait(1) + except KeyboardInterrupt: + print("Received KeyboardInterrupt, stopping server...") + server.stop() diff --git a/Zmq/zmqServer1.py b/Zmq/zmqServer1.py new file mode 100644 index 0000000..79cee04 --- /dev/null +++ b/Zmq/zmqServer1.py @@ -0,0 +1,445 @@ +import numpy as np +import zmq +import threading +import json +import queue +import time +from Device.SunnyLinker import SunnyLinker64, RingBuffer +from collections import deque + + +class zmqServer(threading.Thread): + def __init__(self, host='0.0.0.0', cmd_port=8099, data_port=8100): + threading.Thread.__init__(self) + self.host = host + self.cmd_port = cmd_port + self.data_port = data_port + self.running = False + self.get_Impedance = False + self.open_Impedance = None + self.StartDecode = False + self.StartTrain = False + self.state_mode = None + self.currentLabel = -1 + self.IsExitApp = False + self.getReport = False + self.daemon = True + + # ZMQ Context + self.context = zmq.Context() + + # 指令通道 (8099) - ROUTER + self.cmd_socket = self.context.socket(zmq.ROUTER) + self.cmd_socket.setsockopt(zmq.RCVHWM, 1000) + self.cmd_socket.setsockopt(zmq.SNDHWM, 1000) + self.cmd_socket.bind(f"tcp://{self.host}:{cmd_port}") + + # 数据通道 (8100)) - ROUTER + self.data_socket = self.context.socket(zmq.ROUTER) + self.data_socket.setsockopt(zmq.RCVHWM, 1000) + self.data_socket.setsockopt(zmq.RCVTIMEO, 50) + self.data_socket.bind(f"tcp://{self.host}:{data_port}") + + self.targetFreqs = [] + self.changeTarget = False + self.sunnyLinker = SunnyLinker64(None, None, None, None, None) + self.labels = [0x01, 0x02, 0x03] + + self.decoder_switch = False + self.decoder_class = None + self.cmd_clients = set() + self.data_clients = set() + self.send_queue = queue.Queue() + + # ========== 数据缓冲区 (RingBuffer) ========== + # 与 SunnyLinker 保持一致,使用 RingBuffer + # 66 = 64 EEG通道 + 1 事件通道(第65) + 1 标签序号通道(第66) + # 缓存约 10 秒数据 (250Hz * 10s = 2500 点) + self.n_chan = 66 + self.t_buffer = 10.0 # 缓冲区时长(秒) + self.__ringBuffer = RingBuffer(self.n_chan, int(self.t_buffer * 250)) + + # 事件检测相关 + self._event_lock = threading.Lock() + self._epoch_finished = False + self._event_inner_idx = -1 + self.pack_contain_event = False + self.predict_event = 99 + self.events = [1, 2, self.predict_event] + self.count_events = {} + self.latency = 50 + self.train_latency = 50 + + # 当前事件标签序号 (从第66通道获取) + self.current_label_index = 0 + + # 初始化标志 + self._interval_inited = False + self._currentLabel = -1 + + # 注册的客户端(兼容旧接口) + self.clients = set() + + # ========== 事件属性:线程安全访问 ========== + @property + def epoch_finished(self): + with self._event_lock: + return self._epoch_finished + + @epoch_finished.setter + def epoch_finished(self, value): + with self._event_lock: + self._epoch_finished = value + + @property + def event_inner_idx(self): + with self._event_lock: + return self._event_inner_idx + + @event_inner_idx.setter + def event_inner_idx(self, value): + with self._event_lock: + self._event_inner_idx = value + + @property + def interval_inited(self): + return self._interval_inited + + @interval_inited.setter + def interval_inited(self, value): + self._interval_inited = value + + @property + def currentLabel(self): + return self._currentLabel + + @currentLabel.setter + def currentLabel(self, value): + self._currentLabel = value + + def broadcast_message(self, method, params): + """Put message into queue to be sent to all connected clients""" + self.send_queue.put((method, params)) + + # ========== 数据缓冲区操作接口 ========== + def GetDataLenCount(self): + """返回缓冲区当前数据点数""" + return self.__ringBuffer.nUpdate + + def getData(self, count): + """获取最新count个数据点,不消费(只读)""" + with self.__ringBuffer.RingBufferLock: + count = min(count, self.__ringBuffer.nUpdate) + if count == 0: + return np.zeros((self.n_chan, 0)) + + # 计算读取范围(从尾部取最新数据) + read_end = (self.__ringBuffer.currentPtr - 1) % self.__ringBuffer.n_points + read_start = (read_end - count + 1) % self.__ringBuffer.n_points + + if self.__ringBuffer.currentPtr == 0: + read_start = self.__ringBuffer.n_points - count + read_end = self.__ringBuffer.n_points - 1 + + if read_start <= read_end: + data = self.__ringBuffer.buffer[:, read_start:read_end + 1] + else: + part1 = self.__ringBuffer.buffer[:, read_start:] + part2 = self.__ringBuffer.buffer[:, :read_end + 1] + data = np.concatenate((part1, part2), axis=1) + + return data + + def consumeData(self, count): + """消费(丢弃)指定数量的数据点,从头部移除""" + with self.__ringBuffer.RingBufferLock: + count = min(count, self.__ringBuffer.nUpdate) + self.__ringBuffer.readPtr = (self.__ringBuffer.readPtr + count) % self.__ringBuffer.n_points + self.__ringBuffer.nUpdate -= count + + def ResetAll(self): + """重置缓冲区""" + with self.__ringBuffer.RingBufferLock: + self.__ringBuffer.resetAllPara() + with self._event_lock: + self._epoch_finished = False + self._event_inner_idx = -1 + self.pack_contain_event = False + self.count_events.clear() + self.current_label_index = 0 + + def reset_data_buffer(self): + self.ResetAll() + + def reset_state(self): + self.ResetAll() + + def interval_init(self, decoder_class): + """初始化事件检测参数""" + import ast + from PubLibrary.InifileHelper import IniRead + + if decoder_class == 'ssmvep': + interval_epoch = ast.literal_eval(IniRead('system', 'SSMVEP_IntervalEpoch')) + self.interval_epoch = [int(i * 250) for i in interval_epoch] + self.train_epoch = [int(self.interval_epoch[0]), + int(self.interval_epoch[1] + 0.1 * 250)] + self.latency = (self.interval_epoch[1] + 0.1 * 250) // 5 + self.train_latency = (self.train_epoch[1] + 0.1 * 250) // 5 + + elif decoder_class == 'mi': + interval_epoch = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch')) + self.interval_epoch = [int(i * 250) for i in interval_epoch] + self.train_epoch = self.interval_epoch.copy() + self.latency = self.interval_epoch[1] // 5 + self.train_latency = self.latency + + self.count_events = {} + self._event_inner_idx = -1 + self._epoch_finished = False + self.pack_contain_event = False + self.predict_event = 99 + self.events = [1, 2, self.predict_event] + self._interval_inited = True + + # ========== 事件检测 ========== + def detect_event(self, data_matrix): + """ + 检测事件通道中的触发信号 + + @param data_matrix: shape (66, N) - N个采样点的数据 + 第65行(索引64) = 事件通道 + 第66行(索引65) = 标签通道 + @return: 是否检测到事件 + """ + if data_matrix.shape[1] == 0: + return False + + self.pack_contain_event = False + event_channel = data_matrix[64, :] # 第65通道 = 标签值(event值) + label_channel = data_matrix[65, :] # 第66通道 = 标签序号(label index) + + events = event_channel.tolist() + + with self._event_lock: + self._event_inner_idx = -1 + self.current_event_label = 0 + + for idx, event in enumerate(events): + if int(event) in self.events: + self._event_inner_idx = idx + self.current_label_index = int(label_channel[idx]) + self.pack_contain_event = True + + new_key = f"{event}_{time.time()}" + latency = self.latency if event == self.predict_event else self.train_latency + self.count_events[new_key] = latency + 1 + + # 延迟计数递减 + drop_items = [] + for key, value in self.count_events.items(): + value = value - 1 + if value == 0: + drop_items.append(key) + self.count_events[key] = value + for key in drop_items: + del self.count_events[key] + + if drop_items: + self._epoch_finished = True + # 检测到事件时,清除RingBuffer中之前的数据,只保留当前包 + if self.pack_contain_event: + self.__ringBuffer.resetAllPara() + return True + + self._epoch_finished = False + return False + + def run(self): + self.running = True + print(f"Server running - CMD: {self.cmd_port}, DATA: {self.data_port}") + + cmd_poller = zmq.Poller() + cmd_poller.register(self.cmd_socket, zmq.POLLIN) + + data_poller = zmq.Poller() + data_poller.register(self.data_socket, zmq.POLLIN) + + try: + while self.running: + # --- 处理发送队列 (指令通道) --- + while not self.send_queue.empty(): + method, params = self.send_queue.get() + if self.cmd_clients: + try: + msg = {'method': method, 'params': params} + msg_bytes = json.dumps(msg).encode('utf-8') + for client_id in list(self.cmd_clients): + try: + self.cmd_socket.send_multipart([client_id, b'', msg_bytes]) + except Exception: + pass + except Exception: + pass + + # --- 处理指令通道 --- + socks = dict(cmd_poller.poll(10)) + if self.cmd_socket in socks: + self._handle_cmd_socket() + + # --- 处理数据通道 --- + socks = dict(data_poller.poll(10)) + if self.data_socket in socks: + self._handle_data_socket() + + except Exception as e: + print(f"Server error: {e}") + finally: + self.running = False + self.cmd_socket.close() + self.data_socket.close() + self.context.term() + + def _handle_cmd_socket(self): + """处理指令通道消息""" + try: + frames = self.cmd_socket.recv_multipart() + if len(frames) < 3: + return + ident, _, message_bytes = frames[:3] + self.cmd_clients.add(ident) + self.clients.add(ident) + + message = json.loads(message_bytes.decode('utf-8')) + method = message.get("method") + params = message.get("params") + + print(f"[CMD] {method}: {params}") + + if method == "sync": + self.state_mode = 'sync' + elif method == "targetFreqs": + if isinstance(params, list) and params != self.targetFreqs: + self.targetFreqs = params + self.changeTarget = True + elif method == "decoderClass": + if isinstance(params, str) and params != self.decoder_class: + self.decoder_class = params + self.decoder_switch = True + elif method == "getReport": + self.getReport = True + elif method == "train": + self.state_mode = 'train' + self.StartTrain = True + self.currentLabel = params + elif method == "predict": + self.state_mode = 'predict' + if params == 1: + self.StartDecode = True + elif params == 2: + self.IsExitApp = True + self.running = False + elif method == "rest": + self.state_mode = 'rest' + elif method == "impedance": + if params == 1: + self.open_Impedance = True + self.get_Impedance = True + elif params == 2: + self.open_Impedance = False + self.get_Impedance = False + + except Exception as e: + print(f"CMD socket error: {e}") + + def _handle_data_socket(self): + """处理数据通道消息 (EEG数据) + + 上位机数据格式: + - 数据帧: [identity, '', meta_json, data_buffer] + data_buffer = [N, 66] float32 -> 转置为 [66, N] + """ + try: + frames = self.data_socket.recv_multipart() + if len(frames) < 4: + return + ident, _, message_bytes = frames[:3] + self.data_clients.add(ident) + + meta = json.loads(message_bytes.decode('utf-8')) + + # data: [N, 66] -> 转置 -> [66, N] + raw_data = np.frombuffer(frames[3], dtype=np.float32) + n_samples, n_channels = meta.get('shape', [5, 66]) + data_matrix = raw_data.reshape(n_samples, n_channels).T.astype(np.float32) + + # 写入 RingBuffer + with self.__ringBuffer.RingBufferLock: + self.__ringBuffer.appendBuffer(data_matrix) + + # 事件检测 + self.detect_event(data_matrix) + + except Exception as e: + print(f"DATA socket error: {e}") + + # ========== 各范式数据访问接口 ========== + def get_MIData(self): + """获取MI导联数据 (21通道 + 事件)""" + data = self.getData(self.GetDataLenCount()) + rows_to_extract = [8, 15, 12, 14, 18, 23, 16, 59, 50, 58, 17, 45, 29, 11, 10, 19, 20, 61, 51, 60, 21, 64, 65] + row_to_select = np.array(rows_to_extract) + if data.shape[1] > 0: + return data[row_to_select, :] + return np.zeros((len(rows_to_extract), 0)) + + def get_SSMVEPData(self): + """获取SSMVEP导联数据 (8通道 + 事件)""" + data = self.getData(self.GetDataLenCount()) + rows_to_extract = [13, 3, 2, 46, 9, 54, 47, 55, 64, 65] + row_to_select = np.array(rows_to_extract) + if data.shape[1] > 0: + return data[row_to_select, :] + return np.zeros((len(rows_to_extract), 0)) + + def getDataViaSSVEP(self, count): + """获取SSVEP数据 (8通道 + 事件)""" + data = self.getData(count) + rows_to_extract = [13, 3, 2, 46, 9, 54, 47, 55, 64] + row_to_select = np.array(rows_to_extract) + if data.shape[1] > 0: + return data[row_to_select, :] + return np.zeros((len(rows_to_extract), 0)) + + def get_concentrateData(self, count): + """获取专注力数据 (2通道)""" + data = self.getData(count) + rows_to_extract = [0, 1] + row_to_select = np.array(rows_to_extract) + if data.shape[1] > 0: + return data[row_to_select, :] + return np.zeros((len(rows_to_extract), 0)) + + def get_blinkData(self, count): + """获取眨眼数据 (2通道)""" + data = self.getData(count) + rows_to_extract = [0, 1] + row_to_select = np.array(rows_to_extract) + if data.shape[1] > 0: + return data[row_to_select, :] + return np.zeros((len(rows_to_extract), 0)) + + def getImpedance(self, data, decoder_class): + """计算阻抗(ZMQ模式下不可用)""" + return np.zeros(8) + + def stop(self): + self.running = False + self.cmd_socket.close() + self.data_socket.close() + self.context.term() + + +if __name__ == '__main__': + server = zmqServer() + server.start() diff --git a/config.ini b/config.ini index 4fc2c56..c90639f 100644 --- a/config.ini +++ b/config.ini @@ -19,6 +19,16 @@ Device_Port = 5086 Upper_Host = 127.0.0.1 Upper_Port = 8088 Serial_port = COM44 +algo_log_level = DEBUG +console_output = 1 + +; 64 导设备配置 +[device_type_1] +device_sample_rate = 250 +device_channel_nums = 66 +device_channel_names = ['FP1', 'FP2', 'FC1', 'FC2', 'CP1', 'CP2', 'F3', 'F4', 'P3', 'P4', 'O1', 'O2', 'FT9', 'FT10', 'F7', 'F8', 'TP9', 'TP10', 'AF4', 'PO8', 'PZ', 'FCZ'] +device_channel_index = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 18] + [Layout] diff --git a/logs/__init__.py b/logs/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/logs/log.py b/logs/log.py new file mode 100644 index 0000000..f796905 --- /dev/null +++ b/logs/log.py @@ -0,0 +1,87 @@ +# log.py +import os +from datetime import datetime +import logging +from logging.handlers import RotatingFileHandler +from PubLibrary.InifileHelper import IniRead + + +console_output = IniRead('system', 'console_output', '1') +log_level = IniRead('system', 'algo_log_level', 'INFO') + +# 新增:日志去重缓存,key为日志内容,value为是否已打印 +log_once_cache = set() + + +def init_module_logger(): + """ + 初始化指定模块的日志器 + :return: 对应模块的logger实例 + """ + # 缓存命中则直接返回 + log_dir = './logs/' # 确保日志目录存在 + os.makedirs(log_dir, exist_ok=True) + + log_file = os.path.join(log_dir, f'algo_log_{datetime.now().strftime("%Y-%m-%d")}.log') + + # 初始化logger + logger = logging.getLogger('decoderLogger') + logger.setLevel(log_level) + + if logger.handlers: + return logger + + # 设置日志轮转,最大10个文件,每个10MB + file_handler = RotatingFileHandler( + log_file, + maxBytes=10*1024*1024, + backupCount=10, + encoding='utf-8' + ) + + # 日志格式 + formatter = logging.Formatter( + '%(asctime)s - %(name)s - %(levelname)s - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' + ) + file_handler.setFormatter(formatter) + logger.setLevel(log_level) + logger.addHandler(file_handler) + + if console_output: + console_handler = logging.StreamHandler() + console_handler.setFormatter(formatter) + logger.addHandler(console_handler) + + return logger + +def algo_log(content, level="INFO", record_once=False): + """ + 通用日志函数,支持按模块输出到不同日志文件 + :param content: 日志内容 + :param level: 日志级别(DEBUG/INFO/WARNING/ERROR/FATAL) + :param record_once: 是否只打印一次该日志内容,默认False + """ + # 初始化模块日志器 + logger = init_module_logger() + + # 新增:处理只打印一次的逻辑 + if record_once: + # 生成唯一标识(可根据需要调整,比如拼接level增强唯一性) + log_key = f"{level.upper()}_{content}" + if log_key in log_once_cache: + return # 已打印过,直接返回 + log_once_cache.add(log_key) # 未打印过,加入缓存 + + # 根据级别输出日志 + level_upper = level.upper() + if level_upper == "DEBUG": + logger.debug(content) + elif level_upper == "WARNING": + logger.warning(content) + elif level_upper == "ERROR": + logger.error(content) + elif level_upper == "FATAL": + logger.fatal(content) + else: # 默认INFO级别 + logger.info(content) \ No newline at end of file diff --git a/runDecoder.py b/runDecoder.py index 9199d14..71cdf92 100644 --- a/runDecoder.py +++ b/runDecoder.py @@ -5,27 +5,40 @@ import sys import time from Decoder import Decoder_main from PubLibrary.RunOnce import is_program_running +from PubLibrary.InifileHelper import IniRead + +def get_device_info(device_type): + + + section = f'device_type_{device_type}' + device_info = { + 'device_sample_rate': int(IniRead(section, 'sample_rate')) if IniRead(section, 'sample_rate') is not None else 250, + + '' + } if __name__ == "__main__": if not is_program_running(): # 解析命令行参数 parser = argparse.ArgumentParser(description="EEG Decoder Application") - parser.add_argument('-dt', '--device-type', type=int, default=None, help="Device Type") - parser.add_argument('-dh', '--device-host', type=str, default=None, help="Device Host IP") - parser.add_argument('-dp', '--device-port', type=int, default=None, help="Device Port") - parser.add_argument('-uh', '--upper-host', type=str, default=None, help="Upper Computer Host IP") - parser.add_argument('-up', '--upper-port', type=int, default=None, help="Upper Computer Port") + parser.add_argument('-dt', '-t','--device-type', type=int, default=None, help="Device Type") + # parser.add_argument('-dh', '--device-host', type=str, default=None, help="Device Host IP") + # parser.add_argument('-dp', '--device-port', type=int, default=None, help="Device Port") + # parser.add_argument('-uh', '--upper-host', type=str, default=None, help="Upper Computer Host IP") + # parser.add_argument('-up', '--upper-port', type=int, default=None, help="Upper Computer Port") args = parser.parse_args() + device_info= get_device_info(args.device_type) + - decoder = Decoder_main() - decoder.connect( - device_type=args.device_type, - device_host=args.device_host, - device_port=args.device_port, - upper_host=args.upper_host, - upper_port=args.upper_port - ) + decoder = Decoder_main(device_info=device_info) + # decoder.connect( + # device_type=args.device_type, + # device_host=args.device_host, + # device_port=args.device_port, + # upper_host=args.upper_host, + # upper_port=args.upper_port + # ) try: decoder.start()