Loading [MathJax]/jax/input/TeX/config.js
前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >专栏 >【机器学习】--- 自监督学习

【机器学习】--- 自监督学习

作者头像
Undoom
发布于 2024-09-23 13:32:56
发布于 2024-09-23 13:32:56
11700
代码可运行
举报
文章被收录于专栏:学习学习
运行总次数:0
代码可运行

1. 引言

机器学习近年来的发展迅猛,许多领域都在不断产生新的突破。在监督学习和无监督学习之外,自监督学习(Self-Supervised Learning, SSL)作为一种新兴的学习范式,逐渐成为机器学习研究的热门话题之一。自监督学习通过从数据中自动生成标签,避免了手工标注的代价高昂,进而使得模型能够更好地学习到有用的表示。

自监督学习的应用领域广泛,涵盖了图像处理自然语言处理、音频分析等多个方向。本篇博客将详细介绍自监督学习的核心思想、常见的自监督学习方法及其在实际任务中的应用。我们还将通过具体的代码示例来加深对自监督学习的理解。

2. 自监督学习的核心思想

自监督学习的基本理念是让模型通过从数据本身生成监督信号进行训练,而无需人工标注。常见的方法包括生成对比任务、预测数据中的某些属性或部分等。自监督学习的关键在于设计出有效的预训练任务,使模型在完成这些任务的过程中能够学习到数据的有效表示。

2.1 自监督学习与监督学习的区别

在监督学习中,模型的训练需要依赖大量的人工标注数据,而无监督学习则没有明确的标签。自监督学习介于两者之间,它通过从未标注的数据中创建监督信号,完成预训练任务。通常,自监督学习的流程可以分为两步:

  1. 预训练:利用自监督任务对模型进行预训练,使模型学习到数据的有效表示。
  2. 微调:将预训练的模型应用到具体任务中,通常需要进行一些监督学习的微调。
2.2 常见的自监督学习任务

常见的自监督任务包括:

  • 对比学习(Contrastive Learning):从数据中生成正样本和负样本对,模型需要学会区分正负样本。
  • 预文本任务(Pretext Tasks):如图像块预测、顺序预测、旋转预测等任务。
2.3 自监督学习的优点

自监督学习具备以下优势:

  • 减少对人工标注的依赖:通过生成任务标签,大大降低了数据标注的成本。
  • 更强的泛化能力:在大量未标注的数据上进行预训练,使模型能够学习到通用的数据表示,提升模型在多个任务上的泛化能力。
3. 自监督学习的常见方法

在自监督学习中,研究者设计了多种预训练任务来提升模型的学习效果。以下是几种常见的自监督学习方法。

3.1 对比学习(Contrastive Learning)

对比学习是目前自监督学习中最受关注的一个方向。其基本思想是通过构造正样本对(相似样本)和负样本对(不同样本),让模型学习区分样本之间的相似性。典型的方法包括SimCLR、MoCo等。

SimCLR 的实现
代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.models as models
from torch.utils.data import DataLoader, Dataset
import numpy as np

# SimCLR数据增强
class SimCLRTransform:
    def __init__(self, size):
        self.transform = transforms.Compose([
            transforms.RandomResizedCrop(size=size),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
            transforms.RandomGrayscale(p=0.2),
            transforms.GaussianBlur(kernel_size=(3, 3)),
            transforms.ToTensor()
        ])

    def __call__(self, x):
        return self.transform(x), self.transform(x)

# 定义对比损失
class NTXentLoss(nn.Module):
    def __init__(self, temperature):
        super(NTXentLoss, self).__init__()
        self.temperature = temperature

    def forward(self, z_i, z_j):
        batch_size = z_i.size(0)
        z = torch.cat([z_i, z_j], dim=0)
        sim_matrix = torch.mm(z, z.t()) / self.temperature
        mask = torch.eye(2 * batch_size, dtype=torch.bool).to(sim_matrix.device)
        sim_matrix.masked_fill_(mask, -float('inf'))
        
        positives = torch.cat([torch.diag(sim_matrix, batch_size), torch.diag(sim_matrix, -batch_size)], dim=0)
        negatives = sim_matrix[~mask].view(2 * batch_size, -1)
        
        logits = torch.cat([positives.unsqueeze(1), negatives], dim=1)
        labels = torch.zeros(2 * batch_size).long().to(logits.device)
        
        loss = nn.CrossEntropyLoss()(logits, labels)
        return loss

# 定义模型架构
class SimCLR(nn.Module):
    def __init__(self, base_model, projection_dim=128):
        super(SimCLR, self).__init__()
        self.backbone = base_model
        self.projector = nn.Sequential(
            nn.Linear(self.backbone.fc.in_features, 512),
            nn.ReLU(),
            nn.Linear(512, projection_dim)
        )

    def forward(self, x):
        h = self.backbone(x)
        z = self.projector(h)
        return z

# 模型训练
def train_simclr(model, train_loader, epochs=100, lr=1e-3, temperature=0.5):
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = NTXentLoss(temperature)

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for x_i, x_j in train_loader:
            optimizer.zero_grad()
            z_i = model(x_i)
            z_j = model(x_j)
            loss = criterion(z_i, z_j)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        
        print(f'Epoch [{epoch+1}/{epochs}], Loss: {total_loss/len(train_loader)}')

# 示例:在CIFAR-10上进行SimCLR训练
from torchvision.datasets import CIFAR10

train_dataset = CIFAR10(root='./data', train=True, transform=SimCLRTransform(32), download=True)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=4)

resnet_model = models.resnet18(pretrained=False)
simclr_model = SimCLR(base_model=resnet_model)

train_simclr(simclr_model, train_loader)

以上代码展示了如何实现SimCLR对比学习模型。通过数据增强生成正样本对,使用NT-Xent损失函数来区分正负样本对,进而让模型学习到有效的数据表示。

3.2 预文本任务(Pretext Tasks)

除了对比学习,预文本任务也是自监督学习中的一种重要方法。常见的预文本任务包括图像块预测、旋转预测、Jigsaw拼图任务等。我们以Jigsaw拼图任务为例,展示如何通过打乱图像块顺序,让模型进行重新排序来学习图像表示。

Jigsaw任务的实现
代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
import random

# 定义Jigsaw数据预处理
class JigsawTransform:
    def __init__(self, size, grid_size=3):
        self.size = size
        self.grid_size = grid_size
        self.transform = transforms.Compose([
            transforms.Resize((size, size)),
            transforms.ToTensor()
        ])

    def __call__(self, x):
        x = self.transform(x)
        blocks = self.split_into_blocks(x)
        random.shuffle(blocks)
        return torch.cat(blocks, dim=1), torch.tensor([i for i in range(self.grid_size ** 2)])

    def split_into_blocks(self, img):
        c, h, w = img.size()
        block_h, block_w = h // self.grid_size, w // self.grid_size
        blocks = []
        for i in range(self.grid_size):
            for j in range(self.grid_size):
                block = img[:, i*block_h:(i+1)*block_h, j*block_w:(j+1)*block_w]
                blocks.append(block.unsqueeze(0))
        return blocks

# 定义Jigsaw任务模型
class JigsawModel(nn.Module):
    def __init__(self, base_model):
        super(JigsawModel, self).__init__()
        self.backbone = base_model
        self.classifier = nn.Linear(base_model.fc.in_features, 9)

    def forward(self, x):
        features = self.backbone(x)
        out = self.classifier(features)
        return out

# 示例:在CIFAR-10上进行Jigsaw任务训练
train_dataset = CIFAR10(root='./data', train=True, transform=JigsawTransform(32), download=True)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=4)

jigsaw_model = JigsawModel(base_model=resnet_model)

# 训练过程同样可以采用类似SimCLR的方式进行

Jigsaw任务通过打乱图像块并要求模型恢复原始顺序来学习图像的表示,训练方式与

普通的监督学习任务相似,核心是构建预训练任务并生成标签。

4. 自监督学习的应用场景

自监督学习目前在多个领域得到了成功的应用,包括但不限于:

  • 图像处理:通过预训练任务学习到丰富的图像表示,进而提升在图像分类、目标检测等任务上的表现。
  • 自然语言处理:BERT等模型的成功应用展示了自监督学习在文本任务中的巨大潜力。
  • 时序数据分析:例如在视频处理、音频分析等领域,自监督学习也展示出了强大的能力。
5. 结论

自监督学习作为机器学习中的一个新兴热点,极大地推动了无标注数据的利用效率。通过设计合理的预训练任务,模型能够学习到更加通用的数据表示,进而提升下游任务的性能。在未来,自监督学习有望在更多实际应用中发挥重要作用,帮助解决数据标注昂贵、难以获取的难题。

在这篇文章中,我们不仅阐述了自监督学习的基本原理,还通过代码示例展示了如何实现对比学习和Jigsaw任务等具体方法。通过深入理解这些技术,读者可以尝试将其应用到实际任务中,从而提高模型的表现。

参考文献
  1. Chen, Ting, et al. “A simple framework for contrastive learning of visual representations.” International conference on machine learning. PMLR, 2020.
  2. Gidaris, Spyros, and Nikos Komodakis. “Unsupervised representation learning by predicting image rotations.” International Conference on Learning Representations. 2018.
本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2024-09-17,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
暂无评论
推荐阅读
编辑精选文章
换一批
常用正则表达式
/^([0-9]{1,}\.[0-9]{1,}|[0-9]{1,})$/    数字或小数点
Java架构师必看
2020/04/22
6660
全网最全正则实战指南,拿走不谢
最近有很多小伙伴问我为啥会有那么多的时间写文章,录视频,好吧,今天我就给大家分享下我平时工作中会经常使用的一些小工具吧。
冰河
2023/11/29
2690
全网最全正则实战指南,拿走不谢
十分钟学会正则表达式
https://segmentfault.com/a/1190000038502198
@超人
2021/02/26
4120
[Regex]Get正则表达式
原文链接:http://blog.csdn.net/humanking7/article/details/51175937
祥知道
2020/03/10
5260
知识总结:常用正则表达式正则表达式
正则表达式 常用正则表达式大全!(例如:匹配中文、匹配html) 匹配中文字符的正则表达式: [u4e00-u9fa5]  评注:匹配中文还真是个头疼的事,有了这个表达式就好办了 匹配双字节字符(包
牛客网
2018/04/28
1K0
第177天:常用正则表达式(最全)
常用正则表达式 1 <script> 2 /* 常用正则表达式大全!(例如:匹配中文、匹配html) 3 4 匹配中文字符的正则表达式: [u4e00-u9fa5] 5 评注:匹配中文还真是个头疼的事,有了这个表达式就好办了 6 匹配双字节字符(包括汉字在内):[^x00-xff] 7 评注:可以用来计算字符串的长度(一个双字节字符长度计2,ASCII字符计1) 8 匹配空白行的正则表达式:ns*r 9 评注:可以用来删除空白行 10
半指温柔乐
2018/09/11
8670
C#常见正则表达式
"^\d+$" //非负整数(正整数 + 0) "^[0-9]*[1-9][0-9]*$" //正整数 "^((-\d+)|(0+))$" //非正整数(负整数 + 0) "^-[0-9]*[1
恋喵大鲤鱼
2018/08/03
7410
PHP 正则表达式及常用正则汇总
正则表达式用于字符串处理、表单验证等场合,实用高效。现将一些常用的表达式收集于此,以备不时之需。
V站CEO-西顾
2018/06/10
3.9K2
版本号的正则表达式-常见正则表达式大全
  评注:腾讯QQ号从10000开始   匹配中国邮政编码:[1-9]d{5}(?!d)   评注:中国邮政编码为6位数字   匹配身份证:d{15}|d{18}   评注:中国的身份证为15位或18
宜轩
2022/12/29
9100
C#常用正则表达式整理
C#常用正则表达式 非负整数(正整数 + 0): "^\d+$" 正整数 "^[0-9][1-9][0-9]$" 非正整数(负整数 + 0)"^((-\d+)|(0+))$" 负整数 "^-[0-9]
大师级码师
2021/10/27
6440
正则表达式学习心得
正则表达式算是一门通用的东西,前端后端都能用得到,在某些时候正则表达式也是很方便。
神秘人9527
2022/11/20
2900
常用的20个正则表达式
正则表达式,一个十分古老而又强大的文本处理工具,仅仅用一段非常简短的表达式语句,便能够快速实现一个非常复杂的业务逻辑。熟练地掌握正则表达式的话,能够使你的开发效率得到极大的提升。
小小工匠
2021/08/16
3.3K0
一起来了解一下正则表达式
在维基百科中,正则表达式被形容是“使用单个字符串来描述、匹配一系列匹配某个句法规则的字符串。在很多文本编辑器里,正则表达式通常被用来检索、替换那些匹配某个模式的文本。”
软测小生
2019/07/24
6790
一起来了解一下正则表达式
关于常用的正则表达式的分享
  1.正则表达式,又称规则表达式。(英语:Regular Expression,在代码中常简写为regex、regexp或RE),计算机科学的一个概念。正则表达式通常被用来检索、替换那些符合某个模式(规则)的文本。
用户7053485
2020/03/12
1.2K0
study - 一文入门正则表达式
如图所示的正则,将日期和时间都括号括起来。这个正则中一共有两个分组,日期是第 1 个,时间是第 2 个。
stark张宇
2023/03/04
5950
正则表达式 至少6位-字母,数字,下划线或者数字的正则表达式
  一、校验数字的表达式   数字:^[0-9]*$   n位的数字:^\d{n}$   至少n位的数字:^\d{n,}$   m-n位的数字:^\d{m,n}$   零和非零开头的数字:^(0|1-
宜轩
2022/12/29
3.8K0
正则表达式总结
正则表达式是对字符串(包括普通字符(例如,a 到 z 之间的字母)和特殊字符(称为“元字符”))操作的一种逻辑公式,就是用事先定义好的一些特定字符、及这些特定字符的组合,组成一个“规则字符串”,这个“规则字符串”用来表达对字符串的一种过滤逻辑。正则表达式是一种文本模式,模式描述在搜索文本时要匹配的一个或多个字符串。
Dream城堡
2018/09/10
9450
常用正则表达式
  这些正则皆为日常开发总结,一般常用的都用,来源有来自自己总结的,还有的是从网上记录下来的,希望对大家有个帮助,完好正则提高程序性能!
追逐时光者
2019/08/28
1.5K0
Java正则表达式大全(参考)
ma布
2024/10/21
1170
微信小程序正则表达式
Email地址:^\w+([-+.]\w+)@\w+([-.]\w+).\w+([-.]\w+)*$
江一铭
2022/06/17
1K0
相关推荐
常用正则表达式
更多 >
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档
本文部分代码块支持一键运行,欢迎体验
本文部分代码块支持一键运行,欢迎体验