系列文章目录

pytorch学习笔记(一)-- pytorch深度学习框架基本知识了解

pytorch学习笔记(二)-- pytorch模型开发步骤详解

pytorch学习笔记(三)-- TensorBoard的介绍

pytorch学习笔记(四)-- TorchVision 物体检测微调教程

pytorch学习笔记(五)-- 计算机视觉的迁移学习

文章目录

系列文章目录

文章目录

前言

一、加载数据 

二、训练模型

三、可视化模型预测

四、卷积网络微调

   微调 ConvNet:

ConvNet 作为固定特征提取器:

五、自定义图像的推理

总结


前言

        在本章节,您将学习如何使用迁移学习训练卷积神经网络进行图像分类。您可以在 cs231n notes 笔记中阅读有关迁移学习的更多信息。

        一般来说,大家都不会从头开始训练卷积神经网络,而是先在较大的数据集上做预训练,差不多成熟了,然后再把卷积网络在自己的任务上做初始化,或者特征提取器。

这两种主要的迁移学习场景如下:

  • 微调 ConvNet:我们不是使用随机初始化,而是使用预训练网络(例如在 imagenet 1000 数据集上训练的网络)来初始化网络。其余训练看起来与往常一样。
  • ConvNet 作为固定特征提取器:在这里,我们将冻结除最终全连接层之外的所有网络的权重。最后一个全连接层将被替换为具有随机权重的新层,并且只训练这一层。

一、加载数据 

        我们使用torchvision 和 torch.utils.data数据包进行数据加载,今天的任务是训练一个模型用来分辨蚂蚁和蜜蜂,我们有120张蚂蚁和蜜蜂的照片用于训练,以及75张用于测试蜜蜂和蚂蚁的照片。

数据集下载路径:MyDataset: 数据集仓库,包括各种网站搜刮的,以及一些自定义的数据。方便后续神经网络的训练 - Gitee.com

        通常,如果从头开始训练,这个数据集太小了,无法进行推广。由于我们使用迁移学习,我们就可以相当好地进行推广。

# Data augmentation and normalization for training
# Just normalization for validation
data_transforms = {'train': transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),'val': transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
}data_dir = 'data/hymenoptera_data'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),data_transforms[x])for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,shuffle=True, num_workers=4)for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classesdevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")#可视化部分图片,确认下效果
def imshow(inp, title=None):"""Display image for Tensor."""inp = inp.numpy().transpose((1, 2, 0))mean = np.array([0.485, 0.456, 0.406])std = np.array([0.229, 0.224, 0.225])inp = std * inp + meaninp = np.clip(inp, 0, 1)plt.imshow(inp)if title is not None:plt.title(title)plt.pause(0.001)  # pause a bit so that plots are updated# Get a batch of training data
inputs, classes = next(iter(dataloaders['train']))# Make a grid from batch
out = torchvision.utils.make_grid(inputs)
imshow(out, title=[class_names[x] for x in classes])

二、训练模型

编写一个通用函数来训练模型。

  • 安排这个learning rate
  • 保存模型

       参数 scheduler 是来自 torch.optim.lr_scheduler 的 LR 调度程序对象,关于这个scheduler在后续模型优化的章节会讲到,也是一个非常强大的功能,这里就先不赘述了。

def train_model(model, criterion, optimizer, scheduler, num_epochs=25):since = time.time()# Create a temporary directory to save training checkpointswith TemporaryDirectory() as tempdir:best_model_params_path = os.path.join(tempdir, 'best_model_params.pt')torch.save(model.state_dict(), best_model_params_path)best_acc = 0.0for epoch in range(num_epochs):print(f'Epoch {epoch}/{num_epochs - 1}')print('-' * 10)# Each epoch has a training and validation phasefor phase in ['train', 'val']:if phase == 'train':model.train()  # Set model to training modeelse:model.eval()   # Set model to evaluate moderunning_loss = 0.0running_corrects = 0# Iterate over data.for inputs, labels in dataloaders[phase]:inputs = inputs.to(device)labels = labels.to(device)# zero the parameter gradientsoptimizer.zero_grad()# forward# track history if only in trainwith torch.set_grad_enabled(phase == 'train'):outputs = model(inputs)_, preds = torch.max(outputs, 1)loss = criterion(outputs, labels)# backward + optimize only if in training phaseif phase == 'train':loss.backward()optimizer.step()# statisticsrunning_loss += loss.item() * inputs.size(0)running_corrects += torch.sum(preds == labels.data)if phase == 'train':scheduler.step()epoch_loss = running_loss / dataset_sizes[phase]epoch_acc = running_corrects.double() / dataset_sizes[phase]print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')# deep copy the modelif phase == 'val' and epoch_acc > best_acc:best_acc = epoch_acctorch.save(model.state_dict(), best_model_params_path)print()time_elapsed = time.time() - sinceprint(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')print(f'Best val Acc: {best_acc:4f}')# load best model weightsmodel.load_state_dict(torch.load(best_model_params_path, weights_only=True))return model

三、可视化模型预测

        经过上一节的训练,我们现在看看模型的预测结果怎么样。作为传统的程序开发者,我们习惯通过打印来验证结果,但是这玩意儿只有开发兄弟能懂哈。Pytorch毕竟是基于Python的深度学习框架,工具包是应用尽有,所以我们可以以可视化的图形来显示预测的效果。说个题外话,将工作结果可视化这个习惯,各位开发兄弟得学起来,以后就可以一手抓开发,一手抓产品,既可以跟开发同事一起奋斗,又可以跟老板以及客户吹牛逼,路就走宽了。

def visualize_model(model, num_images=6):was_training = model.trainingmodel.eval()images_so_far = 0fig = plt.figure()with torch.no_grad():for i, (inputs, labels) in enumerate(dataloaders['val']):inputs = inputs.to(device)labels = labels.to(device)outputs = model(inputs)_, preds = torch.max(outputs, 1)for j in range(inputs.size()[0]):images_so_far += 1ax = plt.subplot(num_images//2, 2, images_so_far)ax.axis('off')ax.set_title(f'predicted: {class_names[preds[j]]}')imshow(inputs.cpu().data[j])if images_so_far == num_images:model.train(mode=was_training)returnmodel.train(mode=was_training)

四、卷积网络微调

        上三节都是准备工作,这一节,我们就讲一下两种迁移学习的使用。

   微调 ConvNet:

#加载一个预训练的模型并且重置全连接层
model_ft = models.resnet18(weights='IMAGENET1K_V1')
num_ftrs = model_ft.fc.in_features
# Here the size of each output sample is set to 2.
# Alternatively, it can be generalized to ``nn.Linear(num_ftrs, len(class_names))``.
model_ft.fc = nn.Linear(num_ftrs, 2)
model_ft = model_ft.to(device)
criterion = nn.CrossEntropyLoss()# Observe that all parameters are being optimized
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler,num_epochs=25)
visualize_model(model_ft)

ConvNet 作为固定特征提取器:

        注意:这里,我们需要冻结除最后一层之外的所有网络。我们需要设置requires_grad = False来冻结参数,这样梯度就不会在backward()中计算。

model_conv = torchvision.models.resnet18(weights='IMAGENET1K_V1')
for param in model_conv.parameters():param.requires_grad = False# Parameters of newly constructed modules have requires_grad=True by default
num_ftrs = model_conv.fc.in_features
model_conv.fc = nn.Linear(num_ftrs, 2)model_conv = model_conv.to(device)
criterion = nn.CrossEntropyLoss()# Observe that only parameters of final layer are being optimized as
# opposed to before.
optimizer_conv = optim.SGD(model_conv.fc.parameters(), lr=0.001, momentum=0.9)# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1)model_conv = train_model(model_conv, criterion, optimizer_conv, exp_lr_scheduler, num_epochs=25)
visualize_model(model_conv)
plt.ioff()
plt.show()

五、自定义图像的推理

        使用训练的模型进行自定义图片的预测并且显示预测图片对应的标签

def visualize_model_predictions(model,img_path):was_training = model.trainingmodel.eval()img = Image.open(img_path)img = data_transforms['val'](img)img = img.unsqueeze(0)img = img.to(device)with torch.no_grad():outputs = model(img)_, preds = torch.max(outputs, 1)ax = plt.subplot(2,2,1)ax.axis('off')ax.set_title(f'Predicted: {class_names[preds[0]]}')imshow(img.cpu().data[0])model.train(mode=was_training)visualize_model_predictions(model_conv,img_path='data/hymenoptera_data/val/bees/72100438_73de9f17af.jpg'
)plt.ioff()
plt.show()

总结

    迁移学习其实是实际工作中用的非常多的一种神经网络开发方法,对于开发者来说,从头构建一个模型,开发难度很大,并且个人很难去实现它的训练,这个需要庞大的数据集以及场景测试。            

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/web/89801.shtml
繁体地址,请注明出处:http://hk.pswp.cn/web/89801.shtml
英文地址,请注明出处:http://en.pswp.cn/web/89801.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

数字IC后端培训教程之数字后端项目典型项目案例解析

数字IC后端低功耗设计实现案例分享(3个power domain,2个voltage domain) Q1: 电路如下图,clk是一个很慢的时钟test_clk(属于DFT的),DFF1与and 形成一个clock gating check。跑pr 发现,时钟树综合CTS阶段(C…

2025 Data Whale x PyTorch 安装学习笔记(Windows 版)

一、Anaconda 的安装与基本操作 1. 安装 Anaconda/miniconda 官方链接:Anaconda | Individual Edition 根据系统版本选择合适的安装包下载并安装。 2. 检验安装 打开 “开始” 菜单,找到 “Anaconda Prompt”(一般在 Anaconda3 文件夹…

mac OS上docker安装zookeeper

拉取镜像:$ docker pull zookeeper:3.5.7 3.5.7: Pulling from library/zookeeper 3.5.7: Pulling from library/zookeeper 3.5.7: Pulling from library/zookeeper no matching manifest for linux/arm64/v8 in the manifest list entries报错:由于时M3…

设备通过4G网卡接入EasyCVR视频融合平台,出现无法播放的问题排查和解决

EasyCVR视频融合平台作为支持多协议接入、多设备集中管理的综合性视频解决方案,可实现各类终端设备的视频流汇聚与实时播放。近期收到用户反馈,在EasyCVR平台接入设备后出现视频流无法播放的情况。为帮助更多用户快速排查同类问题,现将具体处…

板凳-------Mysql cookbook学习 (十二--------3)

第二章 抽象数据类型和python类 2.5类定义实例: 学校人事管理系统中的类 import datetimeclass PersonValueError(ValueError):"""自定义异常类"""passclass PersonTypeError(TypeError):"""自定义异常类""…

css flex 布局中 flex-direction为column,如何让子元素的宽度根据内容自动变化

在 display: flex 且 flex-direction: column 的布局中,默认情况下子元素会占满容器的宽度。要让子元素的宽度根据内容自适应,而不是自动拉伸填满父容器,你可以这样处理:✅ 解决方案一:设置子元素 align-self: start 或…

性能优化实践:Modbus 在高并发场景下的吞吐量提升(二)

四、Modbus 吞吐量提升实战策略4.1 优化网络配置选择合适的网络硬件是提升 Modbus 通信性能的基础。在工业现场,应优先选用高性能的工业级交换机和路由器。工业级交换机具备更好的抗干扰能力和稳定性,其背板带宽和包转发率更高,能够满足高并发…

上传ipa到appstore的几种工具

无论是用原生开发也好,使用uniapp或flutter开发也好,最好打包好的APP是需要上架appstore的。而在app store connect上架的时候,需要上传ipa文件到app store的构建版本上。因此,需要上传工具。下面分析下几种上传工具的优缺点&…

数控调压BUCK电路 —— 基于TPS56637(TI)

0 前言 本文基于 TI 的 TPS56637 实现一个支持调压的 BUCK 电路,包含从零开始详细的 原理解析、原理图、PCB 及 实测数据 本文属于《DIY迷你数控电源》系列,本系列我们一起实现一个简单的迷你数控电源 我是 LNY,一个在对嵌入式的所有都感兴…

prometheus UI 和node_exporter节点图形化Grafana

prometheus UI 和node_exporter节点图形化Grafana 先简单的安装一下 进行时间的同步操作安装Prometheus之前必须要先安装ntp时间同步,因为prometheus server对系统时间的准确性要求很高,必须保证本机时间实时同步。# 用crontab进行定时的时间的同步 yum …

RabbitMQ—TTL、死信队列、延迟队列

上篇文章: RabbitMQ—消息可靠性保证https://blog.csdn.net/sniper_fandc/article/details/149311576?fromshareblogdetail&sharetypeblogdetail&sharerId149311576&sharereferPC&sharesourcesniper_fandc&sharefromfrom_link 目录 1 TTL …

LVS 集群技术详解与实战部署

目录 引言 一、实验环境准备 二、理论基础:集群与 LVS 核心原理 2.1 集群与分布式 2.2 LVS 核心原理 LVS 的 4 种工作模式 LVS 调度算法 三、LVS 部署工具:ipvsadm 命令详解 四、实战案例:LVS 部署详解 案例 1:NAT 模式…

前端vue3获取excel二进制流在页面展示

excel二进制流在页面展示安装xlsx在页面中定义一个div来展示html数据定义二进制流请求接口拿到数据并展示安装xlsx npm install xlsx import * as XLSX from xlsx;在页面中定义一个div来展示html数据 <div class"file-input" id"file-input" v-html&qu…

android 信息验证动画效果

layout_check_pro <?xml version"1.0" encoding"utf-8"?> <LinearLayout xmlns:android"http://schemas.android.com/apk/res/android"android:id"id/parent"android:layout_width"wrap_content"android:layout_…

【iOS】继承链

文章目录前言什么是继承链OC中的根类关于NSProxy关键作用1.方法查找与动态绑定2. 消息转发3. **类型判断与多态**继承链的底层实现元类的继承链总结前言 在objective-c中&#xff0c;继承链是类与类之间通过父类&#xff08;Superclass&#xff09;关系形成的一层层继承结构&am…

论文阅读:Instruct BLIP (2023.5)

文章目录InstructBLIP&#xff1a;迈向通用视觉语言模型的指令微调研究总结一、研究背景与目标二、核心方法数据构建与划分模型架构训练策略三、实验结果零样本性能消融实验下游任务微调定性分析可视化结果展示四、结论与贡献InstructBLIP&#xff1a;迈向通用视觉语言模型的指…

Elasticsearch+Logstash+Filebeat+Kibana部署【7.1.1版本】

目录 一、准备阶段 二、实验阶段 1.配置kibana主机 2.配置elasticsearch主机 3.配置logstash主机 4.配置/etc/filebeat/filebeat.yml 三、验证 1.开启Filebeat 2.在logstash查看 3.浏览器访问kibana 一、准备阶段 1.准备四台主机kibana、es、logstash、filebeat 2.在…

Vue开发前端报错:‘vue-cli-service‘ 不是内部或外部命令解决方案

1.Bug: 最近调试一个现有的Vue前端代码&#xff0c;发现如下错误&#xff1a; vue-cli-service’ 不是内部或外部命令&#xff0c;也不是可运行的程序 或批处理文件。 2.Bug原因&#xff1a; 导入的工程缺少依赖包&#xff1a;即缺少node_modules文件夹 3.解决方案&#xff1…

AI生态,钉钉再「出招」

如果说之前钉钉的AI生态加持更多的围绕资源和商业的底层助力&#xff0c;那么如今这种加持则是向更深层次进化&#xff0c;即真正的AI模型训练能力加持&#xff0c;为垂类大模型创业者提供全方位的助力&#xff0c;提高创业成功率和模型产品商业化确定性。作者|皮爷出品|产业家…

XSS GAME靶场

要求用户不参与&#xff0c;触发alert(1337) 目录 Ma Spaghet! Jefff Ugandan Knuckles Ricardo Milos Ah Thats Hawt Ligma Mafia Ok, Boomer Exmaple 1 - Create Example 2 - Overwrite Example 3 - Overwrite2 toString Ma Spaghet! <h2 id"spaghet&qu…