一般情况下,不同的图像任务和模型有不同的数据增强方法。然而比较通用的有图像的仿射变换,颜色抖动,水平/垂直翻转, 随机crop。
其中,仿射变换包括旋转,平移,错切(shear), 尺度变化(scale)。仿射变换特点:直线经过仿射变换仍然是直线。
Class Compose(object):
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, *items):
for t in self.transforms:
item = t(*item)
return item
# example
Class ColorJitter(torchvision.transforms.ColorJitter):
def __call__(self, image, data):
img = Image.fromarray(image)
# data are annotations, like points, bboxes
# do corresponding transforms to annotations if need, such as in Flip, Rotation, etc
return np.array(super().__call__(img)), data
Class Dataset(object):
def gen_transforms(self):
transforms = [
self.transforms = Compose(transforms)
def __getitem__(self, index):
# ...
image, data = self.transforms(image, data)
# ...
theta = 30 # degree
# method1: directly calculate the affine matrix
shear = np.array([
[1, np.tan(theta), 0],
[0, 1, 0]], dtype=np.float32)
# method2: use 3 points to define a transform
# for example, choose red, blue, green points in above wiki image
point1 = np.float32([[0,0], [img.shape[1],0], [0,img.shape[0]]])
point2 = np.float32([[0,0], [img.shape[1],0], [img.shape[0]*np.tan(theta), img.shape[0]]])
M_shear = cv2.getAffineTransform(point1, point2)
# shear equals to M_shear
sheared_point = M_shear.dot(point)
# choose borderValue according to MEAN in your preprocess
img = cv2.warpAffine(img, M_shear, (int(sheared_point[0]), int(sheared_point[1])),borderMode=cv2.BORDER_CONSTANT, borderValue=(128, 128, 128))
如有侵权,请联系 cloudcommunity@tencent.com 删除。
如有侵权,请联系 cloudcommunity@tencent.com 删除。