Merge branch 'master' of http://47.98.56.110:7001/lizhao/bci_algo
This commit is contained in:
@@ -82,7 +82,7 @@ class MultiHeadAttention(nn.Module):
|
||||
values = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.num_heads)
|
||||
energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)
|
||||
if mask is not None:
|
||||
fill_value = torch.finfo(torch.float32).min
|
||||
fill_value = torch.finfo(torch.float64).min
|
||||
energy.mask_fill(~mask, fill_value)
|
||||
|
||||
scaling = self.emb_size ** (1 / 2)
|
||||
|
||||
@@ -71,7 +71,7 @@ class MultiHeadAttention(nn.Module):
|
||||
values = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.num_heads)
|
||||
energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)
|
||||
if mask is not None:
|
||||
fill_value = torch.finfo(torch.float32).min
|
||||
fill_value = torch.finfo(torch.float64).min
|
||||
energy.mask_fill(~mask, fill_value)
|
||||
|
||||
scaling = self.emb_size ** (1 / 2)
|
||||
|
||||
@@ -11,7 +11,7 @@ class ParadigmRingBuffer:
|
||||
def __init__(self, n_chan, n_points):
|
||||
self.n_chan = n_chan
|
||||
self.n_points = n_points
|
||||
self.buffer = np.zeros((n_chan, n_points), dtype=np.float32)
|
||||
self.buffer = np.zeros((n_chan, n_points), dtype=np.float64)
|
||||
self.currentPtr = 0
|
||||
self.readPtr = 0
|
||||
self.nUpdate = 0
|
||||
|
||||
@@ -18,7 +18,7 @@ class FilterRingBuffer:
|
||||
self.n_chan = n_chan
|
||||
self.n_points = n_points
|
||||
|
||||
self.buffer = np.zeros((n_chan, n_points), dtype=np.float32)
|
||||
self.buffer = np.zeros((n_chan, n_points), dtype=np.float64)
|
||||
self.current_ptr = 0 # 写入指针:指向下一个要写入的位置
|
||||
self.total_samples = 0 # 已写入总点数
|
||||
self.lock = threading.Lock() # 线程安全锁
|
||||
|
||||
@@ -174,7 +174,7 @@ class zmqServer(threading.Thread):
|
||||
return
|
||||
|
||||
# 转置为上位机需要的[50, 通道数]格式
|
||||
filtered_data = filtered_data.T.astype(np.float32)
|
||||
filtered_data = filtered_data.T.astype(np.float64)
|
||||
send_buf = filtered_data.tobytes()
|
||||
algo_log(f"发送滤波数据,长度: {len(send_buf)}字节, filtered_data.shape: {filtered_data.shape}", level="DEBUG")
|
||||
self.data_send_queue.put(send_buf)
|
||||
@@ -292,7 +292,7 @@ class zmqServer(threading.Thread):
|
||||
return
|
||||
|
||||
# 零拷贝解析 + 维度转换
|
||||
data_np = np.frombuffer(data_bytes, dtype=np.float32)
|
||||
data_np = np.frombuffer(data_bytes, dtype=np.float64)
|
||||
data_np = data_np.reshape(self.device_info['frame_points'], self.device_info['channel_nums'])
|
||||
data_np = data_np.T.astype(np.float64)
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ PKT_INTERVAL = N_SAMPLES_PER_PKT / FS
|
||||
|
||||
def build_packet(global_sample_idx):
|
||||
"""
|
||||
生成一包 [5, 66] 的 float32 数据
|
||||
生成一包 [5, 66] 的 float64 数据
|
||||
:param global_sample_idx: 当前包第一个采样点在全局序列中的索引 (从 0 开始)
|
||||
:return: np.ndarray shape [5, 66]
|
||||
"""
|
||||
@@ -32,13 +32,13 @@ def build_packet(global_sample_idx):
|
||||
eeg = np.tile(eeg, (1, 64)) # [5, 64]
|
||||
|
||||
# Ch64: 标签值通道,初始化为 0
|
||||
event = np.zeros((N_SAMPLES_PER_PKT, 1), dtype=np.float32)
|
||||
event = np.zeros((N_SAMPLES_PER_PKT, 1), dtype=np.float64)
|
||||
|
||||
# Ch65: 标签序号通道,初始化为 0
|
||||
label_idx = np.zeros((N_SAMPLES_PER_PKT, 1), dtype=np.float32)
|
||||
label_idx = np.zeros((N_SAMPLES_PER_PKT, 1), dtype=np.float64)
|
||||
|
||||
# 拼成 [5, 66]
|
||||
packet = np.concatenate([eeg, event, label_idx], axis=1).astype(np.float32)
|
||||
packet = np.concatenate([eeg, event, label_idx], axis=1).astype(np.float64)
|
||||
return packet
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user