GAN实例基于神经网络

目录

1.前言

2.实验


1.前言

        需要了解GAN的原理查看对抗生成网络(GAN),DCGAN原理。

        采用手写数字识别数据集

2.实验

import argparse
import os
import numpy as np
import math

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch

os.makedirs("images", exist_ok=True)

opt = parser.parse_args()
print(opt)

img_shape = (opt.channels, opt.img_size, opt.img_size)

cuda = True if torch.cuda.is_available() else False


class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(opt.latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *img_shape)
        return img


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)

        return validity


# Loss function
adversarial_loss = torch.nn.BCELoss()

# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()

if cuda:
    generator.cuda()
    discriminator.cuda()
    adversarial_loss.cuda()

# Configure data loader
os.makedirs("./data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "./data/mnist",
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
        ),
    ),
    batch_size=opt.batch_size,
    shuffle=True,
)

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))


Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

# ----------
#  Training
# ----------

for epoch in range(opt.n_epochs):
    for i, (imgs, _) in enumerate(dataloader):

        # Adversarial ground truths
        valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False)
        fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False)

        # Configure input
        real_imgs = Variable(imgs.type(Tensor))

        # -----------------
        #  Train Generator
        # -----------------

        optimizer_G.zero_grad()

        # Sample noise as generator input
        z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))

        # Generate a batch of images
        gen_imgs = generator(z)

        # Loss measures generator's ability to fool the discriminator
        g_loss = adversarial_loss(discriminator(gen_imgs), valid)

        g_loss.backward()
        optimizer_G.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_D.zero_grad()

        # Measure discriminator's ability to classify real from generated samples
        real_loss = adversarial_loss(discriminator(real_imgs), valid)
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2

        d_loss.backward()
        optimizer_D.step()

        print(
            "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
            % (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item())
        )

        batches_done = epoch * len(dataloader) + i
        if batches_done % opt.sample_interval == 0:
            save_image(gen_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/632565.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

怎么把照片变小做头像?多种方法教你图片改尺寸

现在在社交媒体平台或者是社交软件上,我们经常会去更改头像来展示自己,但是有时候我们拍摄的照片太大无法直接用作头像,这时候就需要去修改图片尺寸,将图片改大小到合适的数值才能使用,那么如何快速的将图片改大小呢&a…

在UBuntu上安装QT环境

一、UBuntu环境 二、官网下载QT https://download.qt.io/archive/qt/ 安装所需选择版本下载,可以现在windows下载在复制进去 三、安装QT 1、复制到ubuntu 2、打开终端,改变刚下载文件的权限 权限代号 r:读取权限,数字代号为 “…

手机图片恢复不求人:手动找回丢失的照片!

无论是外出旅行、聚会还是日常点滴,我们总是习惯用手机记录下来,让美好的瞬间定格在一张张照片中。然而,有时因为误删、清空缓存或是更换手机,那些珍贵的照片突然消失了。手机图片恢复有什么简单易行、容易上手的方法吗&#xff1…

MySQL创建存储过程函数(2)

DDL CREATE TABLE student (id int(11) NOT NULL AUTO_INCREMENT COMMENT 学号,createDate datetime DEFAULT NULL,userName varchar(20) DEFAULT NULL,pwd varchar(36) DEFAULT NULL,phone varchar(11) DEFAULT NULL,age tinyint(3) DEFAULT NULL,sex char(2) DEFAULT NULL,i…

上海初中生古诗文大会倒计时4个月:单选题真题示例和独家解析

现在距离2024年初中生古诗文大会还有4个多月时间,备考要趁早,因为知识点还是相对比较多的。这些知识点对于初中语文的学习也是很有帮助的。 今天我们继续来看10道选择题真题和详细解析,以下题目截取自我独家制作的在线真题集,都是…

教你五行代码实现大批量文件重命名

问题背景 文件夹里的大量文件,命名很乱,并且要重新命名为固定长度顺序的文件很麻烦。这里采用5行python实现大批量文件按要求统一命名。 现有文件夹列表 tulips 代码实现 main.py import os path rtulips/ for num, file in enumerate(os.listdir(…

狙击策略专用术语以及含义,WeTrade3秒讲解

想必各位交易高手对狙击策略不会陌生吧!但你想必不知道狙击策略的开发者为了推广狙击策略,在狙击策略基础的经典技术分析理论引入了自己的术语。今天WeTrade众汇和各位投资者继续了解狙击策略专用术语以及含义。 一.BL 银行级别(BL)是前一日线收盘的级别。时间是格…

央视的AI动画《AI我中华》宣传视频,原来用AI工具Stable Diffusion制作,竟然这么简单?

大家好,我是向阳。 前段时间,央视的《爱我中华》AI宣传短片火爆全网,有一个穿越转场效果非常惊艳!先来回顾回顾: 今天就先来详细讲解,如何利用Stable Diffusion制作这样的穿越转场视频。 如你还没有安装St…

脑中风也会出现眩晕?快速识别中风,一定要牢记这些!

眩晕是许多人都会经历的不适感,发作时仿佛整个世界都在旋转,可能还伴随着站立不稳、脚步虚浮、恶心等症状。然而,你可能不知道的是,这些症状在某些情况下可能是脑中风的前兆。如果不及时关注并采取相应措施,一旦发展为…

linux ndk编译搭建测试

一、ndk下载 NDK 下载 | Android NDK | Android Developers 二、ndk环境变量配置 ndk解压: unzip android-ndk-r26d-linux.zip 环境变量配置: export NDK_HOME/rd/own/test/android-ndk-r26d/ export PATH$PATH:$NDK_HOME 三、编译测试验证 …

003_PyCharm的安装与使用

如果你正在学习PyQt,本系列教程完全可以带你入门直至入土。 所谓从零开始,就是从软件安装、环境配置开始。 不跳过一个细节,不漏掉一行代码,不省略一个例图。 IDE 开始学习一个编程语言,我们肯定是首先得安装好它&…

【深度学习】【Lora训练2】StabelDiffusion,Lora训练过程,秋叶包,Linux,SDXL Lora训练

文章目录 一、如何为图片打标1.1. 打标工具1.1.1. 秋叶中使用的WD1.41.1.2. 使用BLIP21.1.3. 用哪一种 二、 Lora训练数据的要求2.1 图片要求2.2 图片的打标要求 三、 Lora的其他问题qa1qa2qa3qa4qa5 四、 对图片的处理细节4.1. 图片尺寸问题4.2. 图片内容选取问题4.3. 什么是一…

python批量生成防伪识别二维码

欢迎关注我👆,收藏下次不迷路┗|`O′|┛ 嗷~~ 目录 一.前言 二.代码 三.使用 四.总结 一.前言 二维码(QR Code)是一种矩阵条码技术,它使用黑白矩形图案来表示二进制数据,这些矩形图案可以被设备扫描并解读。二维码可以被用来存储

究极完整版!!Centos6.9安装最适配的python和yum,附带教大家如何写Centos6.9的yum.repos.d配置文件。亲测可行!

前言! 这里我真是要被Centos6.9给坑惨了,最刚开始学习linux的时候并没有在意那么的,没有考虑到选版本问题,直到23年下半年,官方不维护Centos6.9了,基本上当时配置的文件和安装的依赖都用不了了&#xff0c…

springboot 引用外配置json文件

场景 一些服务需要记录一些持久化的信息(没有数据库,redis,elasticsearch 可用) 我们就项目启动过程创建一个json 文件去记录工作内容的进程(json 可视化与改动非常方便) 实现效果 代码 application.yml…

HashTable, HashMap, ConcurrentHashMap 三者区别

目录 1. HashMap 2. HashTable 3. ConcurrentHashMap 1. HashMap HashMap 是 Java 中非常常用的一个数据结构,它主要用于存储 键值对(K,V)。 在JDK 1.7中,HashMap的实现是基于 Table数组 和 Entry链表 的组合。 从…

直播卖券有妙招:实景ai无人直播系统帮助商家自动化团购直播!

在数字化浪潮席卷的今天,直播卖券已成为商家推广和营销的重要手段。然而,如何高效、精准地利用直播卖券,让每一位观众都能沉浸在购物的乐趣中,成为商家们迫切需要解决的问题。幸运的是,实景AI无人直播系统应运而生&…

503 SERVICE_UNAVAILABLE “Unable to find instance for ***“的问题。

503 SERVICE_UNAVAILABLE "Unable to find instance for *****"问题 1. 问题描述 在完成网关微服务的配置后,网关能够被Nacos发现并正常运行,但是无法将请求转发到其他微服务。使用Postman进行请求时,返回503错误,而后…

Python程序设计 文件处理(二)

实验十二 文件处理 第1关:读取宋词文件,根据词人建立多个文件 读取wjcl/src/step1/宋词.txt文件, 注意:宋词文件的标题行的词牌和作者之间是全角空格(" ")可复制该空格 在wjcl/src/step3/cr文件夹下根据每…

畅捷通TPlus keyEdit.aspx、KeyInfoList.aspx SQL注入漏洞复现

前言 免责声明:请勿利用文章内的相关技术从事非法测试,由于传播、利用此文所提供的信息或者工具而造成的任何直接或者间接的后果及损失,均由使用者本人负责,所产生的一切不良后果与文章作者无关。该文章仅供学习用途使用。 一、产…