Main takeaway

  • 逻辑存储与内存存储 (Contiguous);
  • Stride和Size
  • View, Reshape 和 Permute

⭐: 本文是我最近写代码时的思考, 如果存在不合理或者更好的说明方式欢迎评论区提出.


一些小问题, 如果你都能知道答案以及原理, 那可能本文讲的对你来说有点浅了:

  1. c = torch.arange(6), 如果我希望把c给修改成一个size为(2,3)的张量, 从逻辑上说至少是有两种修改的方法的:

    1.  # 方法1
       tensor([[0, 1, 2],
               [3, 4, 5]])
      
       # 方法2
       tensor([[0, 2, 4],
               [1, 3, 5]])
      
       # 这里的问题就是我如果要把[0, ..., 5]塞到一个(2,3)的框里, 我有几种塞法.
       # 所以甚至胡乱塞也是一种"reshape"的方法, 但显然torch不是这么做的.
       tensor([[1, 4, 3],
               [5, 2, 0]])
      
    2. 那么torch里使用的是哪一种? 为什么要这样?
  2. 假设我有50个智能体, 每个智能体生成了长度为42的动作序列, 一个动作是6维张量, 我们记为actions, actions.shape = (42, 50, 6). 现在如果我希望对于每一个智能体, 把连续的7个action合并成一个6x7=42维的张量, 我应该怎么写呢?

    1. actions.reshape(6, 7, 50, 6).permute(0, 2, 1, 3).reshape(6, 50, -1)
    2. actions.reshape(7, 6, 50, 6).permute(1, 2, 0, 3).reshape(6, 50, -1)
  3. 对于torch里的张量, 什么是size, stride, contiguous, 它们又和view, reshape, permute有什么关系?

0. Preliminaries

  1. 对于一个$3$维张量$x$, 我们记它三个维度分别为$\mathrm{dim}_1$, $\mathrm{dim}_2$, $\mathrm{dim}_3$, 假如三个维度的大小 (dim_size)分别为$2, 3, 4$, 我们记$x.\mathrm{size} = (2,3,4)$

    1.  >>> x = torch.arange(24).view(2,3,4)
       >>> x
       tensor([[[ 0,  1,  2,  3],
                [ 4,  5,  6,  7],
                [ 8,  9, 10, 11]],
      
               [[12, 13, 14, 15],
                [16, 17, 18, 19],
                [20, 21, 22, 23]]])
      
  2. 取张量中的某个元素我们叫做索引(index), 我们对张量做index的时候最好把它想成"多维数组". 这样就很容易理解为什么$x[i][j][k] = x.\mathrm{permute}(2, 1, 0)[k][j][i]$了. 同时我们把$(i,j,k)$记为索引的坐标.

1. 逻辑存储与内存存储 (Contiguous)

对于张量$x$, 我们把它的矩阵表示称为逻辑存储, 而把它在内存中存放的方式记为内存存储.

例如:

>>> x = torch.arange(12).reshape(3,4)
>>> x
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]])

则$x$的逻辑存储形式为:

Fig. 1: 逻辑存储

而$x$的内存存储形式为:

Fig. 2: 内存存储

在torch/numpy中, 即使是高维张量在内存中也是存储在一块连续的内存区域中, 同时会记录一些元信息来描述数组的"形态", 例如起始地址, 步长 (stride), 大小 (size)等.

在对高维张量进行索引时我们采用起始地址 + 地址偏移量的计算方式, 而地址偏移量就需要用到stride和size的信息 (后文会提到具体的计算方式).

1.1. 逻辑存储的行优先展开和列优先展开

对于张量$x$, 我们如果想把它展开成一维张量, 我们需要以某种形式遍历$x$的所有元素, 我们现在描述最常见的两种遍历模式: 行优先展开 (row major) 和 列优先展开 (column major).

行优先展开: 从张量的最后一维度开始, 向前展开.

  • 假设$x.\mathrm{shape} = (4,3,2)$, 展开的顺序: $(0, 0, 0) \rightarrow (0, 0, 1) \rightarrow(0, 1, 0) \rightarrow(0, 1, 1) \rightarrow (0,2, 0) \rightarrow (0, 2, 1) \rightarrow (1, 0, 0) \rightarrow \cdots$.

    • 很类似进位的过程.

列优先展开: 从张量的第一维开始, 向后展开.

一个形象的描述是: 对于逻辑存储为$\mathrm{Fig. 1}$中的张量, 其行优先展开和列优先展开的结果分别为:

Fig. 3: 逻辑存储的行优先展开
Fig. 4: 逻辑存储的列优先展开

1.2. C-contiguous 和 Fortran-contiguous

我们前文提到了, 不管你高维张量$x$具体的形状如何, 它都是被存储在一块连续的内存地址中. 而contiguous用来描述张量逻辑存储和内存存储之间的关系.

  • 如果张量$x$的行优先展开形式和其内存存储一致, 则我们称之为C-contiguous. Numpy, Pytorch中的contiguous指的就是C-contiguous, 以及下文中contiguous默认指C-contiguous.
  • 如果张量$x$的列优先展开形式和其内存存储一致, 则我们称之为Fortran-contiguous. Matlab, Fortran中的contiguous指的是Fortran-contiguous.

这里我们可以回答第一个问题了:

>>> c = torch.arange(6).view(2,3)
>>> 那么c=?
# 方法1
tensor([[0, 1, 2],
        [3, 4, 5]])

# 方法2
tensor([[0, 2, 4],
        [1, 3, 5]])

首先torch.arange(6)是一个contiguous的tensor, 所以我们可以知道它在内存中是$[0, \cdots, 5]$的形式. 而view的输入和输出都需要是contiguous的tensor (后文会提到), 所以我们需要在"保证其逻辑存储和内存存储一致的前提下", 将$x$的size改变成$(2,3)$. 不难看出方法1得到的tensor满足contiguous, 而方法2得到的就不满足contiguous了.

如果我们有能力去选择tensor究竟是满足C-contiguous还是Fortran-contiguous, 那我们应该根据操作中是row-wise的操作多还是column-wise的操作多来判断. 如果row-wise的操作多, 那么显然C-contiguous的存储方式更划算, 因为逻辑意义上相邻的元素在内存中也是相邻的, 从索引计算或者cache的角度来说都更划算.

2. Stride, Size

由于高维张量在内存中都是被存储在一段连续的内存空间中, 所以我们需要一些额外的"元信息"用来描述高维张量的形态, stride和size就是其中的两个.

举例来说, $x$是一个二维张量:

>>> x = torch.arange(9).view(3,3)
>>> x.size()
torch.Size([3, 3])
>>> x.stride()
(3, 1)

$\mathrm{size}$比较好理解, 就是每个维度的大小. 而$\mathrm{stride}$则是在我们需要对张量进行索引的时候起作用.

当我们需要索引$x[i][j]$时, 它的地址为起始地址 + offset, 而$\mathrm{offset} = i * \mathrm{stride}[0] + j*\mathrm{stride}[1]$. 故对于一个$n$维张量来说, 它的$\mathrm{stride}$是一个$n$维度tuple, 且$\mathrm{stride}[i]$的意思是当我们沿着$\dim_i$去索引下一个元素的时候, 在内存空间上要跳过几个元素 (offset).

当我们在使用view去修改tensor的时候, 其实我们并没有修改tensor在内存中的存储, 而只是通过修改stride和size来描述张量形状的变化:

>>> x1 = x.view(3,3)
>>> x.data_ptr() == x1.data_ptr()
True

pytorch的官方文档中, 还给出了使用用$\mathrm{stride}$来描述contiguous的方式:

对于$n$维度张量$x$, 如果$\forall i = 0, \cdots, n-2$, 我们都满足:

$$ \mathrm{stride}[i] =\mathrm{stride}[i+1] \times \mathrm{size}[i+1] $$

且$\mathrm{stride}[n-1] = 1$. 则我们称张量$x$是contiguous的.

但只要理解了$\mathrm{stride}$的意义, 你会发现这只是"张量$x$的行优先展开形式和其内存存储一致"的另一种描述方式.

3. View, ReshapePermute

View:

Returns a new tensor with the same data as the self tensor but of a different shape.

  • view要求输入和输出的tensor都是contiguous的, 否则会throw exception. 换言之, 你不管对一个tensor使用了多少次view, 你都只是在改变$\mathrm{stride}$和$\mathrm{size}$, 并没有修改这个tensor的内存存储. 所以你最后只要$\mathrm{view}(-1)$它还是会回到下面这个样子.

所以对某个tensor做view之前, 你不妨在脑海中把它按照行优先展开还原成这种一维的形式, 再去思考view的结果.

Reshape:

Returns a tensor with the same data and number of elements as input, but with the specified shape. When possible, the returned tensor will be a view of nput. Otherwise, it will be a copy

  • 对于contiguous的输入, reshape等于view. 而对于incontiguous的输入, reshape等于tensor.contigous().view. 其中contiguous()会开辟一块新的内存空间, 将incontiguous的张量按照行优先展开的方式存储进去. 所以reshape是有可能修改内存存储的结构的.

    • >>> x = torch.arange(6).view(2,3).T # incontiguous
      >>> x.reshape(1, 6).view(-1)
      tensor([0, 3, 1, 4, 2, 5])
      # -----
      >>> x1 = torch.arange(6).view(2,3).T # incontiguous
      >>> x1.reshape(3, 2).view(-1) 
      # RuntimeError, 因为这里x1的size本身就是(3,2), 所以reshape直接返回了, 
      # 导致x1.reshape(3, 2)还是incontiguous的.
      
  • 所以对于contiguous的tensor作为输入, 经过无数次reshape也不会影响其内存存储方式. 但如果我们对incontiguous的tensor做了reshape, 则tensor的内存存储方式可能会发生变化.

Permute

Returns a view of the original tensor input with its dimensions permuted.

  • 虽然permuteview一样, 都是修改stride和size, 但并不改变内存存储方式 (因为他们本质上都是返回tensor的一个view). 但是permute并不保证返回的tensor是contiguous的.

    • 换言之permute().contiguous()就有可能修改内存存储方式了.
  • 举例来说:

    • >>> a = torch.arange(24).reshape(2,3,4)
      >>> b = a.reshape(3,2,4).permute(1,0,2)
      
      >>> print(a, a.size(), a.stride())
      >>> print('-' * 40)
      >>> print(b, b.size(), b.stride())
      >>> print(a.data_ptr() == b.data_ptr())
      
      tensor([[[ 0,  1,  2,  3],
               [ 4,  5,  6,  7],
               [ 8,  9, 10, 11]],
      
              [[12, 13, 14, 15],
               [16, 17, 18, 19],
               [20, 21, 22, 23]]]) torch.Size([2, 3, 4]) (12, 4, 1)
      ----------------------------------------
      tensor([[[ 0,  1,  2,  3],
               [ 8,  9, 10, 11],
               [16, 17, 18, 19]],
      
              [[ 4,  5,  6,  7],
               [12, 13, 14, 15],
               [20, 21, 22, 23]]]) torch.Size([2, 3, 4]) (4, 8, 1) # 通过stride和size我们可以知道b是incontiguous的
      True
      
  • 上述permute本质上是进行了如下的操作, 可以看出permute之后的tensor的行优先展开结果和内存存储不一致了:

所以我们现在可以知道文章开始时的问题(2)的答案应该是第一个, 如果还是有不明白的地方可以想一想:

>>> x = torch.arange(6)
>>> print((x.view(2,3) == x.view(3,2).T).all())
False

References