import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import math, copy, time from torch.autograd import Variable import matplotlib.pyplot as plt import seaborn seaborn.set_context(context="talk") %matplotlib inline
defclones(module, N): """ Produce N identical layers. """ return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
2.1 Encoder
encoder是由6个相同的模块堆叠在一起的:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
classEncoder(nn.Module): """ Core encoder is a stack of N layers. """ def__init__(self, layer, N): super(Encoder, self).__init__() self.layers = clones(layer, N) self.norm = LayerNorm(layer.size) defforward(self, x, mask): """ Pass the input (and mask) through each layer in turn. """ for layer in self.layers: x = layer(x, mask) return self.norm(x)
classSublayerConnection(nn.Module): """ A residual connection followed by a layer norm. See http://jmlr.org/papers/v15/srivastava14a.html for dropout detail and https://arxiv.org/abs/1512.03385 for residual connection detail. """ def__init__(self, size, dropout): super(SublayerConnection, self).__init__() self.norm = LayerNorm(size) self.dropout = nn.Dropout(dropout) defforward(self, x, sublayer): """ Apply residual connection to any sublayer with the same size. """ return x + self.dropout(sublayer(self.norm(x)))
classMultiHeadAttention(nn.Module): """ Build Multi-Head Attention sub-layer. """ def__init__(self, h, d_model, dropout=0.1): """ :params h: int, number of heads :params d_model: model size :params dropout: rate of dropout """ super(MultiHeadAtention, self).__init__() assert d_model % h == 0 # According to the paper, d_v always equals to d_k # and d_v = d_k = d_model / h = 64 self.d_k = d_model // h self.h = h # following K, Q, V and `Concat`, so we need 4 linears self.linears = clones(nn.Linear(d_model, d_model), 4) self.attn = None self.dropout = nn.Dropout(p=dropout) defforward(self, query, key, value, mask=None): """ Implement Multi-Head Attention. :params query: query embedding matrix, Q in above figure left :params key: key embedding matrix, K in above figure left :params value value embedding matrix, V in above figure left :params mask: sub-sequence mask """ if mask isnotNone: # same mask applied to all heads mask = mask.unsequeeze(1) n_batch = query.size(0) # 1. Do all the linear projections in batch from d_model to h x d_k query, key, value = [l(x).view(n_batch, -1, self.h, self.d_k).transpose(1, 2) for l, x in zip(self.linears, (query, key, value))] # 2. Apply attention on all the projected vectors in batch x, self.attn = self.attention(query, key, value, mask=mask) # 3. `Concat` using a view and apply a final linear x = x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) return self.linears[-1](x)
defmake_model(src_vocab, tgt_vocab, N=6, d_model=512, d_ff=2048, h=8, dropout=0.1): """ Construct Transformer model. :params src_vocab: source language vocabulary :params tgt_vocab: target language vocabulary :params N: number of encoder or decoder stacks :params d_model: dimension of model input and output :params d_ff: dimension of feed forward layer :params h: number of attention head :params dropout: rate of dropout """ c = copy.deepcopy attn = MultiHeadedAttention(h, d_model) ff = PositionwiseFeedForward(d_model, d_ff) position = PositionEncoding(d_model, dropout) model = EncoderDecoder( Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout),N), Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), dropout), N), nn.Sequential(Embeddings(d_model, src_vocab), c(position)), nn.Sequential(Embeddings(d_model, tgt_vocab), c(position)), Generator(d_model, tgt_vocab) ) # This was important from their code. # Initialize parameters with Glorot / fan_avg. for p in model.parameters(): if p.dim() > 1: nn.init.xavier_uniform(p) return model
classMyIterator(data.Iterator): defcreate_batches(self): if self.train: defpool(d, random_shuffler): for p in data.batch(d, self.batch_size * 100): p_batch = data.batch( sorted(p, key=self.sort_key), self.batch_size, self.batch_size_fn ) for b in random_shuffler(list(p_batch)): yield b self.batches = pool(self.data(), self.random_shuffler) else: self.batches = [] for b in data.batch(self.data(), self.batch_size, self.batch_size_fn): self.batches.append(sorted(b, key=self.sort_key)) defrebatch(pad_idx, batch): """ Fix order in torchtext to match ours. """ src, trg = batch.src.transpose(0, 1), batch.trg.transpose(0, 1) return Batch(src, trg, pad_idx)
# Skip if not interested in multigpu. classMultiGPULossCompute: "A multi-gpu loss compute and train function." def__init__(self, generator, criterion, devices, opt=None, chunk_size=5): # Send out to different gpus. self.generator = generator self.criterion = nn.parallel.replicate(criterion, devices=devices) self.opt = opt self.devices = devices self.chunk_size = chunk_size def__call__(self, out, targets, normalize): total = 0.0 generator = nn.parallel.replicate(self.generator, devices=self.devices) out_scatter = nn.parallel.scatter(out, target_gpus=self.devices) out_grad = [[] for _ in out_scatter] targets = nn.parallel.scatter(targets, target_gpus=self.devices)
# Divide generating into chunks. chunk_size = self.chunk_size for i in range(0, out_scatter[0].size(1), chunk_size): # Predict distributions out_column = [[Variable(o[:, i:i+chunk_size].data, requires_grad=self.opt isnotNone)] for o in out_scatter] gen = nn.parallel.parallel_apply(generator, out_column)
# Compute loss. y = [(g.contiguous().view(-1, g.size(-1)), t[:, i:i+chunk_size].contiguous().view(-1)) for g, t in zip(gen, targets)] loss = nn.parallel.parallel_apply(self.criterion, y)
# Sum and normalize loss l = nn.parallel.gather(loss, target_device=self.devices[0]) l = l.sum()[0] / normalize total += l.data[0]
# Backprop loss to output of transformer if self.opt isnotNone: l.backward() for j, l in enumerate(loss): out_grad[j].append(out_column[j][0].grad.data.clone())
# Backprop all loss through transformer. if self.opt isnotNone: out_grad = [Variable(torch.cat(og, dim=1)) for og in out_grad] o1 = out o2 = nn.parallel.gather(out_grad, target_device=self.devices[0]) o1.backward(gradient=o2) self.opt.step() self.opt.optimizer.zero_grad() return total * normalize
ifFalse: model_opt = NoamOpt(model.src_embed[0].d_model, 1, 2000, torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) for epoch in range(10): model_par.train() run_epoch((rebatch(pad_idx, b) for b in train_iter), model_par, MultiGPULossCompute(model.generator, criterion, devices=devices, opt=model_opt)) model_par.eval() loss = run_epoch((rebatch(pad_idx, b) for b in valid_iter), model_par, MultiGPULossCompute(model.generator, criterion, devices=devices, opt=None)) print(loss) else: model = torch.load("iwslt.pt")
模型一旦训练好了我们就可以用来翻译了。这里我们以验证集的第一个句子为例,进行翻译:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
for i, batch in enumerate(valid_iter): src = batch.src.transpose(0, 1)[:1] src_mask = (src != SRC.vocab.stoi["<blank>"]).unsqueeze(-2) out = greedy_decode(model, src, src_mask, max_len=60, start_symbol=TGT.vocab.stoi["<s>"]) print("Translation:", end="\t") for i in range(1, out.size(1)): sym = TGT.vocab.itos[out[0, i]] if sym == "</s>": break print(sym, end =" ") print() print("Target:", end="\t") for i in range(1, batch.trg.size(0)): sym = TGT.vocab.itos[batch.trg.data[i, 0]] if sym == "</s>": break print(sym, end =" ") print() break
tgt_sent = trans.split() defdraw(data, x, y, ax): seaborn.heatmap(data, xticklabels=x, square=True, yticklabels=y, vmin=0.0, vmax=1.0, cbar=False, ax=ax) for layer in range(1, 6, 2): fig, axs = plt.subplots(1,4, figsize=(20, 10)) print("Encoder Layer", layer+1) for h in range(4): draw(model.encoder.layers[layer].self_attn.attn[0, h].data, sent, sent if h ==0else [], ax=axs[h]) plt.show() for layer in range(1, 6, 2): fig, axs = plt.subplots(1,4, figsize=(20, 10)) print("Decoder Self Layer", layer+1) for h in range(4): draw(model.decoder.layers[layer].self_attn.attn[0, h].data[:len(tgt_sent), :len(tgt_sent)], tgt_sent, tgt_sent if h ==0else [], ax=axs[h]) plt.show() print("Decoder Src Layer", layer+1) fig, axs = plt.subplots(1,4, figsize=(20, 10)) for h in range(4): draw(model.decoder.layers[layer].self_attn.attn[0, h].data[:len(tgt_sent), :len(sent)], sent, tgt_sent if h ==0else [], ax=axs[h]) plt.show()