在 https://blog.csdn.net/fengbingchun/article/details/149878692 中介绍了深度学习中的模型知识蒸馏,这里通过已训练的DenseNet分类模型,基于响应的知识蒸馏实现通过教师模型生成学生模型:

      1. 依赖的模块如下所示:

import argparse
import colorama
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
import ast
import time
from pathlib import Path

      2. 支持的输入参数如下所示:

def parse_args():parser = argparse.ArgumentParser(description="model knowledge distillation")parser.add_argument("--task", required=True, type=str, choices=["train", "predict"], help="specify what kind of task")parser.add_argument("--src_model", type=str, help="source model name")parser.add_argument("--dst_model", type=str, help="distilled model name")parser.add_argument("--classes_number", type=int, default=2, help="classes number")parser.add_argument("--mean", type=str, help="the mean of the training set of images")parser.add_argument("--std", type=str, help="the standard deviation of the training set of images")parser.add_argument("--labels_file", type=str, help="one category per line, the format is: index class_name")parser.add_argument("--images_path", type=str, help="predict images path")parser.add_argument("--epochs", type=int, default=500, help="number of training")parser.add_argument("--lr", type=float, default=0.0001, help="learning rate")parser.add_argument("--drop_rate", type=float, default=0.2, help="dropout rate")parser.add_argument("--dataset_path", type=str, help="source dataset path")parser.add_argument("--temperature", type=float, default=2.0, help="temperature, higher the temperature, the better it expresses the teacher's knowledge: [2.0, 4.0]")parser.add_argument("--alpha", type=float, default=0.7, help="teacher weight coefficient, generally, the larger the alpha, the more dependent on the teacher's guidance: [0.5, 0.9]")args = parser.parse_args()return args

      3. 定义学生网络类StudentModel:

class StudentModel(nn.Module):def __init__(self, classes_number=2, drop_rate=0.2):super().__init__()self.features = nn.Sequential( # four convolutional blocksnn.Conv2d(3, 32, 3, padding=1),nn.BatchNorm2d(32),nn.ReLU(inplace=True),nn.MaxPool2d(2, 2),nn.Conv2d(32, 64, 3, padding=1),nn.BatchNorm2d(64),nn.ReLU(inplace=True),nn.MaxPool2d(2, 2),nn.Conv2d(64, 128, 3, padding=1),nn.BatchNorm2d(128),nn.ReLU(inplace=True),nn.MaxPool2d(2, 2),nn.Conv2d(128, 256, 3, padding=1),nn.BatchNorm2d(256),nn.ReLU(inplace=True),nn.AdaptiveAvgPool2d((1, 1)))self.classifier = nn.Sequential(nn.Linear(256, 128),nn.ReLU(inplace=True),nn.Dropout(drop_rate),nn.Linear(128, classes_number))def forward(self, x):x = self.features(x)x = x.view(x.size(0), -1)x = self.classifier(x)return x

      (1).为知识蒸馏设计的轻量级卷积神经网络,用作student模型来学习复杂teacher模型的知识。

      (2).self.features为特征提取器,通过4个卷积块逐步提取从低级到高级的特征,因为使用了nn.AdaptiveAvgPool2d,可以支持不同大小彩色图像的输入。

      (3).self.classifier为分类器,2个全连接层(第一个全连接层降维,第二个全连接层分类),dropout层防止过拟合。

      4. 训练代码如下:

def print_student_model_parameters():student_model = StudentModel()print("student model parameters: ", student_model)tensor = torch.rand(1, 3, 224, 224)student_model.eval()output = student_model(tensor)print(f"output: {output}; output.shape: {output.shape}")def _str2tuple(value):if not isinstance(value, tuple):value = ast.literal_eval(value) # str to tuplereturn valuedef _load_dataset(dataset_path, mean, std, batch_size):mean = _str2tuple(mean)std = _str2tuple(std)train_transform = transforms.Compose([transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=mean, std=std), # RGB])train_dataset = ImageFolder(root=dataset_path+"/train", transform=train_transform)print(f"train dataset length: {len(train_dataset)}; classes: {train_dataset.class_to_idx}; number of categories: {len(train_dataset.class_to_idx)}")train_loader = DataLoader(train_dataset, batch_size, shuffle=True, num_workers=0)val_transform = transforms.Compose([transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=mean, std=std), # RGB])val_dataset = ImageFolder(root=dataset_path+"/val", transform=val_transform)print(f"val dataset length: {len(val_dataset)}; classes: {val_dataset.class_to_idx}")assert len(train_dataset.class_to_idx) == len(val_dataset.class_to_idx), f"the number of categories int the train set must be equal to the number of categories in the validation set: {len(train_dataset.class_to_idx)} : {len(val_dataset.class_to_idx)}"val_loader = DataLoader(val_dataset, batch_size, shuffle=True, num_workers=0)return len(train_dataset), len(val_dataset), train_loader, val_loaderdef _distillation_loss(student_logits, teacher_logits, labels, temperature, alpha):# hard label loss(student model vs. true label)hard_loss = nn.CrossEntropyLoss()(student_logits, labels)# soft label loss(student model vs. teacher model)soft_loss = nn.KLDivLoss(reduction='batchmean')(F.log_softmax(student_logits / temperature, dim=1),F.softmax(teacher_logits / temperature, dim=1)) * (temperature ** 2)return alpha * soft_loss + (1 - alpha) * hard_lossdef train(src_model, dst_model, device, classes_number, drop_rate, mean, std, dataset_path, epochs, lr, temperature, alpha):teacher_model = models.densenet121(weights=None)teacher_model.classifier = nn.Linear(teacher_model.classifier.in_features, classes_number)teacher_model.load_state_dict(torch.load(src_model, weights_only=False, map_location="cpu"))teacher_model.to(device)student_model = StudentModel(classes_number, drop_rate).to(device)train_dataset_num, val_dataset_num, train_loader, val_loader = _load_dataset(dataset_path, mean, std, 4)optimizer = optim.Adam(student_model.parameters(), lr)highest_accuracy = 0.minimum_loss = 100.for epoch in range(epochs):epoch_start = time.time()train_loss = 0.0train_acc = 0.0val_loss = 0.0val_acc = 0.0student_model.train()teacher_model.eval()for _, (inputs, labels) in enumerate(train_loader):inputs = inputs.to(device)labels = labels.to(device)with torch.no_grad():teacher_outputs = teacher_model(inputs)student_outputs = student_model(inputs)loss = _distillation_loss(student_outputs, teacher_outputs, labels, temperature, alpha)optimizer.zero_grad()loss.backward()optimizer.step()train_loss += loss.item() * inputs.size(0)_, predictions = torch.max(student_outputs.data, 1)correct_counts = predictions.eq(labels.data.view_as(predictions))acc = torch.mean(correct_counts.type(torch.FloatTensor))train_acc += acc.item() * inputs.size(0)student_model.eval()with torch.no_grad():for _, (inputs, labels) in enumerate(val_loader):inputs = inputs.to(device)labels = labels.to(device)outputs = student_model(inputs)loss = nn.CrossEntropyLoss()(outputs, labels)val_loss += loss.item() * inputs.size(0)_, predictions = torch.max(outputs.data, 1)correct_counts = predictions.eq(labels.data.view_as(predictions))acc = torch.mean(correct_counts.type(torch.FloatTensor))val_acc += acc.item() * inputs.size(0)avg_train_loss = train_loss / train_dataset_numavg_train_acc = train_acc / train_dataset_numavg_val_loss = val_loss / val_dataset_numavg_val_acc = val_acc / val_dataset_numepoch_end = time.time()print(f"epoch:{epoch+1}/{epochs}; train loss:{avg_train_loss:.6f}, accuracy:{avg_train_acc:.6f}; validation loss:{avg_val_loss:.6f}, accuracy:{avg_val_acc:.6f}; time:{epoch_end-epoch_start:.2f}s")if highest_accuracy < avg_val_acc and minimum_loss > avg_val_loss:torch.save(student_model.state_dict(), dst_model)highest_accuracy = avg_val_accminimum_loss = avg_val_lossif avg_val_loss < 0.0001 or avg_val_acc > 0.9999:print(colorama.Fore.YELLOW + "stop training early")torch.save(student_model.state_dict(), dst_model)break

      (1).使用PyTorch中的类ImageFolder和DataLoader加载数据集。

      (2).蒸馏损失:

      1).硬标签(one-hot)损失函数:交叉熵,nn.CrossEntropyLoss,来源于ground truth。

      2).软标签(teacher预测概率分布)损失函数:KL散度,nn.KLDivLoss,来源于teacher输出,软标签携带了更多类别间的相似度信息,有助于student模型更好地泛化。

      3).temperature和权重比例alpha参数:temperature越高,soft标签越"软",更能表达teacher的知识;通常alpha越大,表示更依赖teacher的指导。

      (3).训练代码中加入了早停机制,当验证损失小于指定的值或验证准确度大于指定的值时停止训练。

      训练输出结果如下图所示:

      5. 预测代码如下所示:

def _parse_labels_file(labels_file):classes = {}with open(labels_file, "r") as file:for line in file:idx_value = []for v in line.split(" "):idx_value.append(v.replace("\n", "")) # remove line breaks(\n) at the end of the lineassert len(idx_value) == 2, f"the length must be 2: {len(idx_value)}"classes[int(idx_value[0])] = idx_value[1]return classesdef _get_images_list(images_path):image_names = []p = Path(images_path)for subpath in p.rglob("*"):if subpath.is_file():image_names.append(subpath)return image_namesdef predict(model_name, labels_file, images_path, mean, std, device):classes = _parse_labels_file(labels_file)assert len(classes) != 0, "the number of categories can't be 0"image_names = _get_images_list(images_path)assert len(image_names) != 0, "no images found"mean = _str2tuple(mean)std = _str2tuple(std)model = StudentModel(len(classes)).to(device)model.load_state_dict(torch.load(model_name, weights_only=False, map_location="cpu"))model.to(device)model.eval()with torch.no_grad():for image_name in image_names:input_image = Image.open(image_name)preprocess = transforms.Compose([transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=mean, std=std) # RGB])input_tensor = preprocess(input_image) # (c,h,w)input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model, (1,c,h,w)input_batch = input_batch.to(device)output = model(input_batch)probabilities = torch.nn.functional.softmax(output[0], dim=0) # the output has unnormalized scores, to get probabilities, you can run a softmax on itmax_value, max_index = torch.max(probabilities, dim=0)print(f"{image_name.name}\t{classes[max_index.item()]}\t{max_value.item():.4f}")

     预测输出结果如下图所示:

      6. 入口函数如下所示:

if __name__ == "__main__":colorama.init(autoreset=True)args = parse_args()# print_student_model_parameters()device = torch.device("cuda" if torch.cuda.is_available() else "cpu")if args.task == "train":train(args.src_model, args.dst_model, device, args.classes_number, args.drop_rate, args.mean, args.std,args.dataset_path, args.epochs, args.lr, args.temperature, args.alpha)else:predict(args.dst_model, args.labels_file, args.images_path, args.mean, args.std, device)print(colorama.Fore.GREEN + "====== execution completed ======")

      以上代码中不通过蒸馏直接通过StudentModel也可以生成分类模型,使用蒸馏的优势在于:

      (1).可以学习到教师模型的"软知识",从头训练时,学生模型只能依赖硬标签(one-hot),无法知道类别之间的相似性。蒸馏中,学生会去拟合教师模型的软标签,这有助于学生模型泛化。

      (2).从头训练时,学生模型需要足够多且多样化的数据才能学到泛化能力。有了教师模型,即使训练数据不多,学生也能通过教师的软标签进行监督。

      (3).通过蒸馏会使学生模型训练更稳定,收敛更快。

      (4).从头训练学生模型可能达不到教师模型的性能,而蒸馏能让学生模型模仿教师模型的决策边界,从而在计算量相同的情况下获得更好的准确率。

      GitHub:https://github.com/fengbingchun/NN_Test

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/diannao/94982.shtml
繁体地址,请注明出处:http://hk.pswp.cn/diannao/94982.shtml
英文地址,请注明出处:http://en.pswp.cn/diannao/94982.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【数据可视化-82】中国城市幸福指数可视化分析:Python + PyEcharts 打造炫酷城市幸福指数可视化大屏

&#x1f9d1; 博主简介&#xff1a;曾任某智慧城市类企业算法总监&#xff0c;目前在美国市场的物流公司从事高级算法工程师一职&#xff0c;深耕人工智能领域&#xff0c;精通python数据挖掘、可视化、机器学习等&#xff0c;发表过AI相关的专利并多次在AI类比赛中获奖。CSDN…

TikTok网页版访问障碍破解:IP限制到高效运营的全流程指南

在跨境电商与社媒运营的数字化浪潮中&#xff0c;TikTok网页版因其多账号管理便捷性、内容采集高效性等优势&#xff0c;成为从业者的核心工具&#xff0c;然而“页面空白”“地区不支持” 等访问问题却频繁困扰用户。一、TikTok网页版的核心应用场景与技术特性&#xff08;一&…

spring的知识点:容器、AOP、事物

一、Spring 是什么? Spring 是一个开源的 Java 企业级应用框架,它的核心目标是简化 Java 开发。 它不是单一的工具,而是一个 “生态系统”,包含了很多模块(如 Spring Core、Spring Boot、Spring MVC 等),可以解决开发中的各种问题(如对象管理、Web 开发、事务控制等)…

HTML ISO-8859-1:深入解析字符编码标准

HTML ISO-8859-1:深入解析字符编码标准 引言 在HTML文档中,字符编码的选择对于确保网页内容的正确显示至关重要。ISO-8859-1是一种广泛使用的字符编码标准,它定义了256个字符,覆盖了大多数西欧语言。本文将深入探讨HTML ISO-8859-1的原理、应用及其在现代网页开发中的重要…

【计算机网络 | 第4篇】分组交换

文章目录前言&#x1f95d;电路交换&#x1f34b;电路交换技术的优缺点电路交换的资源分配机制报文交换&#x1f34b;报文交换技术的优缺点存储转发技术分组交换&#x1f426;‍&#x1f525;分组交换的过程分组交换解决的关键问题传输过程的关键参数工作原理分组传输时延计算网…

LLM - AI大模型应用集成协议三件套 MCP、A2A与AG-UI

文章目录1. 引言&#xff1a;背景与三协议概览2. MCP&#xff08;Model Context Protocol&#xff09;起源与动因架构与规范要点开发实践3. A2A&#xff08;Agent-to-Agent Protocol&#xff09;起源与动因架构与规范要点开发实践4. AG-UI&#xff08;Agent-User Interaction P…

机器学习DBSCAN密度聚类

引言 在机器学习的聚类任务中&#xff0c;K-means因其简单高效广为人知&#xff0c;但它有一个致命缺陷——假设簇是球形且密度均匀&#xff0c;且需要预先指定簇数。当数据存在任意形状的簇、噪声点或密度差异较大时&#xff0c;K-means的表现往往不尽如人意。这时候&#xff…

RecyclerView 缓存机制

一、四级缓存体系1. Scrap 缓存&#xff08;临时缓存&#xff09;位置&#xff1a;mAttachedScrap 和 mChangedScrap作用&#xff1a;存储当前屏幕可见但被标记为移除的 ViewHolder用于局部刷新&#xff08;如 notifyItemChanged()&#xff09;特点&#xff1a;生命周期短&…

大模型SSE流式输出技术

文章目录背景&#xff1a;为什么需要流式输出SSE 流式输出很多厂商还是小 chunk背景&#xff1a;为什么需要流式输出 大模型的响应通常很长&#xff0c;比如几百甚至几千个 token&#xff0c;如果等模型一次性生成完才返回&#xff1a; 延迟高&#xff1a;用户要等很久才能看…

[Flutter] v3.24 AAPT:错误:未找到资源 android:attr/lStar。

推荐超级课程&#xff1a; 本地离线DeepSeek AI方案部署实战教程【完全版】Docker快速入门到精通Kubernetes入门到大师通关课AWS云服务快速入门实战 前提 将 Flutter 升级到 3.24.4 后&#xff0c;构建在我的本地电脑上通过&#xff0c;但Github actions 构建时失败。 Flutt…

go语言标准库学习, fmt标准输出,Time 时间,Flag,Log日志,Strconv

向外输出 fmt包实现了类似C语言printf和scanf的格式化I/O。主要分为向外输出内容和获取输入内容两大部分。 内置输出 不需要引入标准库&#xff0c;方便 package mainfunc main() {print("我是控制台打印&#xff0c;我不换行 可以自己控制换行 \n我是另一行")prin…

ElementUI之表格

文章目录使用ElementUI使用在线引入的方式表格1. 带状态表格row-class-name"Function({row, rowIndex})/String"2. 固定表头(height"string/number"属性)2.1 属性的取值2.2 动态响应式高度使用演示2.3 ​​自定义滚动条样式​​2.4 表头高度定制获取一行信…

K8S 的 Master组件

K8S 的 Master 组件有哪些&#xff1f;每个组件的作用&#xff1f; K8s 大脑的 4 大核心模块&#xff0c;掌控全局&#xff01; Kubernetes 集群的 Master&#xff08;主节点&#xff09; 就像一座 指挥中心&#xff0c;负责整个集群的调度、管理和控制。它由 4 大核心组件组成…

如何 让ubuntu 在root 下安装的docker 在 普通用户下也能用

在 Ubuntu 系统中&#xff0c;如果 Docker 是以 root 用户安装的&#xff0c;普通用户默认无法直接使用 Docker 命令&#xff08;会报权限错误&#xff09;。要让普通用户也能使用 Docker&#xff0c;可以按照以下步骤操作&#xff1a;方法 1&#xff1a;将用户加入 docker 用户…

模板方法模式:优雅封装算法骨架

目录 一、模板方法模式 1、结构 2、特性 3、优缺点 3.1、优点 3.2、缺点 4、使用场景 5、实现示例 5.1、抽象类 5.2、实现类 5.3、测试类 一、模板方法模式 模板方法模式&#xff08;Template Method Pattern&#xff09;是一种行为设计模式&#xff0c;它在一个方…

韦东山STM32_HAl库入门教程(SPI)学习笔记[09]内容

&#xff08;1&#xff09;SPI程序层次一、核心逻辑&#xff1a;“SPI Flash 操作” 是怎么跑起来的&#xff1f;要读写 SPI Flash&#xff0c;需同时理解 硬件连接&#xff08;怎么接线&#xff09; 和 软件分层&#xff08;谁负责发指令、谁负责控制逻辑&#xff09;&#xf…

线上Linux服务器的优化设置、系统安全与网络安全策略

一、Linux服务器的优化设置 线上Linux的优化配置序号基础优化配置内容说明1最小化安装系统【仅安装需要的&#xff0c;按需安装、不用不装】&#xff0c;必须安装的有基本开发环境、基本网络包、基本应用包。2ssh登录策略优化 Linux服务器上的ssh服务端配置文件是【/et…

基于人眼视觉特性的相关图像增强基础知识介绍

目录 1. 传统的灰度级动态范围优化配置方法 2.基于视觉特性的灰度级动态范围调整优化 1. 传统的灰度级动态范围优化配置方法 传统的灰度级动态范围调整方法主要包括线性动态范围调整及非线性动态 范围调整。线性动态范围调整是最简单的灰度级动态范围调整方法&#xff0c;观察…

Selenium使用超全指南

&#x1f345; 点击文末小卡片 &#xff0c;免费获取软件测试全套资料&#xff0c;资料在手&#xff0c;涨薪更快概述selenium是网页应用中最流行的自动化测试工具&#xff0c;可以用来做自动化测试或者浏览器爬虫等。官网地址为&#xff1a;相对于另外一款web自动化测试工具QT…

Go通道操作全解析:从基础到高并发模式

一、channel类型 Go 语言中的通道(channel)是一种特殊的类型。它类似于传送带或队列,遵循先进先出(FIFO)原则,确保数据收发顺序的一致性。每个通道都是特定类型的导管,因此在声明时必须指定其元素类型。 channel是一种类型, 一种引用类型。 声明通道类型的格式如下:…