update float32 to float64

This commit is contained in:
2026-06-08 15:47:25 +08:00
parent ac0de93e31
commit 31d91d6cc7
6 changed files with 10 additions and 10 deletions

View File

@@ -82,7 +82,7 @@ class MultiHeadAttention(nn.Module):
values = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.num_heads)
energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)
if mask is not None:
fill_value = torch.finfo(torch.float32).min
fill_value = torch.finfo(torch.float64).min
energy.mask_fill(~mask, fill_value)
scaling = self.emb_size ** (1 / 2)

View File

@@ -71,7 +71,7 @@ class MultiHeadAttention(nn.Module):
values = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.num_heads)
energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)
if mask is not None:
fill_value = torch.finfo(torch.float32).min
fill_value = torch.finfo(torch.float64).min
energy.mask_fill(~mask, fill_value)
scaling = self.emb_size ** (1 / 2)

View File

@@ -11,7 +11,7 @@ class ParadigmRingBuffer:
def __init__(self, n_chan, n_points):
self.n_chan = n_chan
self.n_points = n_points
self.buffer = np.zeros((n_chan, n_points), dtype=np.float32)
self.buffer = np.zeros((n_chan, n_points), dtype=np.float64)
self.currentPtr = 0
self.readPtr = 0
self.nUpdate = 0

View File

@@ -18,7 +18,7 @@ class FilterRingBuffer:
self.n_chan = n_chan
self.n_points = n_points
self.buffer = np.zeros((n_chan, n_points), dtype=np.float32)
self.buffer = np.zeros((n_chan, n_points), dtype=np.float64)
self.current_ptr = 0 # 写入指针:指向下一个要写入的位置
self.total_samples = 0 # 已写入总点数
self.lock = threading.Lock() # 线程安全锁

View File

@@ -174,7 +174,7 @@ class zmqServer(threading.Thread):
return
# 转置为上位机需要的[50, 通道数]格式
filtered_data = filtered_data.T.astype(np.float32)
filtered_data = filtered_data.T.astype(np.float64)
send_buf = filtered_data.tobytes()
algo_log(f"发送滤波数据,长度: {len(send_buf)}字节, filtered_data.shape: {filtered_data.shape}", level="DEBUG")
self.data_send_queue.put(send_buf)
@@ -292,7 +292,7 @@ class zmqServer(threading.Thread):
return
# 零拷贝解析 + 维度转换
data_np = np.frombuffer(data_bytes, dtype=np.float32)
data_np = np.frombuffer(data_bytes, dtype=np.float64)
data_np = data_np.reshape(self.device_info['frame_points'], self.device_info['channel_nums'])
data_np = data_np.T.astype(np.float64)

View File

@@ -18,7 +18,7 @@ PKT_INTERVAL = N_SAMPLES_PER_PKT / FS
def build_packet(global_sample_idx):
"""
生成一包 [5, 66] 的 float32 数据
生成一包 [5, 66] 的 float64 数据
:param global_sample_idx: 当前包第一个采样点在全局序列中的索引 (从 0 开始)
:return: np.ndarray shape [5, 66]
"""
@@ -31,13 +31,13 @@ def build_packet(global_sample_idx):
eeg = np.tile(eeg, (1, 64)) # [5, 64]
# Ch64: 标签值通道,初始化为 0
event = np.zeros((N_SAMPLES_PER_PKT, 1), dtype=np.float32)
event = np.zeros((N_SAMPLES_PER_PKT, 1), dtype=np.float64)
# Ch65: 标签序号通道,初始化为 0
label_idx = np.zeros((N_SAMPLES_PER_PKT, 1), dtype=np.float32)
label_idx = np.zeros((N_SAMPLES_PER_PKT, 1), dtype=np.float64)
# 拼成 [5, 66]
packet = np.concatenate([eeg, event, label_idx], axis=1).astype(np.float32)
packet = np.concatenate([eeg, event, label_idx], axis=1).astype(np.float64)
return packet