欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 财经 > 创投人物 > demo_GAN

demo_GAN

2025/7/16 5:42:03 来源:https://blog.csdn.net/yyfhq/article/details/142989398  浏览:    关键词:demo_GAN
# 导入PyTorch库,这是一个用于深度学习的开源库
import torch
# 导入PyTorch的神经网络模块(nn),用于定义神经网络结构
import torch.nn as nn
# 导入PyTorch的函数式模块(functional),提供了一些常用的激活函数和损失函数等
import torch.nn.functional as F
# 导入PyTorch的优化器模块(optim),用于定义优化算法,如梯度下降等
import torch.optim as optim
# 从PyTorch的数据加载器模块中导入DataLoader和TensorDataset类,用于加载和处理数据集
from torch.utils.data import DataLoader, TensorDataset
# 从torchvision库的实用工具模块中导入save_image函数,用于保存生成的图像
from torchvision.utils import save_image
from torch.utils.data import Dataset
from PIL import Image
from torchvision import transforms
# 导入os模块,用于处理文件和目录操作
import os
import matplotlib.pyplot as plt# 自注意力机制模块定义
class SelfAttention(nn.Module):def __init__(self, in_dim):super(SelfAttention, self).__init__()self.query = nn.Conv2d(in_dim, in_dim // 8, kernel_size=1)self.key = nn.Conv2d(in_dim, in_dim // 8, kernel_size=1)self.value = nn.Conv2d(in_dim, in_dim, kernel_size=1)self.gamma = nn.Parameter(torch.zeros(1))def forward(self, x):batch_size, C, width, height = x.size()query = self.query(x).view(batch_size, -1, width * height).permute(0, 2, 1)key = self.key(x).view(batch_size, -1, width * height)energy = torch.bmm(query, key)attention = F.softmax(energy, dim=-1)value = self.value(x).view(batch_size, -1, width * height)out = torch.bmm(value, attention.permute(0, 2, 1))out = out.view(batch_size, C, width, height)out = self.gamma * out + xreturn out# Generator Model定义了一个名为Generator的神经网络模型,它继承自PyTorch框架中的nn.Module类
class Generator(nn.Module):def __init__(self, noise_dim, label_dim):super(Generator, self).__init__()self.label_dim = label_dim# 定义了一个名为self.fc的神经网络层序列,含三个层,输入层:随机噪声和标签,批量归一化层,漏洞型relu层self.fc = nn.Sequential(nn.Linear(noise_dim + label_dim, 1024 * 2 * 2),nn.BatchNorm1d(1024 * 2 * 2),# 这是一个Leaky ReLU激活函数层,它的作用是将负数的输入值乘以一个小的常数(这里是0.2),然后将结果作为输出,在原始数据上进行操作nn.LeakyReLU(0.2, inplace=True))# Hidden Layers: Deconv + BN + Leaky ReLUself.deconv_layers = nn.Sequential(nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2, padding=1),  # 2x2 -> 4x4nn.BatchNorm2d(512),nn.LeakyReLU(0.1, inplace=True),nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),  # 4x4 -> 8x8nn.BatchNorm2d(256),nn.LeakyReLU(0.1, inplace=True),nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),  # 8x8 -> 16x16nn.BatchNorm2d(128),nn.LeakyReLU(0.1, inplace=True),nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),  # 16x16 -> 32x32nn.BatchNorm2d(64),nn.LeakyReLU(0.2, inplace=True),nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),  # 32x32 -> 64x64nn.BatchNorm2d(32),nn.LeakyReLU(0.2, inplace=True),# 增加一层,扩展到128x128nn.ConvTranspose2d(32, 16, kernel_size=4, stride=2, padding=1),  # 64x64 -> 128x128nn.BatchNorm2d(16),nn.LeakyReLU(0.2, inplace=True),nn.ConvTranspose2d(16, 3, kernel_size=3, stride=1, padding=1),  # 128x128 -> 128x128 (RGB)nn.Tanh()  # 输出层,范围[-1, 1])def forward(self, noise, labels):# 拼接噪声和标签x = torch.cat((noise, labels), dim=1)x = self.fc(x).view(-1, 1024, 2, 2)return self.deconv_layers(x)class Discriminator(nn.Module):def __init__(self, input_channels):super(Discriminator, self).__init__()# 第一层:并行卷积层(3×3和5×5卷积核),后续拼接self.conv1_3x3 = nn.Sequential(nn.utils.spectral_norm(nn.Conv2d(in_channels=input_channels, out_channels=64, kernel_size=3, stride=2, padding=1)),nn.LeakyReLU(0.1, inplace=True))self.conv1_5x5 = nn.Sequential(nn.utils.spectral_norm(nn.Conv2d(in_channels=input_channels, out_channels=64, kernel_size=5, stride=2, padding=2)),nn.LeakyReLU(0.1, inplace=True))# (2) conv + BN + leaky Relu (dilation rate 1)self.conv2 = nn.Sequential(nn.utils.spectral_norm(nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=1, dilation=1)),nn.LeakyReLU(0.2, inplace=True),)# (3) conv + BN + leaky Relu + self-attention mechanismself.conv3 = nn.Sequential(nn.utils.spectral_norm(nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=2, padding=1)),nn.LeakyReLU(0.2, inplace=True),SelfAttention(256))# (4) conv + BN + leaky Relu (parallel 3x3, 5x5, and 7x7 kernels)self.conv4_3x3 = nn.Sequential(nn.utils.spectral_norm(nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=2, padding=1)),nn.LeakyReLU(0.2, inplace=True))self.conv4_5x5 = nn.Sequential(nn.utils.spectral_norm(nn.Conv2d(in_channels=256, out_channels=512, kernel_size=5, stride=2, padding=2)),nn.LeakyReLU(0.2, inplace=True))self.conv4_7x7 = nn.Sequential(nn.utils.spectral_norm(nn.Conv2d(in_channels=256, out_channels=512, kernel_size=7, stride=2, padding=3)),nn.LeakyReLU(0.2, inplace=True))# (5) conv + BN + leaky Relu (dilation rate 3)self.conv5 = nn.Sequential(nn.utils.spectral_norm(nn.Conv2d(in_channels=1536, out_channels=1024, kernel_size=3, stride=1, padding=3, dilation=3)),nn.LeakyReLU(0.1, inplace=True))# (6) 用卷积层替换全连接层,输出1x1特征图,并使用sigmoid激活函数self.fc = nn.utils.spectral_norm(nn.Linear(1024 * 4 * 4, 1))self.sigmoid = nn.Sigmoid()def forward(self, x):x1 = self.conv1_3x3(x)x2 = self.conv1_5x5(x)x = torch.cat((x1, x2), dim=1)x = self.conv2(x)x = self.conv3(x)x1 = self.conv4_3x3(x)x2 = self.conv4_5x5(x)x3 = self.conv4_7x7(x)x = torch.cat((x1, x2, x3), dim=1)x = self.conv5(x)x = nn.AvgPool2d(2)(x)x = x.view(x.size(0), -1)  # Flatten the tensorx = self.fc(x)x = self.sigmoid(x)return x# 设置超参数
noise_dim = 100  # 噪声维度
label_dim = 58  # 标签维度
batch_size =64  # 批大小
learning_rate = 0.0001
num_epochs = 500  # 训练轮数
output_dir = "D:/PycharmProjects/GAN/traffic sign datasets/CTSDB/resized_images/MMSGAN"  # 生成图像保存路径# 确保输出目录存在
if not os.path.exists(output_dir):os.makedirs(output_dir)# 创建生成器和判别器
G = Generator(noise_dim=noise_dim, label_dim=label_dim).to('cuda')
D = Discriminator(input_channels=3).to('cuda')# TrafficSignDataset类,用于数据加载
class TrafficSignDataset(Dataset):def __init__(self, root_dir, labels_file, transform=None):self.root_dir = root_dirself.transform = transformself.image_paths = []self.labels = []with open(labels_file, 'r') as f:lines = f.readlines()for line in lines:img_name, label = line.strip().split()img_path = os.path.join(root_dir, img_name)self.image_paths.append(img_path)self.labels.append(int(label))def __len__(self):return len(self.image_paths)def __getitem__(self, idx):img_path = self.image_paths[idx]image = Image.open(img_path).convert('RGB')label = self.labels[idx]if self.transform:image = self.transform(image)return image, label# 损失函数和优化器
criterion = nn.BCELoss()  # 二元交叉熵损失
optimizer_G = optim.Adam(G.parameters(), lr=learning_rate*4,betas=(0.5, 0.999),weight_decay=1e-4)
optimizer_D = optim.Adam(D.parameters(), lr=learning_rate, betas=(0.5, 0.999))
# 设置学习率衰减参数
decay = 0.0001
num_epochs = 500# 训练循环
for epoch in range(num_epochs):# ... 训练过程 ...# 更新学习率lr_new_G = learning_rate * 4 / (1 + decay * num_epochs)lr_new_D = learning_rate / (1 + decay * num_epochs)for param_group in optimizer_G.param_groups:param_group['lr'] = lr_new_Gfor param_group in optimizer_D.param_groups:param_group['lr'] = lr_new_D# 定义图像预处理和数据增强
transform = transforms.Compose([transforms.Resize((128, 128)),  # 调整图像大小# 这个操作会将图像数据从0-255的整数值范围(如果是uint8类型)转换为0-1之间的浮点数范围,并且会将图像的形状从(H, W, C)转换为(C, H, W),其中H是高度,W是宽度,C是通道数。这样做是为了符合PyTorch模型的输入要求.transforms.ToTensor(),  # 转换为 Tensortransforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 归一化到 [-1, 1]
])# 创建数据集和数据加载器
root_dir = "D:/PycharmProjects/GAN/traffic sign datasets/CTSDB/resized_images/BorderlineSMOTE -insepct"
labels_file = "D:/PycharmProjects/GAN/traffic sign datasets/CTSDB/resized_images/BorderlineSMOTE -insepct/labels.txt"  # 标签文件路径
dataset = TrafficSignDataset(root_dir=root_dir, labels_file=labels_file, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)# 生成一批随机标签(整数)及其对应的独热编码(one-hot encoding),独热编码提供了一种方便的方式来表示真实标签,使得我们可以使用交叉熵损失等损失函数来计算预测值与真实值之间的差异。
def create_labels(batch_size, label_dim):labels = torch.randint(0, label_dim, (batch_size,))labels_one_hot = torch.zeros(batch_size, label_dim).scatter_(1, labels.view(-1, 1), 1)return labels.to('cuda'), labels_one_hot.to('cuda')def train():torch.cuda.empty_cache()# 初始化空列表,用于存储生成器和判别器的损失值
d_losses = []
g_losses = []# 训练循环
for epoch in range(num_epochs):for i, (real_images, _) in enumerate(dataloader):real_images = real_images.to('cuda')# 1. 训练判别器# 真实数据损失# 这行代码调用了一个名为create_labels的函数,该函数接收两个参数:real_images.size(0)表示真实图像的数量,label_dim表示标签的维度real_labels, real_labels_one_hot = create_labels(real_images.size(0), label_dim)real_outputs = D(real_images)noise_real = torch.rand_like(real_outputs) * -0.1real_loss = criterion(real_outputs, torch.full_like(real_outputs, 0.8) + noise_real)# 生成数据损失noise = torch.randn(real_images.size(0), noise_dim).to('cuda')fake_labels, fake_labels_one_hot = create_labels(real_images.size(0), label_dim)fake_images = G(noise, fake_labels_one_hot)fake_outputs = D(fake_images.detach())# 为假标签加入随机噪声(0, 0.1)noise_fake = torch.rand_like(fake_outputs) * 0.1fake_loss = criterion(fake_outputs, torch.full_like(fake_outputs, 0.2) + noise_fake)# 判别器总损失d_loss = real_loss + fake_lossoptimizer_D.zero_grad()d_loss.backward()optimizer_D.step()# 2. 训练生成器fake_outputs = D(fake_images)g_loss = criterion(fake_outputs, torch.ones_like(fake_outputs))optimizer_G.zero_grad()g_loss.backward()optimizer_G.step()# 追加损失值到列表中d_losses.append(d_loss.item())g_losses.append(g_loss.item())# 打印损失值if i % 50 == 0:print(f"Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{len(dataloader)}], "f"D Loss: {d_loss.item():.4f}, G Loss: {g_loss.item():.4f}")# 每隔一定步保存生成的图像if i % 200 == 0:# 保存每一张生成的图像for idx in range(min(30, fake_images.size(0))):  # 遍历生成的每一张图像save_image(fake_images[idx],os.path.join(output_dir, f"epoch_{epoch + 1}_image_{idx + 1}.png"),normalize=True)  # 保存每一张图像,命名方式包括epoch, step, 和图像编号print("训练完成并保存生成图像。")# 绘制生成器和判别器的损失曲线
plt.figure(figsize=(10, 5))
plt.plot(d_losses, label='Discriminator Loss', color='blue')
plt.plot(g_losses, label='Generator Loss', color='red')
plt.xlabel('Iterations')
plt.ylabel('Loss')
plt.title('Generator and Discriminator Loss During Training')
plt.legend()
plt.grid()
plt.show()

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

热搜词