首页
学习
活动
专区
工具
TVP
发布
精选内容/技术社群/优惠产品,尽在小程序
立即前往

无监督学习神经网络——自编码

自编码是一种无监督学习的神经网络,主要应用在特征提取,对象识别,降维等。自编码器将神经网络的隐含层看成是一个编码器和解码器,输入数据经过隐含层的编码和解码,到达输出层时,确保输出的结果尽量与输入数据保持一致。也就是说,隐含层是尽量保证输出数据等于输入数据的。 这样做的一个好处是,隐含层能够抓住输入数据的特点,使其特征保持不变。例如,假设输入层有100个神经元,隐含层只有50个神经元,输出层有100个神经元,通过自动编码器算法,只用隐含层的50个神经元就找到了100个输入层数据的特点,能够保证输出数据和输入数据大致一致,就大大降低了隐含层的维度。 既然隐含层的任务是尽量找输入数据的特征,也就是说,尽量用最少的维度来代表输入数据,因此,隐含层各层之间的参数构成的参数矩阵,应该尽量是个稀疏矩阵,即各层之间有越多的参数为0就越好。

fromtorchimportoptim

fromtorchimportnnasnn

fromtorch.autogradimportVariable

fromtorch.utilsimportdata

fromtorchvisionimportdatasets

fromtorchvisionimporttransforms

#超参数

epochs =10

batch_size =64

lr =0.005

n_test_img =5

classAutoEncoder(nn.Module):

def__init(self):

super(AutoEncoder, self).__init__()

# 压缩

self.encoder = nn.Sequential(

nn.Linear(28*28,128),

nn.Tanh(),

nn.Linear(128,64),

nn.Tanh(),

nn.Linear(64,12),

nn.Tanh(),

nn.Linear(12,3)

)

# 解压

self.decoder = nn.Sequential(

nn.Linear(3,12),

nn.Tanh(),

nn.Linear(12,64),

nn.Tanh(),

nn.Linear(64,128),

nn.Tanh(),

nn.Linear(128,28*28),

nn.Sigmoid()

)

defforward(self, x):

encoded = self.encoder(x)

decoded = self.decoder(encoded)

returnencoded, decoded

img_transform = transforms.Compose([

transforms.ToTensor()

])

if__name__ =='__main__':

train_data = datasets.MNIST(root='./data', train=True,transform=img_transform, download=True)

train_loader = data.DataLoader(train_data, batch_size=batch_size, shuffle=True)

autoencoder = AutoEncoder()

# 训练模型

optimizer = optim.Adam(autoencoder.parameters(), lr=lr)

loss = nn.MSELoss()

forepochinrange(epochs):

forsetp, (x, y)inenumerate(train_loader):

b_x = Variable(x.view(-1,28*28))

b_y = Variable(x.view(-1,28*28))

b_label = Variable(y)

encoded, decoded = autoencoder(b_x)

loss_data = loss(decoded, b_y)

optimizer.zero_grad()

loss.backward()

optimizer.step()

  • 发表于:
  • 原文链接http://kuaibao.qq.com/s/20171220G0I10200?refer=cp_1026
  • 腾讯「腾讯云开发者社区」是腾讯内容开放平台帐号(企鹅号)传播渠道之一,根据《腾讯内容开放平台服务协议》转载发布内容。
  • 如有侵权,请联系 cloudcommunity@tencent.com 删除。

扫码

添加站长 进交流群

领取专属 10元无门槛券

私享最新 技术干货

扫码加入开发者社群
领券