-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathFNET.py
More file actions
115 lines (76 loc) · 3.41 KB
/
FNET.py
File metadata and controls
115 lines (76 loc) · 3.41 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
# Kinda like using the best of latest researchs
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional
class SwiGLU(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
def forward(self, x):
# SwiGLU: Swish(xW1) ⊙ (xW3) W2
return self.w2(F.silu(self.w1(x)) * self.w3(x))
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
# More numerically stable RMSNorm
# Calculate mean of squares
mean_square = x.pow(2).mean(dim=-1, keepdim=True)
# RMS normalization
rms = torch.sqrt(mean_square + self.eps)
return self.weight * x / rms
class FnetBlock(nn.Module):
def __init__(self, embed_dim):
super().__init__()
self.rmsnorm1 = RMSNorm(embed_dim)
self.rmsnorm2 = RMSNorm(embed_dim)
# self.mlp = nn.Sequential(
# nn.Linear(embed_dim, embed_dim*2),
# nn.ReLU(),
# nn.Linear(embed_dim*2, embed_dim*2),
# nn.ReLU(),
# nn.Linear(embed_dim*2, embed_dim)
# )
self.mlp = SwiGLU(embed_dim, embed_dim*4)
def forward(self, x):
out = x + torch.fft.fft(self.rmsnorm1(x), dim=1).real
out = out + self.mlp(self.rmsnorm2((out)))
return out
class FNET(nn.Module):
def __init__(self, embed_dim, context_length, vocab_size, num_layers=3, lr=0.0001):
super().__init__()
self.context_length = context_length
self.word_embeddings = nn.Embedding(vocab_size, embed_dim)
self.pos_embeddings = nn.Embedding(context_length, embed_dim)
self.blocks = nn.ModuleList([FnetBlock(embed_dim) for _ in range(num_layers)])
self.norm = RMSNorm(embed_dim)
self.output = nn.Linear(embed_dim, vocab_size, bias=False)
print(f"self.output.weight.shape: {self.output.weight.shape}")
print(f"self.word_embeddings.weight.shape: {self.word_embeddings.weight.shape}")
self.output.weight = self.word_embeddings.weight
def forward(self, input_ids, attention_mask:Optional[torch.tensor]=None):
embs = self.word_embeddings(input_ids) + self.pos_embeddings(torch.arange(0, self.context_length).to(input_ids.device))
if attention_mask:
attention_mask = torch.tril(torch.ones((self.context_length, self.context_length), device=input_ids.device))
mask = attention_mask.unsqueeze(-1).expand_as(embs)
# print(mask)
embs = embs*mask
for layer in self.blocks:
embs = layer(embs)
embs = self.norm(embs)
logits = self.output(embs)
return logits
if __name__ == "__main__":
x = torch.randn((1, 5, 512))
block = FnetBlock(512)
out = block(x)
print(out.shape)
input_ids = torch.randint(0, 20002, (1, 10))
model = FNET(512, 10, 20002)
out = model(input_ids)
print(out.shape)