iis服务器助手广告广告
返回顶部
首页 > 资讯 > 后端开发 > Python >pytorch 实现变分自动编码器的操作
  • 115
分享到

pytorch 实现变分自动编码器的操作

2024-04-02 19:04:59 115人浏览 泡泡鱼

Python 官方文档:入门教程 => 点击学习

摘要

本来以为自动编码器是很简单的东西,但是也是看了好多资料仍然不太懂它的原理。先把代码记录下来,有时间好好研究。 这个例子是用MNIST数据集生成为例子 # -*- coding:

本来以为自动编码器是很简单的东西,但是也是看了好多资料仍然不太懂它的原理。先把代码记录下来,有时间好好研究。

这个例子是用MNIST数据集生成为例子


# -*- coding: utf-8 -*-
"""
Created on Fri Oct 12 11:42:19 2018
@author: www
""" 
import os 
import torch
from torch.autograd import Variable
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader
 
from torchvision.datasets import MNIST
from torchvision import transfORMs as tfs
from torchvision.utils import save_image 
im_tfs = tfs.Compose([
    tfs.ToTensor(),
    tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) # 标准化
])
 
train_set = MNIST('E:\data', transform=im_tfs)
train_data = DataLoader(train_set, batch_size=128, shuffle=True)
 
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
 
        self.fc1 = nn.Linear(784, 400)
        self.fc21 = nn.Linear(400, 20) # mean
        self.fc22 = nn.Linear(400, 20) # var
        self.fc3 = nn.Linear(20, 400)
        self.fc4 = nn.Linear(400, 784)
 
    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)
 
    def reparametrize(self, mu, logvar):
        std = logvar.mul(0.5).exp_()
        eps = torch.FloatTensor(std.size()).normal_()
        if torch.cuda.is_available():
            eps = Variable(eps.cuda())
        else:
            eps = Variable(eps)
        return eps.mul(std).add_(mu)
 
    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return F.tanh(self.fc4(h3))
 
    def forward(self, x):
        mu, logvar = self.encode(x) # 编码
        z = self.reparametrize(mu, logvar) # 重新参数化成正态分布
        return self.decode(z), mu, logvar # 解码,同时输出均值方差 
 
net = VAE() # 实例化网络
if torch.cuda.is_available():
    net = net.cuda()
    
x, _ = train_set[0]
x = x.view(x.shape[0], -1)
if torch.cuda.is_available():
    x = x.cuda()
x = Variable(x)
_, mu, var = net(x) 
print(mu)
 
#可以看到,对于输入,网络可以输出隐含变量的均值和方差,这里的均值方差还没有训练
 
#下面开始训练 
reconstruction_function = nn.MSELoss(size_average=False) 
def loss_function(recon_x, x, mu, logvar):
    """
    recon_x: generating images
    x: origin images
    mu: latent mean
    logvar: latent log variance
    """
    MSE = reconstruction_function(recon_x, x)
    # loss = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
    KLD = torch.sum(KLD_element).mul_(-0.5)
    # KL divergence
    return MSE + KLD 
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
 
def to_img(x):
    '''
    定义一个函数将最后的结果转换回图片
    '''
    x = 0.5 * (x + 1.)
    x = x.clamp(0, 1)
    x = x.view(x.shape[0], 1, 28, 28)
    return x
 
for e in range(100):
    for im, _ in train_data:
        im = im.view(im.shape[0], -1)
        im = Variable(im)
        if torch.cuda.is_available():
            im = im.cuda()
        recon_im, mu, logvar = net(im)
        loss = loss_function(recon_im, im, mu, logvar) / im.shape[0] # 将 loss 平均
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
 
    if (e + 1) % 20 == 0:
        print('epoch: {}, Loss: {:.4f}'.format(e + 1, loss.item()))
        save = to_img(recon_im.cpu().data)
        if not os.path.exists('./vae_img'):
            os.mkdir('./vae_img')
        save_image(save, './vae_img/image_{}.png'.format(e + 1))                    
          

补充:PyTorch 深度学习快速入门——变分自动编码器

变分编码器是自动编码器的升级版本,其结构跟自动编码器是类似的,也由编码器和解码器构成。

回忆一下,自动编码器有个问题,就是并不能任意生成图片,因为我们没有办法自己去构造隐藏向量,需要通过一张图片输入编码我们才知道得到的隐含向量是什么,这时我们就可以通过变分自动编码器来解决这个问题。

其实原理特别简单,只需要在编码过程给它增加一些限制,迫使其生成的隐含向量能够粗略的遵循一个标准正态分布,这就是其与一般的自动编码器最大的不同。

这样我们生成一张新图片就很简单了,我们只需要给它一个标准正态分布的随机隐含向量,这样通过解码器就能够生成我们想要的图片,而不需要给它一张原始图片先编码。

一般来讲,我们通过 encoder 得到的隐含向量并不是一个标准的正态分布,为了衡量两种分布的相似程度,我们使用 KL divergence,利用其来表示隐含向量与标准正态分布之间差异的 loss,另外一个 loss 仍然使用生成图片与原图片的均方误差来表示。

KL divergence 的公式如下

重参数 为了避免计算 KL divergence 中的积分,我们使用重参数的技巧,不是每次产生一个隐含向量,而是生成两个向量,一个表示均值,一个表示标准差,这里我们默认编码之后的隐含向量服从一个正态分布的之后,就可以用一个标准正态分布先乘上标准差再加上均值来合成这个正态分布,最后 loss 就是希望这个生成的正态分布能够符合一个标准正态分布,也就是希望均值为 0,方差为 1

所以最后我们可以将我们的 loss 定义为下面的函数,由均方误差和 KL divergence 求和得到一个总的 loss


def loss_function(recon_x, x, mu, logvar):
    """
    recon_x: generating images
    x: origin images
    mu: latent mean
    logvar: latent log variance
    """
    MSE = reconstruction_function(recon_x, x)
    # loss = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
    KLD = torch.sum(KLD_element).mul_(-0.5)
    # KL divergence
    return MSE + KLD

用 mnist 数据集来简单说明一下变分自动编码器


import os 
import torch
from torch.autograd import Variable
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader
 
from torchvision.datasets import MNIST
from torchvision import transforms as tfs
from torchvision.utils import save_image
 
im_tfs = tfs.Compose([
    tfs.ToTensor(),
    tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) # 标准化
])
 
train_set = MNIST('./mnist', transform=im_tfs)
train_data = DataLoader(train_set, batch_size=128, shuffle=True)
 
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
 
        self.fc1 = nn.Linear(784, 400)
        self.fc21 = nn.Linear(400, 20) # mean
        self.fc22 = nn.Linear(400, 20) # var
        self.fc3 = nn.Linear(20, 400)
        self.fc4 = nn.Linear(400, 784)
 
    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)
 
    def reparametrize(self, mu, logvar):
        std = logvar.mul(0.5).exp_()
        eps = torch.FloatTensor(std.size()).normal_()
        if torch.cuda.is_available():
            eps = Variable(eps.cuda())
        else:
            eps = Variable(eps)
        return eps.mul(std).add_(mu)
 
    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return F.tanh(self.fc4(h3))
 
    def forward(self, x):
        mu, logvar = self.encode(x) # 编码
        z = self.reparametrize(mu, logvar) # 重新参数化成正态分布
        return self.decode(z), mu, logvar # 解码,同时输出均值方差
 
net = VAE() # 实例化网络
if torch.cuda.is_available():
    net = net.cuda()
x, _ = train_set[0]
x = x.view(x.shape[0], -1)
if torch.cuda.is_available():
    x = x.cuda()
x = Variable(x)
_, mu, var = net(x) 
print(mu) 
 
Variable containing:  Columns 0 to 9  -0.0307 -0.1439 -0.0435  0.3472  0.0368 -0.0339  0.0274 -0.5608  0.0280  0.2742  Columns 10 to 19  -0.6221 -0.0894 -0.0933  0.4241  0.1611  0.3267  0.5755 -0.0237  0.2714 -0.2806 [torch.cuda.FloatTensor of size 1x20 (GPU 0)]

可以看到,对于输入,网络可以输出隐含变量的均值和方差,这里的均值方差还没有训练 下面开始训练


reconstruction_function = nn.MSELoss(size_average=False) 
def loss_function(recon_x, x, mu, logvar):
    """
    recon_x: generating images
    x: origin images
    mu: latent mean
    logvar: latent log variance
    """
    MSE = reconstruction_function(recon_x, x)
    # loss = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
    KLD = torch.sum(KLD_element).mul_(-0.5)
    # KL divergence
    return MSE + KLD 
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
 
def to_img(x):
    '''
    定义一个函数将最后的结果转换回图片
    '''
    x = 0.5 * (x + 1.)
    x = x.clamp(0, 1)
    x = x.view(x.shape[0], 1, 28, 28)
    return x
 
for e in range(100):
    for im, _ in train_data:
        im = im.view(im.shape[0], -1)
        im = Variable(im)
        if torch.cuda.is_available():
            im = im.cuda()
        recon_im, mu, logvar = net(im)
        loss = loss_function(recon_im, im, mu, logvar) / im.shape[0] # 将 loss 平均
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
 
    if (e + 1) % 20 == 0:
        print('epoch: {}, Loss: {:.4f}'.format(e + 1, loss.data[0]))
        save = to_img(recon_im.cpu().data)
        if not os.path.exists('./vae_img'):
            os.mkdir('./vae_img')
        save_image(save, './vae_img/image_{}.png'.format(e + 1))
  
epoch: 20, Loss: 61.5803 epoch: 40, Loss: 62.9573 epoch: 60, Loss: 63.4285 epoch: 80, Loss: 64.7138 epoch: 100, Loss: 63.3343

变分自动编码器虽然比一般的自动编码器效果要好,而且也限制了其输出的编码 (code) 的概率分布,但是它仍然是通过直接计算生成图片和原始图片的均方误差来生成 loss,这个方式并不好,生成对抗网络中,我们会讲一讲这种方式计算 loss 的局限性,然后会介绍一种新的训练办法,就是通过生成对抗的训练方式来训练网络而不是直接比较两张图片的每个像素点的均方误差

以上为个人经验,希望能给大家一个参考,也希望大家多多支持编程网。

--结束END--

本文标题: pytorch 实现变分自动编码器的操作

本文链接: https://www.lsjlt.com/news/126648.html(转载时请注明来源链接)

有问题或投稿请发送至: 邮箱/279061341@qq.com    QQ/279061341

本篇文章演示代码以及资料文档资料下载

下载Word文档到电脑,方便收藏和打印~

下载Word文档
猜你喜欢
  • pytorch 实现变分自动编码器的操作
    本来以为自动编码器是很简单的东西,但是也是看了好多资料仍然不太懂它的原理。先把代码记录下来,有时间好好研究。 这个例子是用MNIST数据集生成为例子 # -*- coding: ...
    99+
    2024-04-02
  • Pytorch实现将label变成onehot编码的两种方式
    目录前言使用scatter_获得one hot 编码使用tensor.index_select获得one hot编码第二种针对分割网络的one_hot编码总结由于Pytorch不像T...
    99+
    2023-02-01
    Pytorch label one hot编码 one hot编码 label one hot编码
  • pytorch实现textCNN的具体操作
    1. 原理 2014年的一篇文章,开创cnn用到文本分类的先河。 Convolutional Neural Networks for Sentence Classification ...
    99+
    2024-04-02
  • Python实现自动发消息自定义内容的操作代码
    目录一、效果二、开发环境三、关键步骤解析总结有时候让了解放双手,让电脑来帮我们自动发一些我们想要发的消息,挺省力的,比如说白天写好了演讲稿,晚上要在群里进行文字演讲,那么我们就可以用...
    99+
    2024-04-02
  • Typora自动编号的具体操作
    概述 在使用Typora写比较长的文章时,需要给章节编号,方便区分层次。如果手动编号,一旦章节顺序改变,很多章节的编号都需要一一手动修改,极其麻烦。 Typora官方提供了自动编号的...
    99+
    2024-04-02
  • Pytorch实现全连接层的操作
    全连接神经网络(FC) 全连接神经网络是一种最基本的神经网络结构,英文为Full Connection,所以一般简称FC。 FC的准则很简单:神经网络中除输入层之外的每个节点都和上一...
    99+
    2024-04-02
  • Python实现自动化网页操作
    编程语言:python 集成开发环境(IDE):Visual Studio Code 配置方法参照Visual Studio Code配置Python编程环境 目录 1 准备1.1 安装...
    99+
    2023-09-02
    python 自动化 chrome selenium
  • python神经网络pytorch中BN运算操作自实现
    BN 想必大家都很熟悉,来自论文: 《Batch Normalization Accelerating Deep Network Training by Reducing Inter...
    99+
    2024-04-02
  • 使用python怎么实现mysql自动增删分区操作
    本篇文章给大家分享的是有关使用python怎么实现mysql自动增删分区操作,小编觉得挺实用的,因此分享给大家学习,希望大家阅读完这篇文章后可以有所收获,话不多说,跟着小编一起来看看吧。连接mysql#!/usr/bin/python#-*...
    99+
    2023-06-14
  • c# list部分操作实现代码
    C# Linq获取两个List或数组的差集交集 复制代码 代码如下:List<int> list1 = newList<int>();list1.Add(1)...
    99+
    2022-11-15
    c# list
  • python+selenium的web自动化上传操作的实现
    目录一、关于上传操作二、input标签三、第三方库pywin32四、第三方工具pyautogui总结一、关于上传操作 上传有两种情况: 如果是input可以直接输入路径的,那么直接使...
    99+
    2024-04-02
  • pytorch实现模型剪枝的操作方法
    目录一,剪枝分类1.1,非结构化剪枝1.2,结构化剪枝1.3,本地与全局修剪二,PyTorch 的剪枝2.1,pytorch 剪枝工作原理2.2,局部剪枝2.2.1,局部非结构化剪枝...
    99+
    2023-02-24
    pytorch模型剪枝 pytorch剪枝
  • R语言变量重编码、重命名的操作
    1、变量重编码 重编码涉及根据同一个变量和/或其他变量的现有值创建新值的过程,如将符合某个条件的值重新赋值等,这里主要介绍两种常见的方法: #第一种方法 per <- da...
    99+
    2024-04-02
  • PyTorch 实现L2正则化以及Dropout的操作
    了解知道Dropout原理 如果要提高神经网络的表达或分类能力,最直接的方法就是采用更深的网络和更多的神经元,复杂的网络也意味着更加容易过拟合。 于是就有了Dropout,...
    99+
    2024-04-02
  • Typora自动编号的具体操作是怎样的
    这期内容当中小编将会给大家带来有关Typora自动编号的具体操作是怎样的,文章内容丰富且以专业的角度为大家分析和叙述,阅读完这篇文章希望大家可以有所收获。概述在使用Typora写比较长的文章时,需要给章节编号,方便区分层次。如果手动编号,一...
    99+
    2023-06-21
  • pytorch中LN(LayerNorm)及Relu和其变相输出操作的示例分析
    这篇文章主要介绍pytorch中LN(LayerNorm)及Relu和其变相输出操作的示例分析,文中介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们一定要看完!主要就是了解一下pytorch中的使用layernorm这种归一化之后的数据...
    99+
    2023-06-15
  • Sharding-JDBC自动实现MySQL读写分离的示例代码怎么编写
    Sharding-JDBC自动实现MySQL读写分离的示例代码怎么编写,很多新手对此不是很清楚,为了帮助大家解决这个难题,下面小编将为大家详细讲解,有这方面需求的人可以来学习下,希望你能有所收获。一、ShardingSphere和Shard...
    99+
    2023-06-25
  • PHP Linux脚本操作实例:实现自动化部署
    近年来,随着软件行业的快速发展,部署工作在开发流程中变得愈发重要。为了提高效率,许多开发团队都开始采用自动化部署来简化繁琐的部署过程。在这其中,PHP语言在Linux环境下的脚本操作成为了一种常见的实现方式。本文将介绍如何使用PHP脚本在L...
    99+
    2023-10-21
    Linux PHP 自动化部署
  • 使用Python怎么操作Excel实现自动分组合并单元格
    这篇文章主要介绍了使用Python怎么操作Excel实现自动分组合并单元格,编程网小编觉得不错,现在分享给大家,也给大家做个参考,一起跟随编程网小编来看看吧!df.to_excel('test.xlsx',index=Fal...
    99+
    2023-06-06
  • jQuery实现鼠标拖动div改变位置、大小的操作
    本篇内容主要讲解“jQuery实现鼠标拖动div改变位置、大小的操作”,感兴趣的朋友不妨来看看。本文介绍的方法操作简单快捷,实用性强。下面就让小编来带大家学习“jQuery实现鼠标拖动div改变位置、大小的操作”吧!  ...
    99+
    2023-06-14
软考高级职称资格查询
编程网,编程工程师的家园,是目前国内优秀的开源技术社区之一,形成了由开源软件库、代码分享、资讯、协作翻译、讨论区和博客等几大频道内容,为IT开发者提供了一个发现、使用、并交流开源技术的平台。
  • 官方手机版

  • 微信公众号

  • 商务合作