欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 新闻 > 会展 > python打卡day53

python打卡day53

2025/6/19 8:48:42 来源:https://blog.csdn.net/m0_62284115/article/details/148750529  浏览:    关键词:python打卡day53

对抗生成网络

  1. 对抗生成网络的思想:关注损失从何而来
  2. 生成器、判别器
  3. nn.sequential容器:适合于按顺序运算的情况,简化前向传播写法
  4. leakyReLU介绍:避免relu的神经元失活现象

作业:对于心脏病数据集,对于病人这个不平衡的样本用GAN来学习并生成病人样本,观察不用GAN和用GAN的F1分数差异。

数据预处理

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import pandas as pd
from sklearn.preprocessing import MinMaxScaler
from sklearn.datasets import load_iris
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore")
# 设置中文字体支持
plt.rcParams["font.family"] = ["SimHei"]
plt.rcParams['axes.unicode_minus'] = False  # 解决负号显示问题# 检查GPU是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")LATENT_DIM = 10     # 潜在空间的维度,这里根据任务复杂程度任选
EPOCHS = 10000      # 训练的回合数,一般需要比较长的时间
BATCH_SIZE = 32     # 每批次训练的样本数
LR = 0.0002         # 学习率
BETA1 = 0.5         # Adam优化器的参数# 检查是否有可用的GPU,否则使用CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")#预处理
data=pd.read_csv('heart.csv')
# print(data)
X = data.drop(['target'], axis=1)  
y = data['target']  # 
# 数据缩放到 [-1, 1]
scaler = MinMaxScaler(feature_range=(-1, 1)) 
X_scaled = scaler.fit_transform(X) # 转换为 PyTorch Tensor 并创建 DataLoader
# 注意需要将数据类型转为 float
real_data_tensor = torch.from_numpy(X_scaled).float() 
dataset = TensorDataset(real_data_tensor)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)print(f"成功加载并预处理数据。用于训练的样本数量: {len(X_scaled)}")
print(f"数据特征维度: {X_scaled.shape[1]}")

生成器

# --- 3. 构建模型 ---# (A) 生成器 (Generator)
class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.model = nn.Sequential(nn.Linear(LATENT_DIM, 16),nn.ReLU(),nn.Linear(16, 32),nn.ReLU(),nn.Linear(32, 13),# 最后的维度只要和目标数据对齐即可nn.Tanh() # 输出范围是 [-1, 1])def forward(self, x):return self.model(x) # 因为没有像之前一样做定义x=某些东西,所以现在可以直接输出模型

判别器

# (B) 判别器 (Discriminator)
class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.model = nn.Sequential(nn.Linear(13, 32),nn.LeakyReLU(0.2), # LeakyReLU 是 GAN 中的常用选择nn.Linear(32, 16),nn.LeakyReLU(0.2), # 负斜率参数为0.2nn.Linear(16, 1), # 这里最后输出1个神经元,所以用sigmoid激活函数nn.Sigmoid() # 输出 0 到 1 的概率)def forward(self, x):return self.model(x)

训练

# 实例化模型并移动到指定设备
generator = Generator().to(device)
discriminator = Discriminator().to(device)print(generator)
print(discriminator)# --- 4. 定义损失函数和优化器 ---criterion = nn.BCELoss() # 二元交叉熵损失# 分别为生成器和判别器设置优化器
g_optimizer = optim.Adam(generator.parameters(), lr=LR, betas=(BETA1, 0.999))
d_optimizer = optim.Adam(discriminator.parameters(), lr=LR, betas=(BETA1, 0.999))# --- 5. 执行训练循环 ---print("\n--- 开始训练 ---")
for epoch in range(EPOCHS):for i, (real_data,) in enumerate(dataloader):# 将数据移动到设备real_data = real_data.to(device)current_batch_size = real_data.size(0)# 创建真实和虚假的标签real_labels = torch.ones(current_batch_size, 1).to(device)fake_labels = torch.zeros(current_batch_size, 1).to(device)# ---------------------#  训练判别器# ---------------------d_optimizer.zero_grad() # 梯度清零# (1) 用真实数据训练real_output = discriminator(real_data)d_loss_real = criterion(real_output, real_labels)# (2) 用假数据训练noise = torch.randn(current_batch_size, LATENT_DIM).to(device)# 使用 .detach() 防止在训练判别器时梯度流回生成器,这里我们未来再说fake_data = generator(noise).detach() fake_output = discriminator(fake_data)d_loss_fake = criterion(fake_output, fake_labels)# 总损失并反向传播d_loss = d_loss_real + d_loss_faked_loss.backward()d_optimizer.step()# ---------------------#  训练生成器# ---------------------g_optimizer.zero_grad() # 梯度清零# 生成新的假数据,并尝试"欺骗"判别器noise = torch.randn(current_batch_size, LATENT_DIM).to(device)fake_data = generator(noise)fake_output = discriminator(fake_data)# 计算生成器的损失,目标是让判别器将假数据误判为真(1)g_loss = criterion(fake_output, real_labels)# 反向传播并更新生成器g_loss.backward()g_optimizer.step()# 每 1000 个 epoch 打印一次训练状态if (epoch + 1) % 1000 == 0:print(f"Epoch [{epoch+1}/{EPOCHS}], "f"Discriminator Loss: {d_loss.item():.4f}, "f"Generator Loss: {g_loss.item():.4f}")print("--- 训练完成 ---")

可视化

# --- 6. 生成新数据并进行可视化对比 ---print("\n--- 生成并可视化结果 ---")
# 将生成器设为评估模式
generator.eval()# 使用 torch.no_grad() 来关闭梯度计算
with torch.no_grad():num_new_samples = 50noise = torch.randn(num_new_samples, LATENT_DIM).to(device)generated_data_scaled = generator(noise)# 将生成的数据从GPU移到CPU,并转换为numpy数组
generated_data_scaled_np = generated_data_scaled.cpu().numpy()# 逆向转换回原始尺度
generated_data = scaler.inverse_transform(generated_data_scaled_np)
real_data_original_scale = scaler.inverse_transform(X_scaled)# 可视化对比
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
fig.suptitle('真实数据 vs. GAN生成数据 的特征分布对比 (PyTorch)', fontsize=16)# 从数据中获取特征列名(排除目标列)
feature_names = data.drop(['target'], axis=1).columns.tolist()for i, ax in enumerate(axes.flatten()):ax.hist(real_data_original_scale[:, i], bins=10, density=True, alpha=0.6, label='Real Data')ax.hist(generated_data[:, i], bins=10, density=True, alpha=0.6, label='Generated Data')ax.set_title(feature_names[i])ax.legend()plt.tight_layout(rect=[0, 0.03, 1, 0.95])
plt.show()# 将生成的数据与真实数据并排打印出来看看
print("\n前5个真实样本 (Setosa):")
print(pd.DataFrame(real_data_original_scale[:5], columns=feature_names))print("\nGAN生成的5个新样本:")
print(pd.DataFrame(generated_data[:5], columns=feature_names))

前5个真实样本 (Setosa):age  sex   cp  trestbps   chol  fbs  restecg  thalach  exang  oldpeak  \
0  63.0  1.0  3.0     145.0  233.0  1.0      0.0    150.0    0.0      2.3   
1  37.0  1.0  2.0     130.0  250.0  0.0      1.0    187.0    0.0      3.5   
2  41.0  0.0  1.0     130.0  204.0  0.0      0.0    172.0    0.0      1.4   
3  56.0  1.0  1.0     120.0  236.0  0.0      1.0    178.0    0.0      0.8   
4  57.0  0.0  0.0     120.0  354.0  0.0      1.0    163.0    1.0      0.6   slope   ca  thal  
0    0.0  0.0   1.0  
1    0.0  0.0   2.0  
2    2.0  0.0   2.0  
3    2.0  0.0   2.0  
4    2.0  0.0   2.0  GAN生成的5个新样本:age  sex        cp    trestbps        chol  fbs   restecg  \
0  62.215790  1.0  1.589591  130.429031  274.840942  0.0  0.753500   
1  40.936096  1.0  0.002826  100.742523  241.846802  0.0  0.977905   
2  48.687664  1.0  0.000000  104.236229  341.495422  0.0  0.000000   
3  52.472134  1.0  0.333481  106.300339  252.162643  0.0  0.985957   
4  65.265450  1.0  2.882269  120.447281  276.935150  0.0  0.000000   thalach         exang   oldpeak     slope        ca      thal  
0  146.546661  0.000000e+00  3.115193  0.945533  0.767616  3.000000  
1  173.175735  2.980232e-08  0.558755  2.000000  0.075293  2.999915  
2  158.998764  1.000000e+00  3.843602  0.959458  0.192766  3.000000  
3  169.315262  0.000000e+00  0.179108  2.000000  0.529046  2.050252  
4  164.952011  0.000000e+00  2.238417  0.965755  0.493075  3.000000

@浙大疏锦行

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

热搜词