original push
This commit is contained in:
379
algorithm_V0/datacollect/SunnyLinker.py
Normal file
379
algorithm_V0/datacollect/SunnyLinker.py
Normal file
@@ -0,0 +1,379 @@
|
||||
# -*-coding:utf-8 -*-
|
||||
'''
|
||||
SunnyLinker的通讯驱动
|
||||
'''
|
||||
import ast
|
||||
import socket
|
||||
import threading
|
||||
import time
|
||||
import datetime
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
from threading import Thread, Event
|
||||
import serial
|
||||
from scipy import signal
|
||||
from serial.serialutil import SerialException
|
||||
|
||||
from protocol import ProtocolFrame
|
||||
|
||||
class RingBuffer:
|
||||
def __init__(self, n_chan, n_points):
|
||||
self.n_chan = n_chan
|
||||
self.n_points = n_points
|
||||
self.buffer = np.zeros((n_chan, n_points))
|
||||
self.currentPtr = 0
|
||||
self.readPtr = 0
|
||||
self.nUpdate = 0
|
||||
self.rawData = np.zeros((n_chan, 1))
|
||||
|
||||
## append buffer and update current pointer
|
||||
def appendBuffer(self, data):
|
||||
if self.nUpdate == self.n_points:
|
||||
raise Exception("Buffer is full")
|
||||
|
||||
n = data.shape[1]
|
||||
|
||||
# 计算可以写入的元素数量
|
||||
write_count = min(self.n_points - self.nUpdate, n)
|
||||
# 写入新数据
|
||||
self.buffer[:, np.mod(np.arange(self.currentPtr, self.currentPtr + write_count), self.n_points)] = data[:,:write_count]
|
||||
# 更新结束指针
|
||||
self.currentPtr = (self.currentPtr + write_count) % self.n_points
|
||||
# 更新大小
|
||||
self.nUpdate += write_count
|
||||
|
||||
## get data from buffer
|
||||
def getData(self, count=50):
|
||||
# 确保不会尝试读取超过缓冲区当前大小的数据
|
||||
count = min(count, self.nUpdate)
|
||||
|
||||
# 计算读取结束后的下一个位置
|
||||
next_read_ptr = (self.readPtr + count) % self.n_points
|
||||
if self.readPtr + count <= self.n_points:
|
||||
# 情况 1:不环绕,数据是连续的
|
||||
end_index = next_read_ptr if next_read_ptr != 0 else self.n_points
|
||||
data = self.buffer[:, self.readPtr:end_index]
|
||||
else:
|
||||
# 情况 2:发生环绕,数据被分成两部分
|
||||
# 第一部分:从 readPtr 到缓冲区末尾
|
||||
part1 = self.buffer[:, self.readPtr:]
|
||||
# 第二部分:从缓冲区开头到 (count - part1.shape[1]) 个点
|
||||
part2 = self.buffer[:, :next_read_ptr]
|
||||
# 将两部分在列方向上拼接
|
||||
data = np.concatenate((part1, part2), axis=1)
|
||||
|
||||
# 更新读指针
|
||||
self.readPtr = next_read_ptr
|
||||
# 更新大小
|
||||
self.nUpdate -= count
|
||||
return data
|
||||
|
||||
# reset buffer
|
||||
def resetAllPara(self):
|
||||
self.nUpdate = 0
|
||||
self.currentPtr = 0
|
||||
self.readPtr = 0 # add by lizhenhua 清空读指针
|
||||
self.buffer = np.zeros((self.n_chan, self.n_points)) # add by lizhenhua 清空环形缓冲区
|
||||
|
||||
|
||||
class SunnyLinker64(Thread, ):
|
||||
t_buffer = 10
|
||||
n_chan = 64
|
||||
srate = 250
|
||||
receiveData = b''
|
||||
toUv=True#转为uV
|
||||
RingBufferLock = threading.Lock()
|
||||
|
||||
# 单例模式
|
||||
_instance = None
|
||||
_initialized = False # 检查是否已经初始化
|
||||
def __new__(cls, *args, **kwargs):
|
||||
if cls._instance is None:
|
||||
cls._instance = super(SunnyLinker64, cls).__new__(cls)
|
||||
return cls._instance
|
||||
def __init__(self, host='127.0.0.1', port=7878, srate=250, n_chan=64,method = 'tcp'):
|
||||
if SunnyLinker64._initialized:
|
||||
return
|
||||
Thread.__init__(self)
|
||||
self.daemon = True
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.srate = srate
|
||||
self.n_chan = n_chan
|
||||
self.method = method #传输方式,'tcp'表示tcp传输,'serial'表示串口传输
|
||||
self.__ringBuffer = RingBuffer(self.n_chan + 2,
|
||||
int(np.round(self.t_buffer * self.srate)))
|
||||
self.energy = 0 # 电量
|
||||
self.status_code = 0 # 与采集设备通信的状态码,0为异常,1为正常
|
||||
self.gain_value = 6 # 增益倍数
|
||||
|
||||
# 设置初始化标志为True,防止重复初始化
|
||||
SunnyLinker64._initialized = True
|
||||
|
||||
# --- 新增:用于心跳检测 ---
|
||||
self.last_called = 0 # 初始化为0
|
||||
self.last_called_lock = threading.Lock() # 保护 last_called 的访问
|
||||
|
||||
|
||||
|
||||
def set_sampleRate(self,sampleRate_Code=0x00):
|
||||
'''
|
||||
设置采样率
|
||||
:param sampleRate_Code: 0x00:250Hz,0x01:500Hz,0x02:1000Hz,0x03:2000Hz
|
||||
'''
|
||||
function_code = 0x02
|
||||
gain_code = 0x06
|
||||
sampleRate_Code = [gain_code,sampleRate_Code]
|
||||
packed_data = ProtocolFrame.pack(function_code, sampleRate_Code)
|
||||
if self.method == 'tcp':
|
||||
self.sock.send(packed_data)
|
||||
|
||||
def push_trigger(self,label):
|
||||
'''
|
||||
数据打标
|
||||
@param label:标签类别
|
||||
'''
|
||||
function_code = None
|
||||
label = [label]
|
||||
packed_data = ProtocolFrame.pack(function_code, label)
|
||||
if self.method == 'tcp' and hasattr(self,'serial'):
|
||||
print('发送:', label, datetime.datetime.now().strftime('%H:%M:%S.%f')[:-3])
|
||||
self.serial.write(packed_data)
|
||||
def Impedance(self, On):
|
||||
'''
|
||||
阻抗检测开关
|
||||
:param On:True为开启,False为关闭
|
||||
:return: 组好的协议帧
|
||||
'''
|
||||
function_code = 0x01
|
||||
if On:
|
||||
data = [0x1]
|
||||
self.gain_value = 6
|
||||
else:
|
||||
data = [0x0]
|
||||
self.gain_value = 6
|
||||
packed_data = ProtocolFrame.pack(function_code, data)
|
||||
if self.method == 'tcp':
|
||||
self.sock.send(packed_data)
|
||||
|
||||
def connect(self):
|
||||
try:
|
||||
if self.method == 'serial':
|
||||
# 开启com口,波特率115200,超时5
|
||||
self.sock = serial.Serial(self.host, self.port, timeout=5)
|
||||
self.sock.flushInput() # 清空缓冲区
|
||||
count = self.sock.inWaiting() # 获取串口缓冲区数据
|
||||
while not count:
|
||||
count = self.sock.inWaiting() # 获取串口缓冲区数据
|
||||
# # 接收和存储数据
|
||||
data = (self.sock.read(count))
|
||||
self.receiveData = self.receiveData + data # 将接收数据存储在字符串中
|
||||
elif self.method == 'tcp':
|
||||
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
self.sock.connect((self.host, int(self.port)))
|
||||
self.set_sampleRate(0x00) #设置250Hz采样率
|
||||
except Exception as e:
|
||||
print("请打开头环")
|
||||
print(e)
|
||||
|
||||
print("connected")
|
||||
|
||||
def extract_packet(self, packet):
|
||||
# 存储一个点的八通道数据
|
||||
dataList = []
|
||||
# 存储116个点的八通道数据
|
||||
dataMatrix = []
|
||||
|
||||
for j in range(5):
|
||||
for i in range(self.n_chan):
|
||||
if not self.toUv:#原始数据直接输出
|
||||
val = (packet[194 * j + 25 + i * 3] << 16) | (packet[194 * j + 25 + 1 + i * 3] << 8) | packet[
|
||||
194 * j + 25 + 2 + i * 3]
|
||||
|
||||
else:#转为uV
|
||||
val = (packet[194 * j + 25 + i * 3] << 16) | (packet[194 * j + 25 + 1 + i * 3] << 8) | packet[
|
||||
194 * j + 25 + 2 + i * 3]
|
||||
if val < 8388608:
|
||||
val = val * 4.5 / self.gain_value / 8388608 * 1000000;
|
||||
else:
|
||||
val = (val - 16777216) * 4.5 / self.gain_value / 8388608 * 1000000;
|
||||
dataList.append(val)
|
||||
#同步触发源
|
||||
val = packet[194 * j + 25 + (i+1) * 3]
|
||||
dataList.append(val)
|
||||
#同步触发序号
|
||||
val = packet[194 * j + 25 + (i+1) * 3+1]
|
||||
dataList.append(val)
|
||||
|
||||
|
||||
# 将数据矩阵进行拼接
|
||||
if len(dataMatrix) == 0:
|
||||
dataMatrix = np.asmatrix(dataList)
|
||||
else:
|
||||
dataMatrix = np.concatenate((dataMatrix, np.asmatrix(dataList)), axis=0)
|
||||
dataList.clear()
|
||||
return np.transpose(dataMatrix)
|
||||
|
||||
def run(self):
|
||||
self.connect()
|
||||
self.running = True
|
||||
self.PackageLength = 998
|
||||
# 启动心跳检测线程
|
||||
threading.Thread(target=self.heartbeat_checker, daemon=True).start()
|
||||
while self.running:
|
||||
try:
|
||||
if self.method == 'serial':
|
||||
count = self.sock.inWaiting() # 获取串口缓冲区数据
|
||||
if count:
|
||||
# 接收和存储数据
|
||||
data = (self.sock.read(count))
|
||||
self.receiveData = self.receiveData + data # 将接收数据存储在字符串中
|
||||
elif self.method == 'tcp':
|
||||
data = self.sock.recv(600)
|
||||
if not data:
|
||||
break
|
||||
self.receiveData += data
|
||||
with self.last_called_lock:
|
||||
self.last_called = time.time()
|
||||
self.status_code = 1 # 收到数据,标记为正常
|
||||
if len(self.receiveData) >= self.PackageLength and self.receiveData.rfind(
|
||||
b'\x55\x55') >= self.PackageLength - 2:
|
||||
|
||||
index = self.receiveData.index(b'\xaa')
|
||||
self.receiveData = self.receiveData[index:]
|
||||
if len(self.receiveData) >= self.PackageLength:
|
||||
onepackage = self.receiveData[:self.PackageLength]
|
||||
if onepackage[7] != 0:
|
||||
self.energy = onepackage[7] # 电量
|
||||
self.receiveData = self.receiveData[self.PackageLength:]
|
||||
dataMatrix = self.extract_packet(onepackage)
|
||||
try:
|
||||
with self.RingBufferLock:
|
||||
self.__ringBuffer.appendBuffer(dataMatrix)
|
||||
except Exception as e:
|
||||
print("锁:写入异常",e)
|
||||
# self.RingBufferLock.release()
|
||||
except ConnectionResetError:
|
||||
self.status_code = 0 # 状态异常
|
||||
print("Connection was reset by the peer.")
|
||||
break
|
||||
self.sock.close()
|
||||
|
||||
# --- 新增:心跳检测线程 ---
|
||||
def heartbeat_checker(self):
|
||||
"""
|
||||
定期检查是否在最近2秒内收到 eegData
|
||||
如果超过2秒未收到,则设置 status_code = 0
|
||||
"""
|
||||
while self.running:
|
||||
time.sleep(0.5) # 每0.5秒检查一次
|
||||
with self.last_called_lock:
|
||||
now = time.time()
|
||||
# 只有收到过一次数据后才开始判断超时
|
||||
if self.last_called > 0 and (now - self.last_called) > 2:
|
||||
if self.status_code != 0:
|
||||
print("EEG data timeout: disconnected")
|
||||
self.status_code = 0
|
||||
def getImpedance(self, data,n_chan):
|
||||
'''
|
||||
获取阻抗值,已经放大100倍,单位是kΩ
|
||||
@param data: 准备计算的通道数据,每通道200个值,注意不要把信号打标的通道传进来
|
||||
@return: 返回各个通道的阻抗值
|
||||
'''
|
||||
impedanceList = []
|
||||
data = data[:n_chan]
|
||||
for channelindex in range(data.shape[0]):
|
||||
if len(data[channelindex]) > 0:
|
||||
data_list = []
|
||||
# 设计陷波滤波器,去除50Hz成分
|
||||
is50filter = True
|
||||
if is50filter:
|
||||
b, a = signal.iirnotch(50, 30, self.srate) # 30是带宽,1000是采样频率
|
||||
data_list = signal.lfilter(b, a, data[channelindex].tolist())
|
||||
|
||||
else:
|
||||
data_list.extend(data[channelindex].tolist())
|
||||
|
||||
data_list = data_list[-1000:]
|
||||
# 执行FFT
|
||||
fft_result = np.fft.fft(data_list)
|
||||
fft_magnitude = np.abs(fft_result / len(data_list)) # 归一化FFT结果
|
||||
freqs = np.fft.fftfreq(len(data_list), d=1 / self.srate) # 频率轴
|
||||
|
||||
# y_amp_modified = np.concatenate(([fft_magnitude[0] / len(t[0].tolist())],
|
||||
# fft_magnitude[1:-1] * 2 / len(t[0].tolist()),
|
||||
# [fft_magnitude[-1] / len(t[0].tolist())]))
|
||||
|
||||
# 找到幅值最大的频率成分的索引(忽略直流分量,即索引0)
|
||||
max_index = np.argmax(fft_magnitude[1:])
|
||||
|
||||
# 获取最大幅值的频率索引(加上1,因为索引0是直流分量)
|
||||
freq_index = max_index + 1
|
||||
|
||||
# 获取最大幅值
|
||||
max_magnitude = fft_magnitude[freq_index]
|
||||
|
||||
# 阻抗
|
||||
import math
|
||||
result = math.sqrt(2) * math.pi * max_magnitude / 6 / 4
|
||||
result *= 0.44 * 100 # 统一放大100倍
|
||||
impedanceList.append(int(result))
|
||||
# print(max_magnitude, result)
|
||||
else:
|
||||
impedanceList.append(0)
|
||||
impedances = np.array(impedanceList)
|
||||
return impedances
|
||||
def getData(self,count):
|
||||
'''
|
||||
获取最新的数据
|
||||
@param count: 每通道返回的最数值数目
|
||||
@return: 所有通道的最新count个数值
|
||||
'''
|
||||
data=None
|
||||
try:
|
||||
with self.RingBufferLock:
|
||||
data = self.__ringBuffer.getData(count)
|
||||
except:
|
||||
print("锁:读取异常")
|
||||
# self.RingBufferLock.release()
|
||||
|
||||
|
||||
return data
|
||||
def GetDataLenCount(self):
|
||||
'''
|
||||
获取最新缓存中每个通道的数量
|
||||
@return:
|
||||
'''
|
||||
return self.__ringBuffer.nUpdate
|
||||
|
||||
def ResetAll(self):
|
||||
'''
|
||||
清空缓存
|
||||
@return:
|
||||
'''
|
||||
with self.RingBufferLock:
|
||||
self.__ringBuffer.resetAllPara()
|
||||
def stop(self):
|
||||
self.running = False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Usage
|
||||
Linker = SunnyLinker64('127.0.0.1', 5086, 1000, 65)
|
||||
Linker.start()
|
||||
|
||||
try:
|
||||
while True:
|
||||
time.sleep(0.005)
|
||||
if(Linker.count()>0):
|
||||
# print(Linker.ringBuffer.nUpdate)
|
||||
t = Linker.getData()
|
||||
print(t.shape[1], Linker.count())
|
||||
# Linker.ringBuffer.nUpdate=0
|
||||
# time.sleep(0.2)
|
||||
except KeyboardInterrupt:
|
||||
Linker.stop()
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user