import numpy as np
np.random.seed(2021)
x = np.random.randn(5, 3, 4, 4)
xx = np.random.randn(5, 3, 3, 1, 1)
# megengine
import megengine.functional as F
import megengine.module as M
from megengine import tenso
x_mm = tensor(x)
x_ww = tensor(xx)
x_mm = x_mm.reshape(1, -1, x_mm.shape[2], x_mm.shape[3])
x_ww = x_ww.reshape(5, 3, 3, 1, 1)
y_mm = F.conv2d(x_mm, weight=x_ww, stride=1, padding=0, groups=x_ww.shape[0])
print(y_mm.shape)
mm = y_mm.reshape(5, 3, 4, 4)
# pytorch
import torch
import torch.nn.functional as F
x_mp = torch.from_numpy(x)
x_wp = torch.from_numpy(xx)
x_mp = x_mp.reshape(1, -1, x_mp.shape[2], x_mp.shape[3])
x_wp = x_wp.reshape(5*3, 3, 1, 1)
y_mp = F.conv2d(x_mp, weight=x_wp, stride=1, padding=0, groups=5)
print(y_mp.shape)
pp = y_mp.reshape(5, 3, 4, 4)
print(mm,pp)
相似问题