update float32 to float64

This commit is contained in:
2026-06-08 15:47:25 +08:00
parent ac0de93e31
commit 31d91d6cc7
6 changed files with 10 additions and 10 deletions

View File

@@ -82,7 +82,7 @@ class MultiHeadAttention(nn.Module):
values = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.num_heads)
energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)
if mask is not None:
fill_value = torch.finfo(torch.float32).min
fill_value = torch.finfo(torch.float64).min
energy.mask_fill(~mask, fill_value)
scaling = self.emb_size ** (1 / 2)

View File

@@ -71,7 +71,7 @@ class MultiHeadAttention(nn.Module):
values = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.num_heads)
energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)
if mask is not None:
fill_value = torch.finfo(torch.float32).min
fill_value = torch.finfo(torch.float64).min
energy.mask_fill(~mask, fill_value)
scaling = self.emb_size ** (1 / 2)