知识点回顾:

  1. 预训练的概念
  2. 常见的分类预训练模型
  3. 图像预训练模型的发展史
  4. 预训练的策略
  5. 预训练代码实战:resnet18

一、预训练的概念

我们之前在训练中发现,准确率最开始随着epoch的增加而增加。随着循环的更新,参数在不断发生更新。

所以参数的初始值对训练结果有很大的影响:

  1. 如果最开始的初始值比较好,后续训练轮数就会少很多
  2. 很有可能陷入局部最优值,不同的初始值可能导致陷入不同的局部最优值

所以很自然的想到,如果最开始能有比较好的参数,即可能导致未来训练次数少,也可能导致未来训练避免陷入局部最优解的问题。这就引入了一个概念,即预训练模型。

如果别人在某些和我们目标数据类似的大规模数据集上做过训练,我们可以用他的训练参数来初始化我们的模型,这样我们的模型就比较容易收敛。

为了帮助你们理解,这里提出几个自问自答的问题。

  1. 那为什么要选择类似任务的数据集预训练的模型参数呢?

因为任务差不多,他提取特征的能力才有用,如果任务相差太大,他的特征提取能力就没那么好。
所以本质预训练就是拿别人已经具备的通用特征提取能力来接着强化能力使之更加适应我们的数据集和任务。

  1. 为什么要求预训练模型是在大规模数据集上训练的,小规模不行么?
    因为提取的是通用特征,所以如果数据集数据少、尺寸小,就很难支撑复杂任务学习通用的数据特征。比如你是一个物理的博士,让你去做小学数学题,很快就能上手;但是你是一个小学数学速算高手,让你做物理博士的课题,就很困难。所以预训练模型一般就挺强的。

我们把用预训练模型的参数,然后接着在自己数据集上训练来调整该参数的过程叫做微调,这种思想叫做迁移学习。把预训练的过程叫做上游任务,把微调的过程叫做下游任务。

现在再来看下之前一直用的cifar10数据集,他是不是就很明显不适合作为预训练数据集?

  1. 规模过小:仅 10 万张图像,且尺寸小(32x32),无法支撑复杂模型学习通用视觉特征;
  2. 类别单一:仅 10 类(飞机、汽车等),泛化能力有限;

这里给大家介绍一个常常用来做预训练的数据集,ImageNet,ImageNet 1000 个类别,有 1.2 亿张图像,尺寸 224x224,数据集大小 1.4G,下载地址:http://www.image-net.org/。

二、 经典的预训练模型

2.1 CNN架构预训练模型

模型预训练数据集核心特点在CIFAR10上的适配要点
AlexNetImageNet首次引入ReLU/局部响应归一化,参数量6000万+需修改首层卷积核大小(原11x11→适配32x32)
VGG16ImageNet纯卷积堆叠,结构统一,参数量1.38亿冻结前10层卷积,仅微调全连接层
ResNet18ImageNet残差连接解决梯度消失,参数量1100万直接适配32x32输入,需调整池化层步长
MobileNetV2ImageNet深度可分离卷积,参数量350万+轻量级设计,适合计算资源有限的场景

2.2 Transformer类预训练模型

适用于较大尺图像(如224x224),在CIFAR10上需上采样图像尺寸调整Patch大小

模型预训练数据集核心特点在CIFAR10上的适配要点
ViT-BaseImageNet-21K纯Transformer架构,参数量8600万图像Resize至224x224,Patch大小设为4x4
Swin TransformerImageNet-22K分层窗口注意力,参数量8000万+需调整窗口大小适配小图像
DeiTImageNet结合CNN归纳偏置,参数量2200万轻量级Transformer,适合中小尺寸图像

2.3 自监督预训练模型

无需人工标注,通过 pretext task(如掩码图像重建)学习特征,适合数据稀缺场景。

模型预训练方式典型数据集在CIFAR10上的优势
MoCo v3对比学习ImageNet无需标签即可迁移,适合无标注数据
BEiT掩码图像建模ImageNet-22K特征语义丰富,微调时收敛更快

三、常见的分类预训练模型介绍

3.1 预训练模型的发展史

模型年份提出团队关键创新点层数参数量ImageNet Top-5错误率典型应用场景预训练权重可用性
LeNet-51998Yann LeCun等首个CNN架构,卷积层+池化层+全连接层,Sigmoid激活函数7~60KN/A手写数字识别(MNIST)无(历史模型)
AlexNet2012Alex Krizhevsky等ReLU激活函数、Dropout、数据增强、GPU训练860M15.3%大规模图像分类PyTorch/TensorFlow官方支持
VGGNet2014Oxford VGG团队统一3×3卷积核、多尺度特征提取、结构简洁16/19138M/144M7.3%/7.0%图像分类、目标检测基础骨干网络PyTorch/TensorFlow官方支持
GoogLeNet2014GoogleInception模块(多分支并行卷积)、1×1卷积降维、全局平均池化225M6.7%大规模图像分类PyTorch/TensorFlow官方支持
ResNet2015何恺明等残差连接(解决梯度消失)、Batch Normalization18/50/15211M/25M/60M3.57%/3.63%/3.58%图像/视频分类、检测、分割PyTorch/TensorFlow官方支持
DenseNet2017Gao Huang等密集连接(每层与后续所有层相连)、特征复用、参数效率高121/1698M/14M2.80%小数据集、医学图像处理PyTorch/TensorFlow官方支持
MobileNet2017Google深度可分离卷积(减少75%计算量)、轻量级设计284.2M7.4%移动端图像分类/检测PyTorch/TensorFlow官方支持
EfficientNet2019Google复合缩放(同时优化深度、宽度、分辨率)、NAS搜索最佳配置B0-B75.3M-66M2.6% (B7)高精度图像分类(资源受限场景)PyTorch/TensorFlow官方支持

上图的层数,代表该模型不同的版本resnet有resnet18、resnet50、resnet152;efficientnet有efficientnet-b0、efficientnet-b1、efficientnet-b2、efficientnet-b3、efficientnet-b4等

其中ImageNet Top - 5 准确率是图像分类任务里的一种评估指标 ,用于衡量模型在 ImageNet 数据集上的分类性能,模型对图像进行分类预测,输出所有类别(共 1000 类 )的概率,取概率排名前五的类别,只要这五个类别里包含人工标注的正确类别,就算预测正确。

模型架构演进关键点总结

  1. 深度突破:从LeNet的7层到ResNet152的152层,残差连接解决了深度网络的训练难题。 ----没上过我复试班cv部分的自行去了解下什么叫做残差连接,很重要!
  2. 计算效率:GoogLeNet(Inception)和MobileNet通过结构优化,在保持精度的同时大幅降低参数量。
  3. 特征复用:DenseNet的密集连接设计使模型能更好地利用浅层特征,适合小数据集。
  4. 自动化设计:EfficientNet使用神经架构搜索(NAS)自动寻找最优网络配置,开创了AutoML在CNN中的应用。

预训练模型使用建议

任务需求推荐模型理由
快速原型开发ResNet50/18结构平衡,预训练权重稳定,社区支持完善
移动端部署MobileNetV3参数量小,计算高效,专为移动设备优化
高精度分类(资源充足)EfficientNet-B7目前ImageNet准确率领先,适合GPU/TPU环境
小数据集或特征复用需求DenseNet密集连接设计减少过拟合,特征复用能力强
多尺度特征提取Inception-ResNet结合Inception多分支和ResNet残差连接,适合复杂场景

这些模型的预训练权重均可通过主流框架(如PyTorch的torchvision.models、Keras的applications模块)直接加载,便于快速迁移到新任务。
总结:CNN 架构发展脉络

  1. 早期探索(1990s-2010s):LeNet 验证 CNN 可行性,但受限于计算和数据。
  2. 深度学习复兴(2012-2015):AlexNet、VGGNet、GoogLeNet 通过加深网络和结构创新突破性能。
  3. 超深网络时代(2015 年后):ResNet 解决退化问题,开启残差连接范式,后续模型围绕效率(MobileNet)、特征复用(DenseNet)、多分支结构(Inception)等方向优化。

3.1 预训练模型的训练策略

那么什么模型会被选为预训练模型呢?比如一些调参后表现很好的cnn神经网络(固定的神经元个数+固定的层数等)。

所以调用预训练模型做微调,本质就是 用这些固定的结构+之前训练好的参数 接着训练

所以需要找到预训练的模型结构并且加载模型参数

相较于之前用自己定义的模型有以下几个注意点

  1. 需要调用预训练模型和加载权重
  2. 需要resize 图片让其可以适配模型
  3. 需要修改最后的全连接层以适应数据集

其中,训练过程中,为了不破坏最开始的特征提取器的参数,最开始往往先冻结住特征提取器的参数,然后训练全连接层,大约在5-10个epoch后解冻训练。

主要做特征提取的部分叫做backbone骨干网络;负责融合提取的特征的部分叫做Featue Pyramid Network(FPN);负责输出的预测部分的叫做Head。

首先复用下之前的代码

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt# 设置中文字体支持
plt.rcParams["font.family"] = ["SimHei"]
plt.rcParams['axes.unicode_minus'] = False  # 解决负号显示问题# 检查GPU是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")# 1. 数据预处理(训练集增强,测试集标准化)
train_transform = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),transforms.RandomRotation(15),transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])test_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])# 2. 加载CIFAR-10数据集
train_dataset = datasets.CIFAR10(root='./data',train=True,download=True,transform=train_transform
)test_dataset = datasets.CIFAR10(root='./data',train=False,transform=test_transform
)# 3. 创建数据加载器(可调整batch_size)
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)# 4. 训练函数(支持学习率调度器)
def train(model, train_loader, test_loader, criterion, optimizer, scheduler, device, epochs):model.train()  # 设置为训练模式train_loss_history = []test_loss_history = []train_acc_history = []test_acc_history = []all_iter_losses = []iter_indices = []for epoch in range(epochs):running_loss = 0.0correct_train = 0total_train = 0for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()# 记录Iteration损失iter_loss = loss.item()all_iter_losses.append(iter_loss)iter_indices.append(epoch * len(train_loader) + batch_idx + 1)# 统计训练指标running_loss += iter_loss_, predicted = output.max(1)total_train += target.size(0)correct_train += predicted.eq(target).sum().item()# 每100批次打印进度if (batch_idx + 1) % 100 == 0:print(f"Epoch {epoch+1}/{epochs} | Batch {batch_idx+1}/{len(train_loader)} "f"| 单Batch损失: {iter_loss:.4f}")# 计算 epoch 级指标epoch_train_loss = running_loss / len(train_loader)epoch_train_acc = 100. * correct_train / total_train# 测试阶段model.eval()correct_test = 0total_test = 0test_loss = 0.0with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)test_loss += criterion(output, target).item()_, predicted = output.max(1)total_test += target.size(0)correct_test += predicted.eq(target).sum().item()epoch_test_loss = test_loss / len(test_loader)epoch_test_acc = 100. * correct_test / total_test# 记录历史数据train_loss_history.append(epoch_train_loss)test_loss_history.append(epoch_test_loss)train_acc_history.append(epoch_train_acc)test_acc_history.append(epoch_test_acc)# 更新学习率调度器if scheduler is not None:scheduler.step(epoch_test_loss)# 打印 epoch 结果print(f"Epoch {epoch+1} 完成 | 训练损失: {epoch_train_loss:.4f} "f"| 训练准确率: {epoch_train_acc:.2f}% | 测试准确率: {epoch_test_acc:.2f}%")# 绘制损失和准确率曲线plot_iter_losses(all_iter_losses, iter_indices)plot_epoch_metrics(train_acc_history, test_acc_history, train_loss_history, test_loss_history)return epoch_test_acc  # 返回最终测试准确率# 5. 绘制Iteration损失曲线
def plot_iter_losses(losses, indices):plt.figure(figsize=(10, 4))plt.plot(indices, losses, 'b-', alpha=0.7)plt.xlabel('Iteration(Batch序号)')plt.ylabel('损失值')plt.title('训练过程中的Iteration损失变化')plt.grid(True)plt.show()# 6. 绘制Epoch级指标曲线
def plot_epoch_metrics(train_acc, test_acc, train_loss, test_loss):epochs = range(1, len(train_acc) + 1)plt.figure(figsize=(12, 5))# 准确率曲线plt.subplot(1, 2, 1)plt.plot(epochs, train_acc, 'b-', label='训练准确率')plt.plot(epochs, test_acc, 'r-', label='测试准确率')plt.xlabel('Epoch')plt.ylabel('准确率 (%)')plt.title('准确率随Epoch变化')plt.legend()plt.grid(True)# 损失曲线plt.subplot(1, 2, 2)plt.plot(epochs, train_loss, 'b-', label='训练损失')plt.plot(epochs, test_loss, 'r-', label='测试损失')plt.xlabel('Epoch')plt.ylabel('损失值')plt.title('损失值随Epoch变化')plt.legend()plt.grid(True)plt.tight_layout()plt.show()
# 导入ResNet模型
from torchvision.models import resnet18# 定义ResNet18模型(支持预训练权重加载)
def create_resnet18(pretrained=True, num_classes=10):# 加载预训练模型(ImageNet权重)model = resnet18(pretrained=pretrained)# 修改最后一层全连接层,适配CIFAR-10的10分类任务in_features = model.fc.in_featuresmodel.fc = nn.Linear(in_features, num_classes)# 将模型转移到指定设备(CPU/GPU)model = model.to(device)return model
# 创建ResNet18模型(加载ImageNet预训练权重,不进行微调)
model = create_resnet18(pretrained=True, num_classes=10)
model.eval()  # 设置为推理模式# 测试单张图片(示例)
from torchvision import utils# 从测试数据集中获取一张图片
dataiter = iter(test_loader)
images, labels = dataiter.next()
images = images[:1].to(device)  # 取第1张图片# 前向传播
with torch.no_grad():outputs = model(images)_, predicted = torch.max(outputs.data, 1)# 显示图片和预测结果
plt.imshow(utils.make_grid(images.cpu(), normalize=True).permute(1, 2, 0))
plt.title(f"预测类别: {predicted.item()}")
plt.axis('off')
plt.show()

在 CIFAR-10 数据集 中,类别标签是固定的 10 个,分别对应:

标签(数字)类别名称说明
0airplane飞机
1automobile汽车(含轿车、卡车等)
2bird鸟类
3cat
4deer鹿
5dog
6frog青蛙
7horse
8ship
9truck卡车(重型货车等)
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import os# 设置中文字体支持
plt.rcParams["font.family"] = ["SimHei"]
plt.rcParams['axes.unicode_minus'] = False  # 解决负号显示问题# 检查GPU是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")# 1. 数据预处理(训练集增强,测试集标准化)
train_transform = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),transforms.RandomRotation(15),transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])test_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])# 2. 加载CIFAR-10数据集
train_dataset = datasets.CIFAR10(root='./data',train=True,download=True,transform=train_transform
)test_dataset = datasets.CIFAR10(root='./data',train=False,transform=test_transform
)# 3. 创建数据加载器
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)# 4. 定义ResNet18模型
def create_resnet18(pretrained=True, num_classes=10):model = models.resnet18(pretrained=pretrained)# 修改最后一层全连接层in_features = model.fc.in_featuresmodel.fc = nn.Linear(in_features, num_classes)return model.to(device)# 5. 冻结/解冻模型层的函数
def freeze_model(model, freeze=True):"""冻结或解冻模型的卷积层参数"""# 冻结/解冻除fc层外的所有参数for name, param in model.named_parameters():if 'fc' not in name:param.requires_grad = not freeze# 打印冻结状态frozen_params = sum(p.numel() for p in model.parameters() if not p.requires_grad)total_params = sum(p.numel() for p in model.parameters())if freeze:print(f"已冻结模型卷积层参数 ({frozen_params}/{total_params} 参数)")else:print(f"已解冻模型所有参数 ({total_params}/{total_params} 参数可训练)")return model# 6. 训练函数(支持阶段式训练)
def train_with_freeze_schedule(model, train_loader, test_loader, criterion, optimizer, scheduler, device, epochs, freeze_epochs=5):"""前freeze_epochs轮冻结卷积层,之后解冻所有层进行训练"""train_loss_history = []test_loss_history = []train_acc_history = []test_acc_history = []all_iter_losses = []iter_indices = []# 初始冻结卷积层if freeze_epochs > 0:model = freeze_model(model, freeze=True)for epoch in range(epochs):# 解冻控制:在指定轮次后解冻所有层if epoch == freeze_epochs:model = freeze_model(model, freeze=False)# 解冻后调整优化器(可选)optimizer.param_groups[0]['lr'] = 1e-4  # 降低学习率防止过拟合model.train()  # 设置为训练模式running_loss = 0.0correct_train = 0total_train = 0for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()# 记录Iteration损失iter_loss = loss.item()all_iter_losses.append(iter_loss)iter_indices.append(epoch * len(train_loader) + batch_idx + 1)# 统计训练指标running_loss += iter_loss_, predicted = output.max(1)total_train += target.size(0)correct_train += predicted.eq(target).sum().item()# 每100批次打印进度if (batch_idx + 1) % 100 == 0:print(f"Epoch {epoch+1}/{epochs} | Batch {batch_idx+1}/{len(train_loader)} "f"| 单Batch损失: {iter_loss:.4f}")# 计算 epoch 级指标epoch_train_loss = running_loss / len(train_loader)epoch_train_acc = 100. * correct_train / total_train# 测试阶段model.eval()correct_test = 0total_test = 0test_loss = 0.0with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)test_loss += criterion(output, target).item()_, predicted = output.max(1)total_test += target.size(0)correct_test += predicted.eq(target).sum().item()epoch_test_loss = test_loss / len(test_loader)epoch_test_acc = 100. * correct_test / total_test# 记录历史数据train_loss_history.append(epoch_train_loss)test_loss_history.append(epoch_test_loss)train_acc_history.append(epoch_train_acc)test_acc_history.append(epoch_test_acc)# 更新学习率调度器if scheduler is not None:scheduler.step(epoch_test_loss)# 打印 epoch 结果print(f"Epoch {epoch+1} 完成 | 训练损失: {epoch_train_loss:.4f} "f"| 训练准确率: {epoch_train_acc:.2f}% | 测试准确率: {epoch_test_acc:.2f}%")# 绘制损失和准确率曲线plot_iter_losses(all_iter_losses, iter_indices)plot_epoch_metrics(train_acc_history, test_acc_history, train_loss_history, test_loss_history)return epoch_test_acc  # 返回最终测试准确率# 7. 绘制Iteration损失曲线
def plot_iter_losses(losses, indices):plt.figure(figsize=(10, 4))plt.plot(indices, losses, 'b-', alpha=0.7)plt.xlabel('Iteration(Batch序号)')plt.ylabel('损失值')plt.title('训练过程中的Iteration损失变化')plt.grid(True)plt.show()# 8. 绘制Epoch级指标曲线
def plot_epoch_metrics(train_acc, test_acc, train_loss, test_loss):epochs = range(1, len(train_acc) + 1)plt.figure(figsize=(12, 5))# 准确率曲线plt.subplot(1, 2, 1)plt.plot(epochs, train_acc, 'b-', label='训练准确率')plt.plot(epochs, test_acc, 'r-', label='测试准确率')plt.xlabel('Epoch')plt.ylabel('准确率 (%)')plt.title('准确率随Epoch变化')plt.legend()plt.grid(True)# 损失曲线plt.subplot(1, 2, 2)plt.plot(epochs, train_loss, 'b-', label='训练损失')plt.plot(epochs, test_loss, 'r-', label='测试损失')plt.xlabel('Epoch')plt.ylabel('损失值')plt.title('损失值随Epoch变化')plt.legend()plt.grid(True)plt.tight_layout()plt.show()# 主函数:训练模型
def main():# 参数设置epochs = 40  # 总训练轮次freeze_epochs = 5  # 冻结卷积层的轮次learning_rate = 1e-3  # 初始学习率weight_decay = 1e-4  # 权重衰减# 创建ResNet18模型(加载预训练权重)model = create_resnet18(pretrained=True, num_classes=10)# 定义优化器和损失函数optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)criterion = nn.CrossEntropyLoss()# 定义学习率调度器scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2, verbose=True)# 开始训练(前5轮冻结卷积层,之后解冻)final_accuracy = train_with_freeze_schedule(model=model,train_loader=train_loader,test_loader=test_loader,criterion=criterion,optimizer=optimizer,scheduler=scheduler,device=device,epochs=epochs,freeze_epochs=freeze_epochs)print(f"训练完成!最终测试准确率: {final_accuracy:.2f}%")# # 保存模型# torch.save(model.state_dict(), 'resnet18_cifar10_finetuned.pth')# print("模型已保存至: resnet18_cifar10_finetuned.pth")if __name__ == "__main__":main()
使用设备: cuda
Files already downloaded and verified
已冻结模型卷积层参数 (11176512/11181642 参数)
Epoch 1/40 | Batch 100/782 | 单Batch损失: 2.2746
Epoch 1/40 | Batch 200/782 | 单Batch损失: 2.0283
Epoch 1/40 | Batch 300/782 | 单Batch损失: 1.7735
Epoch 1/40 | Batch 400/782 | 单Batch损失: 1.9917
Epoch 1/40 | Batch 500/782 | 单Batch损失: 1.6662
Epoch 1/40 | Batch 600/782 | 单Batch损失: 1.8309
Epoch 1/40 | Batch 700/782 | 单Batch损失: 1.7485
Epoch 1 完成 | 训练损失: 1.9633 | 训练准确率: 29.96% | 测试准确率: 32.28%
Epoch 2/40 | Batch 100/782 | 单Batch损失: 1.9622
Epoch 2/40 | Batch 200/782 | 单Batch损失: 1.9096
Epoch 2/40 | Batch 300/782 | 单Batch损失: 1.9117
Epoch 2/40 | Batch 400/782 | 单Batch损失: 1.6654
Epoch 2/40 | Batch 500/782 | 单Batch损失: 1.8214
Epoch 2/40 | Batch 600/782 | 单Batch损失: 2.0345
Epoch 2/40 | Batch 700/782 | 单Batch损失: 2.1190
Epoch 2 完成 | 训练损失: 1.8675 | 训练准确率: 33.81% | 测试准确率: 32.79%
Epoch 3/40 | Batch 100/782 | 单Batch损失: 1.7294
Epoch 3/40 | Batch 200/782 | 单Batch损失: 1.8884
Epoch 3/40 | Batch 300/782 | 单Batch损失: 2.0095
Epoch 3/40 | Batch 400/782 | 单Batch损失: 1.7234
Epoch 3/40 | Batch 500/782 | 单Batch损失: 1.8481
Epoch 3/40 | Batch 600/782 | 单Batch损失: 1.8441
...
Epoch 40/40 | Batch 500/782 | 单Batch损失: 0.3312
Epoch 40/40 | Batch 600/782 | 单Batch损失: 0.4169
Epoch 40/40 | Batch 700/782 | 单Batch损失: 0.1733
Epoch 40 完成 | 训练损失: 0.2713 | 训练准确率: 90.29% | 测试准确率: 86.30%训练完成!最终测试准确率: 86.30%

几个明显的现象

1. 解冻后几个epoch即可达到之前cnn训练20轮的效果,这是预训练的优势

2. 由于训练集用了 RandomCrop(随机裁剪)、RandomHorizontalFlip(随机水平翻转)、ColorJitter(颜色抖动)等数据增强操作,这会让训练时模型看到的图片有更多 “干扰” 或变形。比如一张汽车图片,训练时可能被裁剪成只显示局部、颜色也有变化,模型学习难度更高;而测试集是标准的、没增强的图片,模型预测相对轻松,就可能出现训练集准确率暂时低于测试集的情况,尤其在训练前期增强对模型影响更明显。随着训练推进,模型适应增强后会缓解。

3. 最后收敛后的效果超过非预训练模型的80%,大幅提升

@浙大疏锦行

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/news/918647.shtml
繁体地址,请注明出处:http://hk.pswp.cn/news/918647.shtml
英文地址,请注明出处:http://en.pswp.cn/news/918647.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Java Stream API 中常用方法复习及项目实战示例

在最近的练手项目中,对于stream流的操作愈加频繁,我也越来越感觉stream流在处理数据是的干净利落,因此写博客用来记录最近常用的方法以便于未来的复习。map() 方法map()是一个中间操作(intermediate operation)&#x…

从零开始手搓一个GPT大语言模型:从理论到实践的完整指南(一)

现在人工智能飞速发展时代,LLM绝对可以算是人工智能领域得一颗明珠,也是现在许多AI项目落地得必不可少得一个模块,可以说,不管你之前得研究领域是AI得哪个方向,现在都需要会一些LLM基础,在这个系列&#xf…

Redis ubuntu下载Redis的C++客户端

1. 安装 redis-plus-plus C 操作 Redis 的库有很多,这里选择使用 redis-plus-plus,这个库的功能强大,使用简单。 Github 地址:GitHub - sewenew/redis-plus-plus: Redis client written in C 访问不了Github 地址的可以使用Ste…

nm命令和nm -D命令参数

出现这种差异的原因在于:动态库中的符号分为两种类型: 常规符号表(regular symbol table):通常用于静态链接和调试,默认不包含在动态库中(除非显式保留)。动态符号表(dyn…

Windows下cuda的安装和配置

今天开始做一个cuda教程。由于本人主要在windows下使用visual studio进行开发,因此这里讲一下windows下的cuda开发环境。 下载cuda_toolkit 从网站https://developer.nvidia.com/cuda-toolkit中下载,先选择Download Now,然后跳转到如下页面&#xff1a…

【代码随想录day 19】 力扣 450.删除二叉搜索树中的节点

视频讲解:https://www.bilibili.com/video/BV1tP41177us/?share_sourcecopy_web&vd_sourcea935eaede74a204ec74fd041b917810c 文档讲解:https://programmercarl.com/0450.%E5%88%A0%E9%99%A4%E4%BA%8C%E5%8F%89%E6%90%9C%E7%B4%A2%E6%A0%91%E4%B8%A…

智慧养老丨实用科普+避坑指南:科技如何让晚年生活更安全舒适?

随着老龄化社会的到来,智慧养老产品逐渐成为改善老年人生活质量的重要工具。从智能手表到便携洗浴机,科技正为老年人的健康、安全与生活便利提供创新解决方案。我们这次主要介绍四类典型智慧养老产品,结合真实体验给出选购建议,并…

系统垃圾清理批处理脚本 (BAT)

系统垃圾清理批处理脚本 (BAT) 以下是一个Windows系统垃圾清理的批处理脚本,它可以清理常见的系统临时文件、缓存和日志等: echo off title 系统垃圾清理工具 color 0a echo. echo 正在清理系统垃圾文件,请稍候... echo.:: 清理临时文件 echo…

Terraform的零基础学习教程

一、Terraform 是什么? Terraform 是由 HashiCorp 开发的开源工具,用于自动化管理云基础设施(如 AWS、Azure、GCP 等)。 核心特点: 基础设施即代码(IaC):用代码定义和管理资源。跨…

429. N 叉树的层序遍历(中等)题解

题目描述给定一个 N 叉树,返回其节点值的层序遍历。(即从左到右,逐层遍历)。树的序列化输入是用层序遍历,每组子节点都由 null 值分隔(参见示例)。示例 1:输入:root [1,…

Java 课程,每天解读一个简单Java之题目:输入一行字符,分别统计出其中英文字母、空格、数字和其它字符的个数。

package ytr250813;import java.io.IOException;public class CharacterCounter {public static void main(String[] args) throws IOException {// 初始化计数器变量int letterCount 0; // 英文字母计数器int spaceCount 0; // 空格计数器int digitCount 0; // 数字计数器i…

GitLab CI + Docker 自动构建前端项目并部署 — 完整流程文档

一、环境准备1. 服务器准备一台Linux服务器(CentOS/Ubuntu皆可),推荐至少4核8GB内存已安装 Docker(及 Docker 服务已启动)已安装 GitLab Runner2. 服务器上安装 Docker (如果没装)# CentOS9以下…

LCP 17. 速算机器人

目录 题目链接: 题目: 解题思路: 代码: 总结: 题目链接: LCP 17. 速算机器人 - 力扣(LeetCode) 题目: # LCP 17. 速算机器人 小扣在秋日市集发现了一款速算机器人。…

Spring cloud集成ElastictJob分布式定时任务完整攻略(含snakeyaml报错处理方法)

ElasticJob 是一款轻量级、可扩展的分布式定时任务解决方案,基于 Quartz 二次开发,支持任务分片、失效转移、任务追踪等功能,非常适合在 Spring Cloud 微服务场景中使用。我将带你完成 Spring Cloud 集成 ElasticJob 的全过程,并分…

了解 Linux 中的 /usr 目录以及 bin、sbin 和 lib 的演变

Linux 文件系统层次结构是一个复杂且引人入胜的体系,其根源深植于类 Unix 操作系统的历史之中。在这一结构的核心,/usr 目录是一个至关重要的组成部分,随着时间的推移,它经历了显著的演变。与此同时,/bin、/sbin、/lib…

高级IO(五种IO模型介绍)

文章目录一、IO为什么慢?一、阻塞IO二、非阻塞IO三、信号驱动IO四、IO多路复用五、异步IO一、IO为什么慢? IO操作往往都是和外设交互,比如键盘、鼠标、打印机、磁盘。而最常见的就是内存与磁盘的交互,要知道磁盘是机械设备&#…

第十二节:粒子系统:海量点渲染

第十二节:粒子系统:海量点渲染 引言 粒子系统是创造动态视觉效果的神器,从漫天繁星到熊熊火焰,从魔法特效到数据可视化,都离不开粒子技术。Three.js提供了强大的粒子渲染能力,可轻松处理百万级粒子。本文将…

LeetCode Day5 -- 二叉树

目录 1. 啥时候用二叉树? (1)典型问题 (2)核心思路 2. BFS、DFS、BST 2.1 广度优先搜索BFS (1)适用任务 (2)解决思路​​:使用队列逐层遍历 2.2 深度…

<c1:C1DateTimePicker的日期时间控件,控制日期可以修改,时间不能修改,另外控制开始时间的最大值比结束时间小一天

两个时间控件 <c1:C1DateTimePicker Width"170" EditMode"DateTime" CustomDateFormat"yyyy-MM-dd" CustomTimeFormat"HH:mm:ss" Style"{StaticResource ListSearch-DateTimePicker}" x:Name"dateTimePicker"…

文件完整性监控工具:架构和实现

文件完整性监控(FIM)作为一道关键的防御层&#xff0c;确保系统和网络中文件及文件夹的完整性与安全性。文件完整性监控工具通过监控关键文件的变更并检测未经授权的修改&#xff0c;提供关于潜在安全漏洞、恶意软件感染和内部威胁的早期警报。为了使文件完整性监控工具发挥实效…