流程

    • 定义自定义数据集类
    • 定义训练和验证的数据增强
    • 定义模型、损失函数和优化器
    • 训练循环,包括验证
    • 训练可视化
    • 整个流程
    • 模型评估
    • 高级功能扩展
      • 混合精度训练​
      • 分布式训练​

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传{:width=“50%” height=“50%”}

定义自定义数据集类

#======================
#1. 自定义数据集类
#======================
class CustomImageDataset(Dataset):def __init__(self, root_dir, transform=None):"""自定义数据集初始化:param root_dir: 数据集根目录:param transform: 数据增强和预处理"""self.root_dir = root_dirself.transform = transformself.classes = sorted(os.listdir(root_dir))self.class_to_idx = {cls_name: i for i, cls_name in enumerate(self.classes)}# 收集所有图像路径和标签self.image_paths = []self.labels = []for cls_name in self.classes:cls_dir = os.path.join(root_dir, cls_name)for img_name in os.listdir(cls_dir):if img_name.lower().endswith(('.jpg', '.png', '.jpeg')):self.image_paths.append(os.path.join(cls_dir, img_name))self.labels.append(self.class_to_idx[cls_name])def __len__(self):return len(self.image_paths)def __getitem__(self, idx):# 加载图像img_path = self.image_paths[idx]try:image = Image.open(img_path).convert('RGB')except Exception as e:print(f"Error loading image {img_path}: {e}")# 返回空白图像作为占位符image = Image.new('RGB', (224, 224), (0, 0, 0))# 应用数据增强和预处理if self.transform:image = self.transform(image)# 获取标签label = self.labels[idx]return image, label

定义训练和验证的数据增强

#======================
#2. 数据增强与预处理
#======================
def get_transforms():"""返回训练和验证的数据增强管道"""# 训练集增强(更丰富)train_transform = transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(p=0.5),transforms.RandomRotation(15),transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])# 验证集预处理(无随机增强)val_transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])return train_transform, val_transform

定义模型、损失函数和优化器

#======================
#3. 模型定义
#======================
def create_model(num_classes):"""创建模型(使用预训练ResNet18)"""model = resnet18(pretrained=True)num_ftrs = model.fc.in_featuresmodel.fc = nn.Linear(num_ftrs, num_classes)return model

训练循环,包括验证

#======================
#4. 训练函数
#======================
def train_model(model, dataloaders, criterion, optimizer, scheduler, device, num_epochs=25, checkpoint_path='checkpoint.pth', resume=False):"""训练模型并支持中断恢复:param resume: 是否从检查点恢复训练"""# 训练历史记录history = {'train_loss': [], 'val_loss': [],'train_acc': [], 'val_acc': [],'epoch': 0, 'best_acc': 0.0}# 从检查点恢复start_epoch = 0if resume and os.path.exists(checkpoint_path):print(f"Loading checkpoint from {checkpoint_path}")checkpoint = torch.load(checkpoint_path)model.load_state_dict(checkpoint['model_state_dict'])optimizer.load_state_dict(checkpoint['optimizer_state_dict'])scheduler.load_state_dict(checkpoint['scheduler_state_dict'])history = checkpoint['history']start_epoch = history['epoch'] + 1print(f"Resuming training from epoch {start_epoch}")# 训练循环for epoch in range(start_epoch, num_epochs):print(f'Epoch {epoch+1}/{num_epochs}')print('-' * 10)# 更新历史记录history['epoch'] = epoch# 每个epoch都有训练和验证阶段for phase in ['train', 'val']:if phase == 'train':model.train()  # 设置训练模式else:model.eval()   # 设置评估模式running_loss = 0.0running_corrects = 0# 迭代数据for inputs, labels in dataloaders[phase]:inputs = inputs.to(device)labels = labels.to(device)# 梯度清零optimizer.zero_grad()# 前向传播with torch.set_grad_enabled(phase == 'train'):outputs = model(inputs)_, preds = torch.max(outputs, 1)loss = criterion(outputs, labels)# 训练阶段反向传播和优化if phase == 'train':loss.backward()optimizer.step()# 统计running_loss += loss.item() * inputs.size(0)running_corrects += torch.sum(preds == labels.data)if phase == 'train':scheduler.step()epoch_loss = running_loss / len(dataloaders[phase].dataset)epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)# 记录历史history[f'{phase}_loss'].append(epoch_loss)history[f'{phase}_acc'].append(epoch_acc.item())print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')# 保存最佳模型if phase == 'val' and epoch_acc > history['best_acc']:history['best_acc'] = epoch_acc.item()torch.save(model.state_dict(), 'best_model.pth')print(f"New best model saved with accuracy: {epoch_acc:.4f}")# 保存检查点(每个epoch结束后)checkpoint = {'epoch': epoch,'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'scheduler_state_dict': scheduler.state_dict(),'history': history}torch.save(checkpoint, checkpoint_path)print(f"Checkpoint saved at epoch {epoch+1}")print()# 保存最终模型torch.save(model.state_dict(), 'final_model.pth')print('Training finished!')return model, history

训练可视化

#======================
#5. 可视化训练历史
#======================
def plot_history(history):plt.figure(figsize=(12, 4))# 损失曲线plt.subplot(1, 2, 1)plt.plot(history['train_loss'], label='Train Loss')plt.plot(history['val_loss'], label='Validation Loss')plt.xlabel('Epochs')plt.ylabel('Loss')plt.legend()plt.title('Training and Validation Loss')# 准确率曲线plt.subplot(1, 2, 2)plt.plot(history['train_acc'], label='Train Accuracy')plt.plot(history['val_acc'], label='Validation Accuracy')plt.xlabel('Epochs')plt.ylabel('Accuracy')plt.legend()plt.title('Training and Validation Accuracy')plt.tight_layout()plt.savefig('training_history.png')plt.show()

整个流程

#======================
#6. 主函数
#======================
def main():# 设置随机种子(确保可复现性)torch.manual_seed(42)np.random.seed(42)# 检查设备device = torch.device("cuda" if torch.cuda.is_available() else "cpu")print(f"Using device: {device}")# 创建数据增强管道train_transform, val_transform = get_transforms()# 创建数据集train_dataset = CustomImageDataset(root_dir='path/to/your/train_data',  # 替换为你的训练数据路径transform=train_transform)val_dataset = CustomImageDataset(root_dir='path/to/your/val_data',    # 替换为你的验证数据路径transform=val_transform)# 创建数据加载器train_loader = DataLoader(train_dataset,batch_size=32,shuffle=True,num_workers=4,pin_memory=True)val_loader = DataLoader(val_dataset,batch_size=32,shuffle=False,num_workers=4,pin_memory=True)dataloaders = {'train': train_loader, 'val': val_loader}# 创建模型num_classes = len(train_dataset.classes)model = create_model(num_classes)model = model.to(device)# 定义损失函数和优化器criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)# 检查是否要恢复训练resume_training = Falsecheckpoint_path = 'checkpoint.pth'# 检查是否存在检查点文件if os.path.exists(checkpoint_path):print("Checkpoint file found. Do you want to resume training? (y/n)")response = input().lower()if response == 'y':resume_training = True# 开始训练start_time = time.time()model, history = train_model(model=model,dataloaders=dataloaders,criterion=criterion,optimizer=optimizer,scheduler=scheduler,device=device,num_epochs=25,checkpoint_path=checkpoint_path,resume=resume_training)end_time = time.time()# 保存训练历史with open('training_history.json', 'w') as f:json.dump(history, f, indent=4)# 打印训练时间training_time = end_time - start_timeprint(f"Total training time: {training_time//3600}h {(training_time%3600)//60}m {training_time%60:.2f}s")# 可视化训练历史plot_history(history)if __name__ == "__main__":main()

模型评估

#======================
#模型评估
#======================
def evaluate_model(model, dataloader, device):model.eval()correct = 0total = 0with torch.no_grad():for inputs, labels in dataloader:inputs, labels = inputs.to(device), labels.to(device)outputs = model(inputs)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()accuracy = 100 * correct / totalprint(f'Test Accuracy: {accuracy:.2f}%')return accuracy
test_dataset = CustomImageDataset('path/to/test_data', transform=val_transform)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
evaluate_model(model, test_loader, device)

高级功能扩展

混合精度训练​

from torch.cuda.amp import autocast, GradScaler
#在训练函数中添加
scaler = GradScaler()
#修改训练循环
with autocast():outputs = model(inputs)loss = criterion(outputs, labels)scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()

分布式训练​

import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
#初始化分布式环境
dist.init_process_group(backend='nccl')
local_rank = int(os.environ['LOCAL_RANK'])
torch.cuda.set_device(local_rank)
#包装模型
model = DDP(model.to(local_rank), device_ids=[local_rank])
#修改数据加载器
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
train_loader = DataLoader(..., sampler=train_sampler)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/pingmian/93681.shtml
繁体地址,请注明出处:http://hk.pswp.cn/pingmian/93681.shtml
英文地址,请注明出处:http://en.pswp.cn/pingmian/93681.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Prompt工程:OCR+LLM文档处理的精准制导系统

在PDF OCR与大模型结合的实际应用中,很多团队会发现一个现象:同样的OCR文本,不同的Prompt设计会产生截然不同的提取效果。有时候准确率能达到95%,有时候却只有60%。这背后的关键就在于Prompt工程的精细化程度。 🎯 为什…

RecSys:粗排模型和精排特征体系

粗排 在推荐系统链路中,排序阶段至关重要,通常分为召回、粗排和精排三个环节。粗排作为精排前的预处理阶段,需要在效果和性能之间取得平衡。 双塔模型 后期融合:把用户、物品特征分别输入不同的神经网络,不对用户、…

spring声明式事务,finally 中return对事务回滚的影响

finally 块中使用 return 是一个常见的编程错误,它会: 跳过正常的事务提交流程。吞掉异常,使错误处理失效 导致不可预测的事务行为Java 中 finally 和 return 的执行机制:1. finally 块的基本特性 在 Java 中,finally …

WPF 打印报告图片大小的自适应(含完整示例与详解)

目标:在 FlowDocument 报告里,根据 1~6 张图片的数量, 自动选择 2 行 3 列 的最佳布局;在只有 1、2、4 张时保持“占满感”,打印清晰且不变形。规则一览:1 张 → 占满 23(大图居中)…

【AI大模型前沿】百度飞桨PaddleOCR 3.0开源发布,支持多语言、手写体识别,赋能智能文档处理

系列篇章💥 No.文章1【AI大模型前沿】深度剖析瑞智病理大模型 RuiPath:如何革新癌症病理诊断技术2【AI大模型前沿】清华大学 CLAMP-3:多模态技术引领音乐检索新潮流3【AI大模型前沿】浙大携手阿里推出HealthGPT:医学视觉语言大模…

迅为RK3588开发板Android12 制作使用系统签名

在 Android 源码 build/make/target/product/security/下存放着签名文件,如下所示:将北京迅为提供的 keytool 工具拷贝到 ubuntu 中,然后将 Android11 或 Android12 源码build/make/target/product/security/下的 platform.pk8 platform.x509…

Day08 Go语言学习

1.安装Go和Goland 2.新建demo项目实践语法并使用git实践版本控制操作 2.1 Goland配置 路径**:** GOPATH workspace GOROOT golang 文件夹: bin 编译后的可执行文件 pkg 编译后的包文件 src 源文件 遇到问题1:运行 ‘go build awesomeProject…

Linux-文件创建拷贝删除剪切

文章目录Linux文件相关命令ls通配符含义touch 创建文件命令示例cp 拷贝文件rm 删除文件mv剪切文件Linux文件相关命令 ls ls是英文单词list的简写,其功能为列出目录的内容,是用户最常用的命令之一,它类似于DOS下的dir命令。 Linux文件或者目…

RabbitMQ:交换机(Exchange)

目录一、概述二、Direct Exchange (直连型交换机)三、Fanout Exchange(扇型交换机)四、Topic Exchange(主题交换机)五、Header Exchange(头交换机)六、Default Exchange(…

【实时Linux实战系列】基于实时Linux的物联网系统设计

随着物联网(IoT)技术的飞速发展,越来越多的设备被连接到互联网,形成了一个庞大而复杂的网络。这些设备从简单的传感器到复杂的工业控制系统,都在实时地产生和交换数据。实时Linux作为一种强大的操作系统,为…

第五天~提取Arxml中描述信息New_CanCluster--Expert

🔍 ARXML描述信息提取:挖掘汽车电子设计的"知识宝藏" 在AUTOSAR工程中,描述信息如同埋藏在ARXML文件中的金矿,而New_CanCluster--Expert正是打开这座宝藏的密钥。本文将带您深度探索ARXML描述信息的提取艺术,解锁汽车电子设计的核心知识资产! 💎 为什么描述…

开源 C++ QT Widget 开发(一)工程文件结构

文章的目的为了记录使用C 进行QT Widget 开发学习的经历。临时学习,完成app的开发。开发流程和要点有些记忆模糊,赶紧记录,防止忘记。 相关链接: 开源 C QT Widget 开发(一)工程文件结构-CSDN博客 开源 C…

手写C++ string类实现详解

类定义cppnamespace ym {class string {private:char* _str; // 字符串数据size_t _size; // 当前字符串长度size_t _capacity; // 当前分配的内存容量static const size_t npos -1; // 特殊值,表示最大可能位置public:// 构造函数和析构函数string(…

C++信息学奥赛一本通-第一部分-基础一-第3章-第2节

C信息学奥赛一本通-第一部分-基础一-第3章-第2节 2057 星期几 #include <iostream>using namespace std;int main() {int day; cin >> day;switch (day) {case 1:cout << "Monday";break;case 2:cout << "Tuesday";break;case 3:c…

【leetcode 3】最长连续序列 (Longest Consecutive Sequence) - 解题思路 + Golang实现

最长连续序列 (Longest Consecutive Sequence) - LeetCode 题解 题目描述 给定一个未排序的整数数组 nums&#xff0c;找出数字连续的最长序列&#xff08;不要求序列元素在原数组中连续&#xff09;的长度。要求设计并实现时间复杂度为 O(n) 的算法解决此问题。 示例 1&#x…

矿物分类系统开发笔记(一):数据预处理

目录 一、数据基础与预处理目标 二、具体预处理步骤及代码解析 2.1 数据加载与初步清洗 2.2 标签编码 2.3 缺失值处理 &#xff08;1&#xff09;删除含缺失值的样本 &#xff08;2&#xff09;按类别均值填充 &#xff08;3&#xff09;按类别中位数填充 &#xff08;…

《UE5_C++多人TPS完整教程》学习笔记43 ——《P44 奔跑混合空间(Running Blending Space)》

本文为B站系列教学视频 《UE5_C多人TPS完整教程》 —— 《P44 奔跑混合空间&#xff08;Running Blending Space&#xff09;》 的学习笔记&#xff0c;该系列教学视频为计算机工程师、程序员、游戏开发者、作家&#xff08;Engineer, Programmer, Game Developer, Author&…

TensorRT-LLM.V1.1.0rc1:Dockerfile.multi文件解读

一、TensorRT-LLM有三种安装方式&#xff0c;从简单到难 1.NGC上的预构建发布容器进行部署,见《tensorrt-llm0.20.0离线部署DeepSeek-R1-Distill-Qwen-32B》。 2.通过pip进行部署。 3.从源头构建再部署&#xff0c;《TensorRT-LLM.V1.1.0rc0:在无 GitHub 访问权限的服务器上编…

UniApp 实现pdf上传和预览

一、上传1、html<template><button click"takeFile">pdf上传</button> </template>2、JStakeFile() {// #ifdef H5// H5端使用input方式选择文件const input document.createElement(input);input.type file;input.accept .pdf;input.onc…

《用Proxy解构前端壁垒:跨框架状态共享库的从零到优之路》

一个项目中同时出现React的函数式组件、Vue的模板语法、Angular的依赖注入时,数据在不同框架体系间的流转便成了开发者不得不面对的难题—状态管理,这个本就复杂的命题,在跨框架场景下更显棘手。而Proxy,作为JavaScript语言赋予开发者的“元编程利器”,正为打破这道壁垒提…