1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
| import torch import torch.nn
class GRU(nn.Module): def __init__(self, vocab_size, embed_size, hidden_size, num_layers=1): super().__init__() self.hidden_size = hidden_size self.num_layers = num_layers self.embedding = nn.Embedding(vocab_size, embed_size) self.gru = nn.GRU(embed_size, hidden_size, num_layers, batch_first=True) self.fc = nn.Linear(hidden_size, vocab_size) def forward(self, x, hidden): x = self.embedding(x) out, hidden = self.gru(x, hidden) out = self.fc(out) return out, hidden
def init_hidden(num_layers, batch_size, hidden_size): num_layers = int(num_layers) batch_size = int(batch_size) hidden_size = int(hidden_size) return torch.zeros(num_layers, batch_size, hidden_size) vocab_size = 10000 embed_size=256 hidden_size=512 num_layers=3 batch_size = 256 device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') n = torch.cuda.device_count() if torch.cuda.is_available() else 1
model = GRU(vocab_size, embed_size, hidden_size, num_layers).to(device) model = nn.DataParallel(model)
for epoch in range(num_epochs): for batch, (block_input, block_target) in enumerate(dataloader): block_input, block_target = block_input.to(device), block_target.to(device) seq_length = 35 hidden = init_hidden(num_layers * n, batch_size / n, hidden_size) num_steps = truncated_length // seq_length block_loss = 0 optimizer.zero_grad() for step in range(num_steps): start = step * seq_length end = start + seq_length
x = block_input[:, start: end] y = block_target[:, start: end]
hidden = hidden.detach()
output, hidden = model(x, hidden)
loss = loss_function(output.reshape(-1, vocab_size), y.reshape(-1)) block_loss += loss.item()
loss.backward(retain_graph=True) nn.utils.clip_grad_norm_(model.parameters(), 5) optimizer.step() block_loss /= num_steps if batch % 100 == 0: print(f'Epoch [{epoch+1}/{num_epochs}], Batch {batch}, Loss: {block_loss:.4f}')
|