对象分割任务的目标是找到图像中目标对象的边界。实际应用例如自动驾驶汽车和医学成像分析。这里将使用PyTorch开发一个深度学习模型来完成多对象分割任务。多对象分割的主要目标是自动勾勒出图像中多个目标对象的边界。

对象的边界通常由与图像大小相同的分割掩码定义,在分割掩码中属于目标对象的所有像素基于预定义的标记被标记为相同。

目录

创建数据集

创建数据加载器

创建模型

部署模型

定义损失函数和优化器

训练和验证模型


创建数据集

from torchvision.datasets import VOCSegmentation
from PIL import Image   
from torchvision.transforms.functional import to_tensor, to_pil_imageclass myVOCSegmentation(VOCSegmentation):def __getitem__(self, index):img = Image.open(self.images[index]).convert('RGB')target = Image.open(self.masks[index])if self.transforms is not None:augmented= self.transforms(image=np.array(img), mask=np.array(target))img = augmented['image']target = augmented['mask']                  target[target>20]=0img= to_tensor(img)            target= torch.from_numpy(target).type(torch.long)return img, targetfrom albumentations import (HorizontalFlip,Compose,Resize,Normalize)mean = [0.485, 0.456, 0.406] 
std = [0.229, 0.224, 0.225]
h,w=520,520transform_train = Compose([ Resize(h,w),HorizontalFlip(p=0.5), Normalize(mean=mean,std=std)])transform_val = Compose( [ Resize(h,w),Normalize(mean=mean,std=std)])            path2data="./data/"    
train_ds=myVOCSegmentation(path2data, year='2012', image_set='train', download=False, transforms=transform_train) 
print(len(train_ds)) val_ds=myVOCSegmentation(path2data, year='2012', image_set='val', download=False, transforms=transform_val)
print(len(val_ds)) 
import torch
import numpy as np
from skimage.segmentation import mark_boundaries
import matplotlib.pylab as plt
%matplotlib inline
np.random.seed(0)
num_classes=21
COLORS = np.random.randint(0, 2, size=(num_classes+1, 3),dtype="uint8")def show_img_target(img, target):if torch.is_tensor(img):img=to_pil_image(img)target=target.numpy()for ll in range(num_classes):mask=(target==ll)img=mark_boundaries(np.array(img) , mask,outline_color=COLORS[ll],color=COLORS[ll])plt.imshow(img)def re_normalize (x, mean = mean, std= std):x_r= x.clone()for c, (mean_c, std_c) in enumerate(zip(mean, std)):x_r [c] *= std_cx_r [c] += mean_creturn x_r

 展示训练数据集示例图像

img, mask = train_ds[10]
print(img.shape, img.type(),torch.max(img))
print(mask.shape, mask.type(),torch.max(mask))plt.figure(figsize=(20,20))img_r= re_normalize(img)
plt.subplot(1, 3, 1) 
plt.imshow(to_pil_image(img_r))plt.subplot(1, 3, 2) 
plt.imshow(mask)plt.subplot(1, 3, 3) 
show_img_target(img_r, mask)

展示验证数据集示例图像

img, mask = val_ds[10]
print(img.shape, img.type(),torch.max(img))
print(mask.shape, mask.type(),torch.max(mask))plt.figure(figsize=(20,20))img_r= re_normalize(img)
plt.subplot(1, 3, 1) 
plt.imshow(to_pil_image(img_r))plt.subplot(1, 3, 2) 
plt.imshow(mask)plt.subplot(1, 3, 3) 
show_img_target(img_r, mask)

创建数据加载器

 通过torch.utils.data针对训练和验证集分别创建Dataloader,打印示例观察效果

from torch.utils.data import DataLoader
train_dl = DataLoader(train_ds, batch_size=4, shuffle=True)
val_dl = DataLoader(val_ds, batch_size=8, shuffle=False) for img_b, mask_b in train_dl:print(img_b.shape,img_b.dtype)print(mask_b.shape, mask_b.dtype)breakfor img_b, mask_b in val_dl:print(img_b.shape,img_b.dtype)print(mask_b.shape, mask_b.dtype)break

创建模型

创建并打印deeplab_resnet模型结构,使用预训练权重

from torchvision.models.segmentation import deeplabv3_resnet101
import torchmodel=deeplabv3_resnet101(pretrained=True, num_classes=21)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model=model.to(device)
print(model)

部署模型

在验证数据集的数据批次上部署模型观察效果 

from torch import nnmodel.eval()
with torch.no_grad():for xb, yb in val_dl:yb_pred = model(xb.to(device))yb_pred = yb_pred["out"].cpu()print(yb_pred.shape)    yb_pred = torch.argmax(yb_pred,axis=1)break
print(yb_pred.shape)plt.figure(figsize=(20,20))n=2
img, mask= xb[n], yb_pred[n]
img_r= re_normalize(img)
plt.subplot(1, 3, 1) 
plt.imshow(to_pil_image(img_r))plt.subplot(1, 3, 2) 
plt.imshow(mask)plt.subplot(1, 3, 3) 
show_img_target(img_r, mask)

可见勾勒对象方面效果很好 

定义损失函数和优化器

from torch import nn
criterion = nn.CrossEntropyLoss(reduction="sum")
from torch import optim
opt = optim.Adam(model.parameters(), lr=1e-6)def loss_batch(loss_func, output, target, opt=None):   loss = loss_func(output, target)if opt is not None:opt.zero_grad()loss.backward()opt.step()return loss.item(), Nonefrom torch.optim.lr_scheduler import ReduceLROnPlateau
lr_scheduler = ReduceLROnPlateau(opt, mode='min',factor=0.5, patience=20,verbose=1)def get_lr(opt):for param_group in opt.param_groups:return param_group['lr']current_lr=get_lr(opt)
print('current lr={}'.format(current_lr))

训练和验证模型

def loss_epoch(model,loss_func,dataset_dl,sanity_check=False,opt=None):running_loss=0.0len_data=len(dataset_dl.dataset)for xb, yb in dataset_dl:xb=xb.to(device)yb=yb.to(device)output=model(xb)["out"]loss_b, _ = loss_batch(loss_func, output, yb, opt)running_loss += loss_bif sanity_check is True:breakloss=running_loss/float(len_data)return loss, Noneimport copy
def train_val(model, params):num_epochs=params["num_epochs"]loss_func=params["loss_func"]opt=params["optimizer"]train_dl=params["train_dl"]val_dl=params["val_dl"]sanity_check=params["sanity_check"]lr_scheduler=params["lr_scheduler"]path2weights=params["path2weights"]loss_history={"train": [],"val": []}metric_history={"train": [],"val": []}    best_model_wts = copy.deepcopy(model.state_dict())best_loss=float('inf')    for epoch in range(num_epochs):current_lr=get_lr(opt)print('Epoch {}/{}, current lr={}'.format(epoch, num_epochs - 1, current_lr))   model.train()train_loss, train_metric=loss_epoch(model,loss_func,train_dl,sanity_check,opt)loss_history["train"].append(train_loss)metric_history["train"].append(train_metric)model.eval()with torch.no_grad():val_loss, val_metric=loss_epoch(model,loss_func,val_dl,sanity_check)loss_history["val"].append(val_loss)metric_history["val"].append(val_metric)   if val_loss < best_loss:best_loss = val_lossbest_model_wts = copy.deepcopy(model.state_dict())torch.save(model.state_dict(), path2weights)print("Copied best model weights!")lr_scheduler.step(val_loss)if current_lr != get_lr(opt):print("Loading best model weights!")model.load_state_dict(best_model_wts) print("train loss: %.6f" %(train_loss))print("val loss: %.6f" %(val_loss))print("-"*10) model.load_state_dict(best_model_wts)return model, loss_history, metric_history        
import os
opt = optim.Adam(model.parameters(), lr=1e-6)
lr_scheduler = ReduceLROnPlateau(opt, mode='min',factor=0.5, patience=20,verbose=1)path2models= "./models/"
if not os.path.exists(path2models):os.mkdir(path2models)params_train={"num_epochs": 10,"optimizer": opt,"loss_func": criterion,"train_dl": train_dl,"val_dl": val_dl,"sanity_check": True,"lr_scheduler": lr_scheduler,"path2weights": path2models+"sanity_weights.pt",
}model, loss_hist, _ = train_val(model, params_train)

绘制了训练和验证损失曲线 

num_epochs=params_train["num_epochs"]plt.title("Train-Val Loss")
plt.plot(range(1,num_epochs+1),loss_hist["train"],label="train")
plt.plot(range(1,num_epochs+1),loss_hist["val"],label="val")
plt.ylabel("Loss")
plt.xlabel("Training Epochs")
plt.legend()
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/pingmian/93687.shtml
繁体地址,请注明出处:http://hk.pswp.cn/pingmian/93687.shtml
英文地址,请注明出处:http://en.pswp.cn/pingmian/93687.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

SSH 使用密钥登录服务器

用这种方法远程登陆服务器的时候无需手动输入密码 具体步骤 客户端通过 ssh-keygen 生成公钥和私钥 ssh-keygen -t rsa 生成的时候会有一系列问题&#xff0c;根据自己的需要选择就行。生成的结果为两个文件&#xff1a; 上传公钥至服务器&#xff0c;上述两个文件一般在客户…

MySQL 8.4 企业版启用TDE功能和表加密

一、系统环境操作系统&#xff1a;Ubuntu 24.04 数据库:8.4.4-commercial for Linux on x86_64 (MySQL Enterprise Server - Commercial)二、安装TDE组件前提&#xff1a;检查组件文件是否存在ls /usr/lib/mysql/plugin/component_keyring_encrypted_file.so1.配置全局清单文件…

【Altium designer】导出的原理图PDF乱码异常的解决方法

一、有些电源名字无法显示或器件丢失 解决办法 (1)首先AD18以及以上的新版本AD不存在该问题。 (2)其次AD17以及更旧版本的AD很可能遇到该问题,参考如下博客笔记进行操作即可: 大致的操作如下:DXP → Preferences → Schematic → Options里面“Render Text with GDI+”…

4.Ansible自动化之-部署文件到主机

4 - 部署文件到受管主机 实验环境 先通过以下命令搭建基础环境&#xff08;创建工作目录、配置 Ansible 环境和主机清单&#xff09;&#xff1a; # 在控制节点&#xff08;controller&#xff09;上创建web目录并进入&#xff0c;作为工作目录 [bqcontroller ~]$ mkdir web &a…

Vuex的使用

Vuex 超详细使用教程&#xff08;从入门到精通&#xff09;一、Vuex 是什么&#xff1f;Vuex 是专门为 Vue.js 设计的状态管理库&#xff0c;它采用集中式存储管理应用的所有组件的状态。简单来说&#xff0c;Vuex 就是一个"全局变量仓库"&#xff0c;所有组件都可以…

pytorch 数据预处理,加载,训练,可视化流程

流程定义自定义数据集类定义训练和验证的数据增强定义模型、损失函数和优化器训练循环&#xff0c;包括验证训练可视化整个流程模型评估高级功能扩展混合精度训练​分布式训练​{:width“50%” height“50%”} 定义自定义数据集类 # #1. 自定义数据集类 # class CustomImageD…

Prompt工程:OCR+LLM文档处理的精准制导系统

在PDF OCR与大模型结合的实际应用中&#xff0c;很多团队会发现一个现象&#xff1a;同样的OCR文本&#xff0c;不同的Prompt设计会产生截然不同的提取效果。有时候准确率能达到95%&#xff0c;有时候却只有60%。这背后的关键就在于Prompt工程的精细化程度。 &#x1f3af; 为什…

RecSys:粗排模型和精排特征体系

粗排 在推荐系统链路中&#xff0c;排序阶段至关重要&#xff0c;通常分为召回、粗排和精排三个环节。粗排作为精排前的预处理阶段&#xff0c;需要在效果和性能之间取得平衡。 双塔模型 后期融合&#xff1a;把用户、物品特征分别输入不同的神经网络&#xff0c;不对用户、…

spring声明式事务,finally 中return对事务回滚的影响

finally 块中使用 return 是一个常见的编程错误&#xff0c;它会&#xff1a; 跳过正常的事务提交流程。吞掉异常&#xff0c;使错误处理失效 导致不可预测的事务行为Java 中 finally 和 return 的执行机制&#xff1a;1. finally 块的基本特性 在 Java 中&#xff0c;finally …

WPF 打印报告图片大小的自适应(含完整示例与详解)

目标&#xff1a;在 FlowDocument 报告里&#xff0c;根据 1~6 张图片的数量&#xff0c; 自动选择 2 行 3 列 的最佳布局&#xff1b;在只有 1、2、4 张时保持“占满感”&#xff0c;打印清晰且不变形。规则一览&#xff1a;1 张 → 占满 23&#xff08;大图居中&#xff09;…

【AI大模型前沿】百度飞桨PaddleOCR 3.0开源发布,支持多语言、手写体识别,赋能智能文档处理

系列篇章&#x1f4a5; No.文章1【AI大模型前沿】深度剖析瑞智病理大模型 RuiPath&#xff1a;如何革新癌症病理诊断技术2【AI大模型前沿】清华大学 CLAMP-3&#xff1a;多模态技术引领音乐检索新潮流3【AI大模型前沿】浙大携手阿里推出HealthGPT&#xff1a;医学视觉语言大模…

迅为RK3588开发板Android12 制作使用系统签名

在 Android 源码 build/make/target/product/security/下存放着签名文件&#xff0c;如下所示&#xff1a;将北京迅为提供的 keytool 工具拷贝到 ubuntu 中&#xff0c;然后将 Android11 或 Android12 源码build/make/target/product/security/下的 platform.pk8 platform.x509…

Day08 Go语言学习

1.安装Go和Goland 2.新建demo项目实践语法并使用git实践版本控制操作 2.1 Goland配置 路径**&#xff1a;** GOPATH workspace GOROOT golang 文件夹&#xff1a; bin 编译后的可执行文件 pkg 编译后的包文件 src 源文件 遇到问题1&#xff1a;运行 ‘go build awesomeProject…

Linux-文件创建拷贝删除剪切

文章目录Linux文件相关命令ls通配符含义touch 创建文件命令示例cp 拷贝文件rm 删除文件mv剪切文件Linux文件相关命令 ls ls是英文单词list的简写&#xff0c;其功能为列出目录的内容&#xff0c;是用户最常用的命令之一&#xff0c;它类似于DOS下的dir命令。 Linux文件或者目…

RabbitMQ:交换机(Exchange)

目录一、概述二、Direct Exchange &#xff08;直连型交换机&#xff09;三、Fanout Exchange&#xff08;扇型交换机&#xff09;四、Topic Exchange&#xff08;主题交换机&#xff09;五、Header Exchange&#xff08;头交换机&#xff09;六、Default Exchange&#xff08;…

【实时Linux实战系列】基于实时Linux的物联网系统设计

随着物联网&#xff08;IoT&#xff09;技术的飞速发展&#xff0c;越来越多的设备被连接到互联网&#xff0c;形成了一个庞大而复杂的网络。这些设备从简单的传感器到复杂的工业控制系统&#xff0c;都在实时地产生和交换数据。实时Linux作为一种强大的操作系统&#xff0c;为…

第五天~提取Arxml中描述信息New_CanCluster--Expert

🔍 ARXML描述信息提取:挖掘汽车电子设计的"知识宝藏" 在AUTOSAR工程中,描述信息如同埋藏在ARXML文件中的金矿,而New_CanCluster--Expert正是打开这座宝藏的密钥。本文将带您深度探索ARXML描述信息的提取艺术,解锁汽车电子设计的核心知识资产! 💎 为什么描述…

开源 C++ QT Widget 开发(一)工程文件结构

文章的目的为了记录使用C 进行QT Widget 开发学习的经历。临时学习&#xff0c;完成app的开发。开发流程和要点有些记忆模糊&#xff0c;赶紧记录&#xff0c;防止忘记。 相关链接&#xff1a; 开源 C QT Widget 开发&#xff08;一&#xff09;工程文件结构-CSDN博客 开源 C…

手写C++ string类实现详解

类定义cppnamespace ym {class string {private:char* _str; // 字符串数据size_t _size; // 当前字符串长度size_t _capacity; // 当前分配的内存容量static const size_t npos -1; // 特殊值&#xff0c;表示最大可能位置public:// 构造函数和析构函数string(…

C++信息学奥赛一本通-第一部分-基础一-第3章-第2节

C信息学奥赛一本通-第一部分-基础一-第3章-第2节 2057 星期几 #include <iostream>using namespace std;int main() {int day; cin >> day;switch (day) {case 1:cout << "Monday";break;case 2:cout << "Tuesday";break;case 3:c…