Einsum 表示法是对张量的复杂操作的一种优雅方式,本质上是使用特定领域的语言。 一旦理解并掌握了 einsum,可以帮助我们更快地编写更简洁高效的代码。
Einsum 是爱因斯坦求和(Einstein summation)的缩写,是一种求和的方法,在处理关于坐标的方程式时非常有效。在 numpy、TensorFlow 和 Pytorch 中都有相关实现,本文通过 Pytorch 实现 Transformer 中的多头注意力来介绍 einsum 在深度学习模型中的应用。
1. 矩阵乘法
假设有两个矩阵:
A=[123456],B=[789101112]A=[123456],B=⎡⎢⎣789101112⎤⎥⎦我们想求两个矩阵的乘积。
- 第一步:
- 第二步:
- 第三步:
- 第四步:
2. Einstein Notation
爱因斯坦标记法又称爱因斯坦求和约定(Einstein summation convention),基本内容是:
当两个变量具有相同的角标时,则遍历求和。在此情况下,求和号可以省略。
比如,计算两个向量的乘积, a,b∈RIa,b∈RI:
c=∑iaibi=aibic=∑iaibi=aibi计算两个矩阵的乘积, AA ∈RI×K∈RI×K,BB ∈RK×J∈RK×J。用爱因斯坦求和符号表示,可以写成:
cij=∑kAikBkj=AikBkjcij=∑kAikBkj=AikBkj在深度学习中,通常使用的是更高阶的张量之间的变换。比如在一个 batch 中包含 NN 个训练样本的,最大长度是 TT,词向量维度为 KK 的张量,即 T∈RN×T×KT∈RN×T×K,如果想让词向量的维度映射到 QQ 维,则定义一个 W∈RK×QW∈RK×Q:
Cntq=∑kTntkWkq=TntkWkqCntq=∑kTntkWkq=TntkWkq在图像处理中,通常在一个 batch 的训练样本中包含 NN 张图片,每张图片长为 TT,宽为 KK,颜色通道为 MM,即 T∈RN×T×K×MT∈RN×T×K×M 是一个 4d 张量。如果我想进行三个操作:
- 将 KK 投影成 QQ 维;
- 对 TT 进行求和;
- 将 MM 和 NN 进行转置。
用爱因斯坦标记法可以表示成:
Cmqn=∑t∑kTntkmWkq=TntkmWkqCmqn=∑t∑kTntkmWkq=TntkmWkq需要注意的是,爱因斯坦标记法是一种书写约定,是为了将复杂的公式写得更加简洁。它本身并不是某种运算符,具体运算还是要回归到各种算子上。
3. einsum
- Numpy:
np.einsum
- Pytorch:
torch.einsum
- TensorFlow:
tf.einsum
以上三种 einsum
都有相同的特性 einsum(equation, operands)
:
equation
:字符串,用来表示爱因斯坦求和标记法的;operands
:一些列张量,要运算的张量。
其中 口
是一个占位符,代表的是张量维度的字符。比如:
1 | np.einsum('ij,jk->ik', A, B) |
A
和 B
是两个矩阵,将 ij,jk->ik
分成两部分:ij, jk
和 ik
,那么 ij
代表的是输入矩阵 A
的第 i
维和第 j
维,jk
代表的是 B
第 j
维和第 k
维,ik
代表的是输出矩阵的第 i
维和第 k
维。注意 i, j, k
可以是任意的字符,但是必须保持一致。换句话说,einsum
实际上是直接操作了矩阵的维度(角标)。上例中表示的是, A
和 B
的乘积。
3.1 矩阵转置
Bji=AijBji=Aij1 | import torch |
3.2 求和
b=∑i∑jAij=Aijb=∑i∑jAij=Aij1 | a = torch.arange(6).reshape(2, 3) |
3.3 列求和
bj=∑iAij=Aijbj=∑iAij=Aij1 | a = torch.arange(6).reshape(2, 3) |
3.4 行求和
bi=∑jAij=Aijbi=∑jAij=Aij1 | a = torch.arange(6).reshape(2, 3) |
3.5 矩阵-向量乘积
ci=∑kAikbk=Aikbkci=∑kAikbk=Aikbk1 | a = torch.arange(6).reshape(2, 3) |
3.6 矩阵-矩阵乘积
Cij=∑kAikBkj=AikBkjCij=∑kAikBkj=AikBkj1 | a = torch.arange(6).reshape(2, 3) |
3.7 点积
c=∑iaibi=aibi1 | a = torch.arange(3) |
3.8 Hardamard 积
Cij=AijBij1 | a = torch.arange(6).reshape(2, 3) |
3.9 外积
Cij=aibj1 | a = torch.arange(3) |
3.10 Batch 矩阵乘积
Cijl=∑kAijkBikl=AijkBikl1 | a = torch.randn(3,2,5) |
3.11 张量收缩
假设有两个张量 A∈RI1×⋯×In 和 B∈RJ1×⋯×Jm。比如 n=4,m=5,且 I2=J3 和 I3=J5。我们可以计算两个张量的乘积,得到新的张量 C∈RI1×I4×J1×J2×J4:
Cpstuv=∑q∑rApqrsBtuqvr=ApqrsBtuqvr1 | a = torch.randn(2,3,5,7) |
3.12 双线性变换
Dij=∑k∑lAikBjklCil=AikBjklCil1 | a = torch.randn(2,3) |
4. einops
尽管 einops
是一个通用的包,这里哦我们只介绍 einops.rearrange
。同 einsum
一样,einops.rearrange
也是操作矩阵的角标的,只不过函数的参数正好相反,如下图所示。
rearrange
传入的参数是一个张量列表,那么后面字符串的第一维表示列表的长度。
1 | qkv = torch.rand(2,128,3*512) # dummy data for illustration only |
5. Scale dot product self-attention
第一步:创建一个线性投影。给定输入 X∈Rb×t×d,其中 b 表示 batch size,t 表示 sentence length,d 表示 word dimension。
Q=XWQ,K=XWK,V=XWV1
2
3
4
5to_qvk = nn.Linear(dim, dim * 3, bias=False) # init only
# Step 1
qkv = to_qvk(x) # [batch, tokens, dim*3 ]
# decomposition to q,v,k
q, k, v = tuple(rearrange(qkv, 'b t (d k) -> k b t d ', k=3))第二步:计算点积,mask,最后计算 softmax。
dot_score=softmax(QK⊤√dk)1
2
3
4
5
6
7# Step 2
# Resulting shape: [batch, tokens, tokens]
scaled_dot_prod = torch.einsum('b i d , b j d -> b i j', q, k) * self.scale_factor
if mask is not None:
assert mask.shape == scaled_dot_prod.shape[1:]
scaled_dot_prod = scaled_dot_prod.masked_fill(mask, -np.inf)
attention = torch.softmax(scaled_dot_prod, dim=-1)第三步:计算注意力得分与 V 的乘积。
Attention(Q,K,V)=softmax(QK⊤√dk)V1
torch.einsum('b i j , b j d -> b i d', attention, v)
将上面三步综合起来:
1 | import numpy as np |
6. Multi-Head Self-Attention
第一步:为每一个头创建一个线性投影 Q,K,V。
1
2to_qvk = nn.Linear(dim, dim_head * heads * 3, bias=False) # init only
qkv = self.to_qvk(x)第二步:将 Q,K,V 分解,并分配给每个头。
1
2
3
4# Step 2
# decomposition to q,v,k and cast to tuple
# [3, batch, heads, tokens, dim_head]
q, k, v = tuple(rearrange(qkv, 'b t (d k h) -> k b h t d ', k=3, h=self.heads))第三步:计算注意力得分
1
2
3
4
5
6
7# Step 3
# resulted shape will be: [batch, heads, tokens, tokens]
scaled_dot_prod = torch.einsum('b h i d , b h j d -> b h i j', q, k) * self.scale_factor
if mask is not None:
assert mask.shape == scaled_dot_prod.shape[2:]
scaled_dot_prod = scaled_dot_prod.masked_fill(mask, -np.inf)
attention = torch.softmax(scaled_dot_prod, dim=-1)第四步:注意力得分与 V 相乘
1
2# Step 4. Calc result per batch and per head h
out = torch.einsum('b h i j , b h j d -> b h i d', attention, v)第五步:将所有的头合并
1
out = rearrange(out, "b h t d -> b t (h d)")
第六步:线性变换
1
2
3self.W_0 = nn.Linear( _dim, dim, bias=False) # init only
# Step 6. Apply final linear transformation layer
self.W_0(out)
最终实现 MHSA:
1 | import numpy as np |
Reference
Einstein Summation in Numpy, OLEXA BILANIUK
A basic introduction to NumPy’s einsum, Alex Riley
- EINSUM IS ALL YOU NEED - EINSTEIN SUMMATION IN DEEP LEARNING, Tim Rocktäschel
- Understanding einsum for Deep learning: implement a transformer with multi-head self-attention from scratch, Nikolas Adaloglou
请问您的公式推导和图片有论文出处吗?
@cvychen 可以看下最后参考资料给出的链接哈
我在官方code里面没有找到两个用来选策略的mixture factor,alpha和beta...请问您知道这两个是实现在哪里了吗? https://github.com/facebookresearch/fairseq/blob/main/fairseq/models/nat/levenshtein_transformer.py
torch.einsum('b i d , b j d -> b i j', q, k) 这个语法用矩阵乘法写是不是比较麻烦
我当时写的时候也没有开源代码,后来没有关注了,所以也不太清楚哎·
总的来说,einsum是一个非常方便的算子
一点小纠正:2.5节“Explicit Sparse Transformer虽然实现了不连续的自适应稀疏化自注意力”应该是Adaptively Sparse Transformers