先看函数参数:torch.flatten(input, start_dim=0, end_dim=-1)input: 一个 tensor,即要被“推平”的 tensor。...我们要先来看一下 tensor 中的 shape 是怎么样的:t = torch.tensor([[[1, 2, 2, 1], [3, 4, 4, 3],...4, 4, 3], [1, 2, 3, 4]], [[5, 6, 6, 5], [7, 8, 8, 7], [5, 6, 7, 8]]])torch.Size...示例代码:x = torch.flatten(t, start_dim=1)print(x, x.shape)y = torch.flatten(t, start_dim=0, end_dim=1)print...pytorch中的 torch.nn.Flatten 类和 torch.Tensor.flatten 方法其实都是基于上面的 torch.flatten 函数实现的。