欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 教育 > 培训 > 机器学习周报(12.2-12.8)

机器学习周报(12.2-12.8)

2025/10/20 5:23:05 来源:https://blog.csdn.net/weixin_51923997/article/details/144315913  浏览:    关键词:机器学习周报(12.2-12.8)

文章目录

    • 摘要
    • Abstract
  • Vision Transformer
    • 1 原理
    • 2 代码

摘要

本周学习了Vision Transformer (ViT) 的基本原理及其实现,并完成了基于PyTorch的模型训练、验证和预测任务。深入理解了ViT如何将图像分割成patch作为输入序列,并结合Transformer Encoder处理。通过迁移学习在花类数据集上训练模型,并验证了模型在预测任务中的优越性能。

Abstract

This week, I studied the fundamental principles and implementation of Vision Transformer (ViT) and completed model training, validation, and prediction tasks using PyTorch. I gained a deep understanding of how ViT splits an image into patches as input sequences and processes them using the Transformer Encoder. By leveraging transfer learning, I trained the model on a flower dataset and validated its superior performance in prediction tasks.

Vision Transformer

1 原理

  • 数据处理

我认为ViT的关键在于理解怎么将图片当作一个序列输入进模型之中。我们先看看ViT整体结构图,如下图所示

在这里插入图片描述
论文中提到将 224x224x3 的图像作为输入,将图像分为 16x16x3 大小的patch,也就是说将输入图像分为了 224 × 224 × 3 16 × 16 × 3 = 196 \frac{224×224×3}{16×16×3}=196 16×16×3224×224×3=196 个patch。其中每个patch拉直之后的维度为 16×16×3=768维,也就是Linear Projection of Flattened Patches层下面分割的小图像。

在具体实现中,使用卷积核大小为 16x16x3 、步距为16、卷积核个数为768的卷积层,就能将3维图像转换为Transformer所需要的输入token[组数,维度]。

  • 全连接层
    上述[196,768]的token将传入Linear Projection of Flattened Patches层,该层是 768x768 的全连接层,该层输出认为 196x768 。

  • 位置编码
    将经过全连接层后的输出进行位置编码,其位置编码和Transformer中的时序编码有异曲同工之妙,前者可以通过位置编码表示出token之间关于原输入图像的一些位置信息,后者可以表示输入先后的时序信息。
    该模型位置编码通过类似于坐标的形式表达,直接于输入相加,不改变维度大小。如下图所示:
    在这里插入图片描述

进行位置编码后,还需要加上一个特殊字符(最左输入0*),输入总组数从之前的196变为197,传入Transformer Encoder的token为[197,768]。

  • Transformer Encoder
    在这里插入图片描述
    ViT采用的是Transformer中编码器进行叠加,但其中的参数数量有所不同。
    经过位置编码和加入特殊字符的token[197,768]传入编码器,首先经过层归一化,再经过多头自注意力。这里的多头自注意力是采用12个头,也就是将768维分为12份,每份(Q、K、V)64维度,计算之后再进行合并为768维。

ViT中的编码器仍是采用残差连接,再经过一次层归一化后,就进入单个Transformer Encoder的最后一层MLP(多层感知机)。MLP将经过多头自注意力的输出维度升高4倍,即从768变为3072,最后再将维度降至768维

ViT中的编码器输入和输出都是768维,也就是在硬件运行的情况下一直叠加,论文中也是将该模块叠加了L块。

  • 输出
    ViT中的编码器输入和输出都是768维,也就是在硬件运行的情况下一直叠加,论文中也是将该模块叠加了L块。

在这里插入图片描述

最后,通过全连接层和softmax进行概率输出即可

2 代码

在理解完ViT的原理之后,我们来看看PyTorch代码如何实现。这里以ViT-base模型,输入图像 224x224x3,patch大小 16x16x3 为例

花类数据集:
https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgzy

训练模型代码如下,需要自行更改数据集路径和权重路径。

import os
import math
import argparseimport torch
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from torch.utils.tensorboard import SummaryWriter
from torchvision import transformsfrom my_dataset import MyDataSet
from vit_model import vit_base_patch16_224_in21k as create_model
from utils import read_split_data, train_one_epoch, evaluatedef main(args):device = torch.device(args.device if torch.cuda.is_available() else "cpu")if os.path.exists("../weights") is False:os.makedirs("../weights")tb_writer = SummaryWriter()train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(args.data_path)data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),"val": transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])}# 实例化训练数据集train_dataset = MyDataSet(images_path=train_images_path,images_class=train_images_label,transform=data_transform["train"])# 实例化验证数据集val_dataset = MyDataSet(images_path=val_images_path,images_class=val_images_label,transform=data_transform["val"])batch_size = args.batch_sizenw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workersprint('Using {} dataloader workers every process'.format(nw))train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,shuffle=True,pin_memory=True,num_workers=nw,collate_fn=train_dataset.collate_fn)val_loader = torch.utils.data.DataLoader(val_dataset,batch_size=batch_size,shuffle=False,pin_memory=True,num_workers=nw,collate_fn=val_dataset.collate_fn)model = create_model(num_classes=args.num_classes, has_logits=False).to(device)if args.weights != "":assert os.path.exists(args.weights), "weights file: '{}' not exist.".format(args.weights)weights_dict = torch.load(args.weights, map_location=device)# 删除不需要的权重# del_keys = ['head.weight', 'head.bias'] if model.has_logits \#     else ['pre_logits.fc.weight', 'pre_logits.fc.bias', 'head.weight', 'head.bias']del_keys = ['head.weight', 'head.bias']for k in del_keys:del weights_dict[k]print(model.load_state_dict(weights_dict, strict=False))if args.freeze_layers:for name, para in model.named_parameters():# 除head, pre_logits外,其他权重全部冻结if "head" not in name and "pre_logits" not in name:para.requires_grad_(False)else:print("training {}".format(name))pg = [p for p in model.parameters() if p.requires_grad]optimizer = optim.SGD(pg, lr=args.lr, momentum=0.9, weight_decay=5E-5)# Scheduler https://arxiv.org/pdf/1812.01187.pdflf = lambda x: ((1 + math.cos(x * math.pi / args.epochs)) / 2) * (1 - args.lrf) + args.lrf  # cosinescheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)for epoch in range(args.epochs):# traintrain_loss, train_acc = train_one_epoch(model=model,optimizer=optimizer,data_loader=train_loader,device=device,epoch=epoch)scheduler.step()# validateval_loss, val_acc = evaluate(model=model,data_loader=val_loader,device=device,epoch=epoch)tags = ["train_loss", "train_acc", "val_loss", "val_acc", "learning_rate"]tb_writer.add_scalar(tags[0], train_loss, epoch)tb_writer.add_scalar(tags[1], train_acc, epoch)tb_writer.add_scalar(tags[2], val_loss, epoch)tb_writer.add_scalar(tags[3], val_acc, epoch)tb_writer.add_scalar(tags[4], optimizer.param_groups[0]["lr"], epoch)torch.save(model.state_dict(), "../weights/model-{}.pth".format(epoch))if __name__ == '__main__':parser = argparse.ArgumentParser()parser.add_argument('--num_classes', type=int, default=5)parser.add_argument('--epochs', type=int, default=10)parser.add_argument('--batch-size', type=int, default=8)parser.add_argument('--lr', type=float, default=0.001)parser.add_argument('--lrf', type=float, default=0.01)# 数据集所在根目录# https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgzparser.add_argument('--data_path', type=str, default='../data/flower_photos', help='path to dataset')parser.add_argument('--model-name', default='', help='create model name')# 预训练权重路径,如果不想载入就设置为空字符parser.add_argument('--weights', type=str, default='../weights/vit_base_patch16_224.pth', help='path to initial weights')# 是否冻结权重parser.add_argument('--freeze-layers', type=bool, default=True)parser.add_argument('--device', default='cuda:0', help='device id (i.e. 0 or 0,1 or cpu)')opt = parser.parse_args()main(opt)

训练结果如下:

在这里插入图片描述

因为是迁移学习的原因,只需要进行微调即可,所以9epoch之后准确率就达到97.9%了。

每训练一个epoch,就会将训练模型保存至weights文件夹,如下图所示

在这里插入图片描述

通过上述代码的训练之后,我们可以将保存的模型model-9.pth引入预测代码进行预测啦!需自行更改权重路径,以及需要测试的图片路径。

import os
import jsonimport torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as pltfrom vit_model import vit_base_patch16_224_in21k as create_modeldef main():device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")data_transform = transforms.Compose([transforms.Resize(254),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])# load imageimg_path = "../data/Image/flower.png"assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)img = Image.open(img_path)img2 = imgplt.imshow(img)plt.show()img = img.convert('RGB')img = data_transform(img)# expand batch dimensionimg = torch.unsqueeze(img, dim=0)  # [1, 3, 224, 224]# read class_indictjson_path = 'class_indices.json'assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)with open(json_path, "r") as f:class_indict = json.load(f)# create modelmodel = create_model(num_classes=5, has_logits=False).to(device)  # num_classes=5:表示模型将被训练来识别5个不同的类别;has_logits=False:模型不直接输出logits,在实际应用中,这通常意味着模型的输出层之后可能会跟随一个softmax激活函数# load model weightsmodel_weight_path = "../weights/model-9.pth"  # 采用第10轮训练的参数model.load_state_dict(torch.load(model_weight_path, map_location=device))model.eval()with torch.no_grad():# predict classoutput = torch.squeeze(model(img.to(device))).cpu()predict = torch.softmax(output, dim=0)predict_cla = torch.argmax(predict).numpy()print_res = "class: {}   prob: {:.3}".format(class_indict[str(predict_cla)], predict[predict_cla].numpy())plt.title(print_res)for i in range(len(predict)):print("class: {:10}   prob: {:.3}".format(class_indict[str(i)], predict[i].numpy()))plt.imshow(img2)plt.show()if __name__ == '__main__':main()

模型预测结果如下所示:
在这里插入图片描述

在这里插入图片描述

模型预测结果几乎100%为sunflowers

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

热搜词