位置: IT常识 - 正文

GANs系列:CGAN(条件GAN)原理简介以及项目代码实现

编辑:rootadmin
GANs系列:CGAN(条件GAN)原理简介以及项目代码实现

一、原始GAN的缺点

推荐整理分享GANs系列:CGAN(条件GAN)原理简介以及项目代码实现,希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:,内容如对您有帮助,希望把文章链接给更多的朋友!

       生成的图像是随机的,不可预测的,无法控制网络输出特定的图片,生成目标不明确,可控性不强。针对原始GAN不能生成具有特定属性的图片的问题, Mehdi Mirza等人提出了cGAN,其核心在于将属性信息y 融入生成器G和判别器D中,属性y可以是任何标签信息, 例如图像的类别、人脸图像的面部表情等。

二、CGAN的基本原理

      cGAN的中心思想是希望 可以控制 GAN 生成的图片,而不 是单纯的随机生成图片。 具体来说,Conditional GAN 在生成器和判别器的输入中 增加了额外的 条件信息,生成器生成的图片只有足够真实 且与条件相符,才能够通过判别器。

      实际上 , 在无条件约束的生成模型中 , 没法控制数据生成的模式。然而,通过额外的信息对模型进行约束,有可能指导数据生成的过程。条件约束可以是类标签 , 可以是图像修补的部分数据, 甚至是来自不同模态的数据

cGAN将 无监督学习 转为 有监督学习 使得网络可以更好地在我们的掌控下进行学习!

GANs系列:CGAN(条件GAN)原理简介以及项目代码实现

从公式看,cgan相当于在原始GAN的基础上对生成器部分 和判别器部分都加了一个条件

三、CGAN模型

如果将上图绿色部分的y去掉,就是GAN的原理图。 

 四、CGAN结构

为了实现条件GAN的目的,生成网络和判别网络的原理和 训练方式均要有所改变。

模型部分,在判别器和生成器中都添加了额外信息 y,y 可 以是类别标签或者是其他类型的数据,可以将 y 作为一个 额外的输入层丢入判别器和生成器。 

在生成器中,作者将输入噪声 z 和 y 连在一起隐含表示, 带条件约束这个简单直接的改进被证明非常有效,并广泛用 于后续的相关工作中。论文是在MNIST数据集上以类别标 签为条件变量,生成指定类别的图像。作者还探索了CGAN 在用于图像自动标注的多模态学习上的应用,在MIR Flickr25000数据集上,以图像特征为条件变量,生成该图像的tag的词向量。

 五、CGAN缺陷

cGAN生成的图像虽有很多缺陷,譬如图像边缘模糊,生成的图像分辨率太低等,但是它为后面的pix2pixGAN和CycleGAN开拓了道路,这两个模型转换图像风格时对属性特征的 处理方法均受cGAN启发。

六、代码实现,生成指定手写数字import torchimport torch.nn as nnimport torch.nn.functional as Fimport torch.optim as optimimport numpy as npimport matplotlib.pyplot as pltimport torchvisionfrom torchvision import transformsfrom torch.utils import dataimport osimport globfrom PIL import Image# 独热编码# 输入x代表默认的torchvision返回的类比值,class_count类别值为10def one_hot(x, class_count=10): return torch.eye(class_count)[x, :] # 切片选取,第一维选取第x个,第二维全要transform =transforms.Compose([transforms.ToTensor(), transforms.Normalize(0.5, 0.5)])dataset = torchvision.datasets.MNIST('data', train=True, transform=transform, target_transform=one_hot, download=False)dataloader = data.DataLoader(dataset, batch_size=64, shuffle=True)# 定义生成器class Generator(nn.Module): def __init__(self): super(Generator, self).__init__() self.linear1 = nn.Linear(10, 128 * 7 * 7) self.bn1 = nn.BatchNorm1d(128 * 7 * 7) self.linear2 = nn.Linear(100, 128 * 7 * 7) self.bn2 = nn.BatchNorm1d(128 * 7 * 7) self.deconv1 = nn.ConvTranspose2d(256, 128, kernel_size=(3, 3), padding=1) self.bn3 = nn.BatchNorm2d(128) self.deconv2 = nn.ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=2, padding=1) self.bn4 = nn.BatchNorm2d(64) self.deconv3 = nn.ConvTranspose2d(64, 1, kernel_size=(4, 4), stride=2, padding=1) def forward(self, x1, x2): x1 = F.relu(self.linear1(x1)) x1 = self.bn1(x1) x1 = x1.view(-1, 128, 7, 7) x2 = F.relu(self.linear2(x2)) x2 = self.bn2(x2) x2 = x2.view(-1, 128, 7, 7) x = torch.cat([x1, x2], axis=1) x = F.relu(self.deconv1(x)) x = self.bn3(x) x = F.relu(self.deconv2(x)) x = self.bn4(x) x = torch.tanh(self.deconv3(x)) return x# 定义判别器# input:1,28,28的图片以及长度为10的conditionclass Discriminator(nn.Module): def __init__(self): super(Discriminator, self).__init__() self.linear = nn.Linear(10, 1*28*28) self.conv1 = nn.Conv2d(2, 64, kernel_size=3, stride=2) self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=2) self.bn = nn.BatchNorm2d(128) self.fc = nn.Linear(128*6*6, 1) # 输出一个概率值 def forward(self, x1, x2): x1 =F.leaky_relu(self.linear(x1)) x1 = x1.view(-1, 1, 28, 28) x = torch.cat([x1, x2], axis=1) x = F.dropout2d(F.leaky_relu(self.conv1(x))) x = F.dropout2d(F.leaky_relu(self.conv2(x))) x = self.bn(x) x = x.view(-1, 128*6*6) x = torch.sigmoid(self.fc(x)) return x# 初始化模型device = 'cuda' if torch.cuda.is_available() else 'cpu'gen = Generator().to(device)dis = Discriminator().to(device)# 损失计算函数loss_function = torch.nn.BCELoss()# 定义优化器d_optim = torch.optim.Adam(dis.parameters(), lr=1e-5)g_optim = torch.optim.Adam(gen.parameters(), lr=1e-4)# 定义可视化函数def generate_and_save_images(model, epoch, label_input, noise_input): predictions = np.squeeze(model(label_input, noise_input).cpu().numpy()) fig = plt.figure(figsize=(4, 4)) for i in range(predictions.shape[0]): plt.subplot(4, 4, i + 1) plt.imshow((predictions[i] + 1) / 2, cmap='gray') plt.axis("off") plt.savefig('D:/practice/CGAN/img/image_at_epoch_{:04d}.png'.format(epoch)) plt.show()noise_seed = torch.randn(16, 100, device=device)label_seed = torch.randint(0, 10, size=(16,))label_seed_onehot = one_hot(label_seed).to(device)print(label_seed)# print(label_seed_onehot)# 开始训练D_loss = []G_loss = []# 训练循环for epoch in range(150): d_epoch_loss = 0 g_epoch_loss = 0 count = len(dataloader.dataset) # 对全部的数据集做一次迭代 for step, (img, label) in enumerate(dataloader): img = img.to(device) label = label.to(device) size = img.shape[0] random_noise = torch.randn(size, 100, device=device) d_optim.zero_grad() real_output = dis(label, img) d_real_loss = loss_function(real_output, torch.ones_like(real_output, device=device) ) d_real_loss.backward() #求解梯度 # 得到判别器在生成图像上的损失 gen_img = gen(label,random_noise) fake_output = dis(label, gen_img.detach()) # 判别器输入生成的图片,f_o是对生成图片的预测结果 d_fake_loss = loss_function(fake_output, torch.zeros_like(fake_output, device=device)) d_fake_loss.backward() d_loss = d_real_loss + d_fake_loss d_optim.step() # 优化 # 得到生成器的损失 g_optim.zero_grad() fake_output = dis(label, gen_img) g_loss = loss_function(fake_output, torch.ones_like(fake_output, device=device)) g_loss.backward() g_optim.step() with torch.no_grad(): d_epoch_loss += d_loss.item() g_epoch_loss += g_loss.item() with torch.no_grad(): d_epoch_loss /= count g_epoch_loss /= count D_loss.append(d_epoch_loss) G_loss.append(g_epoch_loss) if epoch % 10 == 0: print('Epoch:', epoch) generate_and_save_images(gen, epoch, label_seed_onehot, noise_seed)plt.plot(D_loss, label='D_loss')plt.plot(G_loss, label='G_loss')plt.legend()plt.show()

具体实战代码解读,参考:GAN实战之Pytorch 使用CGAN生成指定MNIST手写数字

本文链接地址:https://www.jiuchutong.com/zhishi/300285.html 转载请保留说明!

上一篇:R-CNN史上最全讲解(rcnn系列详解)

下一篇:结构重参数化(Structural Re-Parameters)PipLine(结构重参数化2d pose)

  • 金税盘抵减增值税的账务处理
  • 个人经营所得税起征点是多少
  • 土地增值税间接费用扣除
  • 设立登记核准通知书未签名不能办理其他业务
  • 金税四期如何应对
  • 可供出售金融资产是指什么
  • 公司购买软件著作权
  • 子公司之间可以相互交易吗
  • 汇算清缴弥补以前年度亏损报表怎么填
  • 固定资产转到别人账户
  • 建筑服务 1%
  • 一般代开增值税多少个点?
  • 公司收生育津贴怎么入账
  • 营改增后开不了增值税发票怎么样办?
  • 公司在外地
  • 2021年沙子水泥税率是多少
  • 生产企业出口发票数量开错有什么影响
  • 进项税留抵需要转出吗
  • 出租设备计提折旧记什么科目
  • 即征即退的增值税计入其他收益吗
  • 持有至到期投资减值准备
  • 出口货物收汇核销流程
  • 月末累计折旧怎么结转
  • linux acl配置命令
  • 员工罚款算收入吗
  • macOSCatalina10.15.5正式版值得升级吗 macOSCatalina10.15.5更新了什么
  • 印花税申报后怎么缴费
  • 期间费用为什么叫期间费用
  • 超额累进税率包括
  • 免抵退税务处理
  • 房地产企业增值税扣除土地成本
  • 用盈余公积弥补以前年度亏损
  • 爬虫10054
  • php验证码识别
  • 面试官:一千万是真的吗
  • php对数组进行排序
  • linuxshel
  • 2022年最新公务员职务职级对照表
  • 会计调整以前年度遗留问题查不出来说明怎么写
  • 购买方已认证怎么开具红字信息表
  • 代发工资范围
  • 个税少扣了怎么账平掉
  • 财政补贴收入账务处理
  • mysql索引优化的方案
  • mysql主从配置详解
  • 劳务分包劳务费付款比例
  • 将自产产品用于公益事业
  • 办公费税前扣除标准2023年
  • 金税盘入什么科目
  • 小规模纳税人缴纳本月未交增值税
  • 新办企业在建期间账务处理
  • 私立医院怎么做不被骗的事情
  • 小规模纳税人附加税减半征收
  • 返还垫付征地款
  • 折现率算现值公式
  • 工业企业成本占比多少合适
  • 库存商品销售后怎么做账
  • 会计什么情况下不用继续教育
  • sql server数据库恢复
  • sqlserver数据库维护
  • 优化什么建立生育支持政策体系
  • win10怎么查看windows版本
  • 苹果电脑支付在哪里设置
  • CentOS 6.x 到 CentOS 7 的升级过程
  • centos无法挂载位置
  • scrnsave.exe
  • windows中的hosts文件
  • ssh 使用代理
  • linux三种用户权限
  • win7无法注册dll文件
  • win10电脑开机蓝屏怎么进入安全模式
  • 定时删除一个文件怎么删
  • vs2019怎么用gets
  • 启用opengl的文档窗口
  • windows设置开机启动软件
  • python第三方库在哪下载
  • 相机调取失败
  • asynctask缺点
  • javascript简明教程
  • 河南城乡居民医疗保险电话
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设