目录

前言

1.检查GPU

2.查看数据

3.划分数据集

4.创建模型与编译训练

​​​​5.编译及训练模型 

6.结果可视化

7.模型预测 

8.总结:

前言

🍨 本文为🔗365天深度学习训练营中的学习记录博客
🍖 原作者:K同学啊

1.检查GPU

import numpy as np
import pandas as pd
import torch
from torch import nn
import torch.nn.functional as F
import seaborn as sns#设置GPU训练,也可以使用CPU
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

2.查看数据

df = pd.read_csv("DATA/alzheimers_disease_data.csv")
# 删除第一列和最后一列
df = df.iloc[:, 1:-1]
df

3.划分数据集

from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_splitX = df.iloc[:,:-1]
y = df.iloc[:,-1]# 将每一列特征标准化为标准正太分布,注意,标准化是针对每一列而言的
sc = StandardScaler()
X  = sc.fit_transform(X)X = torch.tensor(np.array(X), dtype=torch.float32)
y = torch.tensor(np.array(y), dtype=torch.int64)X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.1, random_state = 1)X_train.shape, y_train.shapefrom torch.utils.data import TensorDataset, DataLoadertrain_dl = DataLoader(TensorDataset(X_train, y_train),batch_size=64, shuffle=False)test_dl  = DataLoader(TensorDataset(X_test, y_test),batch_size=64, shuffle=False)

4.创建模型与编译训练

class model_rnn(nn.Module):def __init__(self):super(model_rnn, self).__init__()self.rnn0 = nn.RNN(input_size=32, hidden_size=200, num_layers=1, batch_first=True)self.fc0   = nn.Linear(200, 50)self.fc1   = nn.Linear(50, 2)def forward(self, x):out, hidden1 = self.rnn0(x) out    = self.fc0(out) out    = self.fc1(out) return out   model = model_rnn().to(device)
model

​​​​5.编译及训练模型 

# 训练循环
def train(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset) # 训练集的大小num_batches = len(dataloader) # 批次数目, (size/batch_size,向上取整)train_loss, train_acc = 0, 0 # 初始化训练损失和正确率for X, y in dataloader: # 获取图片及其标签X, y = X.to(device), y.to(device)# 计算预测误差pred = model(X) # 网络输出loss = loss_fn(pred, y) # 计算网络输出和真实值之间的差距,targets为真实值,计算二者差值即为损失# 反向传播optimizer.zero_grad() # grad属性归零loss.backward() # 反向传播optimizer.step() # 每一步自动更新# 记录acc与losstrain_acc += (pred.argmax(1) == y).type(torch.float).sum().item()train_loss += loss.item()train_acc /= sizetrain_loss /= num_batchesreturn train_acc, train_lossdef test (dataloader, model, loss_fn):size = len(dataloader.dataset) # 测试集的大小num_batches = len(dataloader) # 批次数目, (size/batch_size,向上取整)test_loss, test_acc = 0, 0# 当不进行训练时,停止梯度更新,节省计算内存消耗with torch.no_grad():for imgs, target in dataloader:imgs, target = imgs.to(device), target.to(device)# 计算losstarget_pred = model(imgs)loss = loss_fn(target_pred, target)test_loss += loss.item()test_acc += (target_pred.argmax(1) == target).type(torch.float).sum().item()test_acc /= sizetest_loss /= num_batchesreturn test_acc, test_lossloss_fn = nn.CrossEntropyLoss() # 创建损失函数
learn_rate = 1e-4 # 学习率
opt = torch.optim.Adam(model.parameters(),lr=learn_rate)
epochs = 30train_loss = []
train_acc = []
test_loss = []
test_acc = []for epoch in range(epochs):model.train()epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, opt)model.eval()epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)train_acc.append(epoch_train_acc)train_loss.append(epoch_train_loss)test_acc.append(epoch_test_acc)test_loss.append(epoch_test_loss)# 获取当前的学习率lr = opt.state_dict()['param_groups'][0]['lr']template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%, Test_loss:{:.3f}, Lr:{:.2E}')print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss, epoch_test_acc*100, epoch_test_loss, lr))
print("="*20, 'Done', "="*20)

6.结果可视化

import matplotlib.pyplot as plt
#隐藏警告
import warnings
warnings.filterwarnings("ignore") #忽略警告信息
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
plt.rcParams['figure.dpi'] = 100 #分辨率from datetime import datetime
current_time = datetime.now() # 获取当前时间epochs_range = range(epochs)plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
plt.xlabel(current_time) # 打卡请带上时间戳,否则代码截图无效plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay# 计算混淆矩阵
cm = confusion_matrix(y_test, pred)plt.figure(figsize=(6,5))
plt.suptitle('')
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues")# 修改字体大小
plt.xticks(fontsize=10)
plt.yticks(fontsize=10)
plt.title("Confusion Matrix", fontsize=12)
plt.xlabel("Predicted Label", fontsize=10)
plt.ylabel("True Label", fontsize=10)# 显示图
plt.tight_layout()  # 调整布局防止重叠
plt.show()

 

7.模型预测 

test_X = X_test[0].reshape(1, -1) # X_test[0]即我们的输入数据pred = model(test_X.to(device)).argmax(1).item()
print("模型预测结果为:",pred)
print("=="*20)
print("0:未患病")
print("1:已患病")

 

8.总结:

代码展示了如何使用PyTorch框架进行阿尔茨海默病数据集的分类任务。以下是该代码的主要步骤和功能总结:

检查GPU:首先,代码检查是否有可用的GPU,并设置相应的设备(cuda或cpu)。

查看数据:通过Pandas库加载数据集,并删除第一列和最后一列,这可能是为了去除非特征信息(如ID)或冗余信息。

划分数据集:对数据进行预处理,包括标准化以及将数据划分为训练集和测试集。接着,使用PyTorch的DataLoader创建数据加载器以便于后续模型训练时的数据批次处理。

创建模型与编译训练:定义了一个基于RNN的神经网络模型model_rnn,包含RNN层和两个全连接层。模型被移动到之前设定的设备(GPU或CPU)上。

编译及训练模型:定义了训练和测试函数,分别用于执行模型的训练过程和评估过程。采用交叉熵损失作为损失函数,Adam优化器作为优化算法。经过30个epoch的训练后,记录并打印出每个epoch的训练和测试准确率及损失值。

结果可视化:使用Matplotlib绘制训练和测试的准确率与损失的变化曲线图,直观地展示模型的学习效果。同时,还生成了混淆矩阵以进一步分析模型性能。

模型预测:最后,选取了一条测试数据进行模型预测,输出预测结果,并解释了预测结果的意义(是否患病)。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/diannao/91773.shtml
繁体地址,请注明出处:http://hk.pswp.cn/diannao/91773.shtml
英文地址,请注明出处:http://en.pswp.cn/diannao/91773.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

equals和hashcode方法重写

在 Java 中,当你需要基于对象的内容而非引用地址来判断两个对象是否相等时,就需要重写equals和hashCode方法。以下是具体场景和实现原则:一、为什么需要同时重写这两个方法?equals方法:默认比较对象的内存地址&#xf…

Excel批量生成SQL语句 Excel批量生成SQL脚本 Excel拼接sql

Excel批量生成SQL语句 Excel批量生成SQL脚本 Excel拼接sql一、情境描述在Excel中有标准的格式化数据,如何快速导入到数据库中呢?有些工具支持Excel导入的,则可以快速导入数据---例如Navicat;如果不支持呢,如果将Excel表…

金和OA C6 DelTemp.aspx 存在XML实体注入漏洞(CVE-2025-7523)

免责声明 本文档所述漏洞详情及复现方法仅限用于合法授权的安全研究和学术教育用途。任何个人或组织不得利用本文内容从事未经许可的渗透测试、网络攻击或其他违法行为。 前言:我们建立了一个更多,更全的知识库。每日追踪最新的安全漏洞,追中25HW情报。 更多详情: http…

Android性能优化之启动优化

一、启动性能瓶颈深度分析 1. 冷启动阶段耗时分布阶段耗时占比关键阻塞点进程创建15%fork进程 加载ZygoteApplication初始化40%ContentProvider/库初始化Activity创建30%布局inflate 视图渲染首帧绘制15%VSync信号等待 GPU渲染2. 高频性能问题 初始化风暴:多个库…

中国优秀开源软件及企业调研报告

中国优秀开源软件及企业调研报告 引言 当前中国开源生态呈现蓬勃发展态势,技术创新领域尤为活跃,其中人工智能大模型成为开源动作的核心聚焦方向。2025年上半年,国内AI领域开源生态迎来密集爆发,头部科技企业相继推出重要开源举…

C++语法 匿名对象 与 命名对象 的详细区分

目录一、匿名对象的本质定义二、匿名对象的调用逻辑:即生即用的设计三、与命名对象的核心差异四、匿名对象的典型应用场景五、匿名对象的潜在风险与规避六、总结:匿名对象的价值定位在 C 类与对象的知识体系中,匿名对象是一种容易被咱们忽略&…

【Fedora 42】Linux内核升级后,鼠标滚轮失灵,libinput的锅?

解决: 最近在玩Fedora 42,升级了一次给俺鼠标滚轮干失灵了。原因可能是 libinput 升级后与Fedora升级后的某些配置有冲突?(搞不懂) sudo dnf downgrade libinput降级 libinput (1.28.901-1.fc42 -> 1.28.0-1.fc42) …

虚拟机centos服务器安装

创建虚拟机选择镜像启动 移除旧的repo文件: sudo rm -f /etc/yum.repos.d/CentOS-Base.repo下载阿里云的repo文件: 对于CentOS 7: sudo wget -O /etc/yum.repos.d/CentOS-Base.repo http://mirrors.aliyun.com/repo/Centos-7.repo清除缓存并生…

【js(1)一文解决】var let const

var let const!在 ES6 之前,JavaScript 只有两种作用域: 全局变量 与 函数内的局部变量一、var1. 函数级作用域,有变量提升二、let(ES6新增)1. 块级作用域,不会影响外部作用域2.let 关键字在不同…

论螺旋矩阵

螺旋矩阵题型总结。我刷了几道螺旋矩阵相关的题目,这里我们介绍一下一些常见的解法。 螺旋矩阵 方形矩阵 当我们遇到n*n的方形矩阵时,可以用一种特殊的解法来遍历实现,以下面这道题为例: 59. 螺旋矩阵 II 我们可以定义几个变…

数学金融与金融工程:学科差异与选择指南

在金融领域的学习中,数学金融与金融工程常被混淆。两者虽同属 “金融 量化” 交叉方向,但在研究侧重、培养路径上有显著区别。结合学科特点与行业实践,帮大家理清两者的核心差异,以便更精准地选择方向。一、核心差异:…

包管理工具npm cnpm yarn的使用

包管理工具 1. 什么是包管理工具? 包管理工具是用于管理和安装 Node.js 项目依赖的工具。它们提供了一种结构化的方式来管理项目的依赖关系,使得项目的依赖管理变得更加便捷和可靠。 2. 常见的包管理工具有哪些? npm(Node Package Manager):是 Node.js 的默认包管理工…

网络基础13--链路聚合技术

一、链路聚合概述定义将多条物理链路捆绑为一条逻辑链路,提升带宽与可靠性。2. 应用场景交换机/路由器/服务器之间的互联,支持二层(数据链路层)和三层(网络层)聚合。二、核心作用增加带宽聚合链路的总带宽 …

一文讲清楚React性能优化

文章目录一文讲清楚React性能优化1. React性能优化概述2. React性能优化2.1 render优化2.2 较少使用内联函数2.3 使用React Fragments避免额外标记2.4 使用Immutable上代码2.5 组件懒加载2.6 服务端渲染2.7 其他优化手段一文讲清楚React性能优化 1. React性能优化概述 React通…

3.0 - 指针-序列化

一、关于Serialize的使用 可以使用该指令临时将用户程序的多个结构化数据项保存到缓冲区中(最好位于全局数据块中)。用于保存转换后数据的存储区的数据类型必需为 ARRAY of BYTE 或 ARRAY of CHAR 相当于把一个struct或其他自定义类型变成一个字节数组。 比如我有好几个结构体…

【论文精读】基于共识的分布式量子分解算法用于考虑最优传输线切换的安全约束机组组合

本次分析的论文《Consensus‐Based Distributed Quantum Decomposition Algorithm for Security‐Constrained Unit Commitment Considering Optimal Transmission Switching》于2025年6月25日在《Advanced Quantum Technologies》期刊上公开发表。本文提出了一个新的基于共识的…

MyBatis-Flex代码生成

引入依赖 <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-web</artifactId> </dependency><dependency><groupId>org.projectlombok</groupId><artifactId>lombok<…

知网论文批量下载pdf格式论文,油猴脚本

任务描述 今天收到一个任务&#xff0c;在知网上&#xff0c;把一位专家所有的论文全都下载下来&#xff0c;要保存为PDF格式。 知网不支持批量导出PDF格式论文。一个一个下载PDF&#xff0c;太繁琐了。 解决方案&#xff1a;找到一个油猴脚本&#xff0c;这个脚本可以从知网…

低代码平台:驱动项目管理敏捷开发新范式

随着企业数字化转型加速&#xff0c;项目管理系统已从单一任务跟踪工具到集成流程自动化、资源调度、跨团队协作与风险监控的综合平台&#xff0c;项目管理系统的功能复杂度持续提升。然而&#xff0c;根据Gartner 2024年研究报告显示&#xff0c;约60%的项目管理系统因未能有效…

图机器学习(11)——链接预测

图机器学习&#xff08;11&#xff09;——链接预测0. 链接预测1. 基于相似性的方法1.1 基于指标的方法1.2 基于社区的方法2. 基于嵌入的方法0. 链接预测 链接预测 (link prediction)&#xff0c;也称为图补全&#xff0c;是处理图时常见的问题。具体而言&#xff0c;给定一个…