爱因斯坦求和表示法
爱因斯坦求和表示法
以上所述内容解释了 torch.einsum
函数的使用规则,具体介绍了如何利用字母(通常是 [a-zA-Z])来标识输入张量的维度,并使用爱因斯坦求和约定简化复杂的张量运算。
表示方法
- 下标(Subscript)和输入张量维度的对应关系
- 在
torch.einsum
中,字母下标用来标识输入张量的每一个维度。多个输入张量的下标使用逗号(,
)分隔。例如'ij,jk'
表示两个二维张量,分别具有维度i x j
和j x k
。 - 下标相同的维度必须可以广播(broadcastable),也就是说它们的尺寸要么相同,要么其中一个为 1。唯一的例外情况是,当某个输入张量中有重复的下标时,这表示对这个张量的这些维度进行对角化操作,即取该张量在这些维度上的对角线。
- 在
- 输出结果中的下标
- 如果你不明确定义输出下标(使用箭头
->
),那么结果张量的下标将是那些在输入张量中只出现一次的下标,并且这些下标会按字母顺序排序。 - 你也可以使用箭头
->
显式地定义输出张量的下标。例如,公式'ij,jk->ki'
计算矩阵乘法并转置结果,使输出张量的维度为k x i
。
- 如果你不明确定义输出下标(使用箭头
- 元素相乘和求和规则
- 输入张量会在对应的维度上进行元素逐点相乘,然后对那些不出现在输出下标中的维度进行求和。这意味着,如果某个下标只出现在输入张量中,而不出现在输出张量中,那么这个下标对应的维度会被消去(求和)。
- 省略号(Ellipsis,
...
)用于广播...
可以用来表示未显式定义的维度,以方便处理维度不确定的张量。每个输入张量最多可以包含一个省略号,它覆盖那些没有被具体下标标识的维度。例如,5 维张量可以使用'ab...c'
来标识维度,...
代表第三和第四维。- 省略号不需要在不同张量中覆盖相同数量的维度,但这些维度的形状必须是可以广播的。
- 如果输出没有使用箭头
->
明确标识输出下标,那么省略号会在输出张量的最前面(即左边),然后是那些只出现一次的下标。
- 空字符串的特殊情况
- 空字符串
''
是一个有效的公式,通常用于标量(没有维度的值)。
- 空字符串
具体例子说明:
- 矩阵乘法
1
torch.einsum('ij,jk->ik', A, B)
- 这里的
ij
和jk
分别表示两个矩阵的维度。 - 计算的是普通的矩阵乘法,输出的维度是
ik
,即i x k
。
- 这里的
- 批量矩阵乘法(使用省略号)
1
torch.einsum('...ij,...jk->...ik', A, B)
- 这里
...
表示其他未显示的维度,通常是批次维度。这样可以在多个批次上进行矩阵乘法。 - 输出的维度是
...ik
,即在批次维度不变的情况下,进行矩阵乘法。
- 这里
- 提取对角线
1
torch.einsum('ii->i', A)
- 这里的
ii
表示一个方阵的对角线元素。该操作提取方阵的对角线元素,输出为一维张量,包含矩阵的对角线值。
- 这里的
Pytorch 官网例子
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
# trace
torch.einsum('ii', torch.randn(4, 4))
# diagonal
torch.einsum('ii->i', torch.randn(4, 4))
# outer product
x = torch.randn(5)
y = torch.randn(4)
torch.einsum('i,j->ij', x, y)
# batch matrix multiplication
As = torch.randn(3, 2, 5)
Bs = torch.randn(3, 5, 4)
torch.einsum('bij,bjk->bik', As, Bs)
# with sublist format and ellipsis
torch.einsum(As, [..., 0, 1], Bs, [..., 1, 2], [..., 0, 2])
# batch permute
A = torch.randn(2, 3, 4, 5)
torch.einsum('...ij->...ji', A).shape
# equivalent to torch.nn.functional.bilinear
A = torch.randn(3, 5, 4)
l = torch.randn(2, 5)
r = torch.randn(2, 4)
torch.einsum('bn,anm,bm->ba', l, A, r)
总结:
torch.einsum
提供了一种非常灵活和强大的方式来执行复杂的张量运算。通过定义下标和公式,可以方便地进行矩阵乘法、求和、提取对角线等操作,极大地简化了代码的复杂性。
这个函数的主要特点是:
- 下标标识维度,相同下标表示在该维度进行操作。
- 重复下标表示对该维度进行求和。
- 可以通过省略号
...
来进行批次操作或广播操作。
This post is licensed under CC BY 4.0 by the author.