MLP神经网络的训练

知识点回顾:

  1. PyTorch和cuda的安装
  2. 查看显卡信息的命令行命令(cmd中使用)
  3. cuda的检查
  4. 简单神经网络的流程
    1. 数据预处理(归一化、转换成张量)
    2. 模型的定义
      1. 继承nn.Module类
      2. 定义每一个层
      3. 定义前向传播流程
    3. 定义损失函数和优化器
    4. 定义训练流程
    5. 可视化loss过程
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
import numpy as np
import matplotlib.pyplot as plt# 检查CUDA可用性并输出相关信息
if torch.cuda.is_available():print("CUDA可用!")device_count = torch.cuda.device_count()print(f"可用的CUDA设备数量:{device_count}")current_device = torch.cuda.current_device()print(f"当前使用的CUDA设备索引:{current_device}")device_name = torch.cuda.get_device_name(current_device)print(f"当前CUDA设备的名称:{device_name}")cuda_version = torch.version.cudaprint(f"CUDA版本:{cuda_version}")
else:print("CUDA不可用。")# 加载并准备Iris数据集
iris = load_iris()
X = iris.data
y = iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)print(X_train.shape)
print(y_train.shape)
print(X_test.shape)
print(y_test.shape)# 数据标准化处理
scaler = MinMaxScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)# 转换为PyTorch张量
X_train = torch.FloatTensor(X_train)
y_train = torch.LongTensor(y_train)
X_test = torch.FloatTensor(X_test)
y_test = torch.LongTensor(y_test)# 定义多层感知机模型
class MLP(nn.Module):def __init__(self):super(MLP, self).__init__()self.fc1 = nn.Linear(4, 10)self.relu = nn.ReLU()self.fc2 = nn.Linear(10, 3)def forward(self, x):out = self.fc1(x)out = self.relu(out)out = self.fc2(out)return out# 初始化模型、损失函数和优化器
model = MLP()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)  # 添加了学习率参数# 训练模型
num_epochs = 20000
losses = []for epoch in range(num_epochs):outputs = model.forward(X_train)loss = criterion(outputs, y_train)optimizer.zero_grad()loss.backward()optimizer.step()losses.append(loss.item())if (epoch + 1) % 100 == 0:print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')# 可视化训练损失
plt.plot(range(num_epochs), losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss over Epochs')
plt.show()

输出结果:

(120, 4)
(120,)  
(30, 4) 
(30,)   
Epoch [100/20000], Loss: 1.0980
Epoch [200/20000], Loss: 1.0692
Epoch [300/20000], Loss: 1.0398
Epoch [400/20000], Loss: 1.0060
Epoch [500/20000], Loss: 0.9665
Epoch [600/20000], Loss: 0.9214
Epoch [700/20000], Loss: 0.8731
Epoch [800/20000], Loss: 0.8241
Epoch [900/20000], Loss: 0.7777
Epoch [1000/20000], Loss: 0.7360
Epoch [1100/20000], Loss: 0.6991
Epoch [1200/20000], Loss: 0.6666
Epoch [1300/20000], Loss: 0.6380
Epoch [1400/20000], Loss: 0.6127
Epoch [1500/20000], Loss: 0.5903
Epoch [1600/20000], Loss: 0.5703
Epoch [1700/20000], Loss: 0.5523
Epoch [1800/20000], Loss: 0.5360
Epoch [1900/20000], Loss: 0.5211
Epoch [2000/20000], Loss: 0.5074
Epoch [2100/20000], Loss: 0.4948
Epoch [2200/20000], Loss: 0.4830
Epoch [2300/20000], Loss: 0.4720
Epoch [2400/20000], Loss: 0.4616
Epoch [2500/20000], Loss: 0.4518
Epoch [2600/20000], Loss: 0.4425
Epoch [2700/20000], Loss: 0.4336
Epoch [2800/20000], Loss: 0.4251
Epoch [2900/20000], Loss: 0.4168
Epoch [3000/20000], Loss: 0.4088
Epoch [3100/20000], Loss: 0.4010
Epoch [3200/20000], Loss: 0.3935
Epoch [3300/20000], Loss: 0.3860
Epoch [3400/20000], Loss: 0.3787
Epoch [3500/20000], Loss: 0.3716
Epoch [3600/20000], Loss: 0.3645
Epoch [3700/20000], Loss: 0.3576
Epoch [3800/20000], Loss: 0.3508
Epoch [3900/20000], Loss: 0.3440
Epoch [4000/20000], Loss: 0.3374
Epoch [4100/20000], Loss: 0.3308
Epoch [4200/20000], Loss: 0.3243
Epoch [4300/20000], Loss: 0.3179
Epoch [4400/20000], Loss: 0.3115
Epoch [4500/20000], Loss: 0.3053
Epoch [4600/20000], Loss: 0.2991
Epoch [4700/20000], Loss: 0.2930
Epoch [4800/20000], Loss: 0.2870
Epoch [4900/20000], Loss: 0.2810
Epoch [5000/20000], Loss: 0.2752
Epoch [5100/20000], Loss: 0.2694
Epoch [5200/20000], Loss: 0.2638
Epoch [5300/20000], Loss: 0.2583
Epoch [5400/20000], Loss: 0.2530
Epoch [5500/20000], Loss: 0.2478
Epoch [5600/20000], Loss: 0.2428
Epoch [5700/20000], Loss: 0.2378
Epoch [5800/20000], Loss: 0.2330
Epoch [5900/20000], Loss: 0.2284
Epoch [6000/20000], Loss: 0.2238
Epoch [6100/20000], Loss: 0.2193
Epoch [6200/20000], Loss: 0.2150
Epoch [6300/20000], Loss: 0.2108
Epoch [6400/20000], Loss: 0.2067
Epoch [6500/20000], Loss: 0.2027
Epoch [6600/20000], Loss: 0.1989
Epoch [6700/20000], Loss: 0.1951
Epoch [6800/20000], Loss: 0.1914
Epoch [6900/20000], Loss: 0.1878
Epoch [7000/20000], Loss: 0.1844
Epoch [7100/20000], Loss: 0.1810
Epoch [7200/20000], Loss: 0.1778
Epoch [7300/20000], Loss: 0.1746
Epoch [7400/20000], Loss: 0.1716
Epoch [7500/20000], Loss: 0.1686
Epoch [7600/20000], Loss: 0.1658
Epoch [7700/20000], Loss: 0.1630
Epoch [7800/20000], Loss: 0.1604
Epoch [7900/20000], Loss: 0.1578
Epoch [8000/20000], Loss: 0.1553
Epoch [8100/20000], Loss: 0.1528
Epoch [8200/20000], Loss: 0.1505
Epoch [8300/20000], Loss: 0.1483
Epoch [8400/20000], Loss: 0.1461
Epoch [8500/20000], Loss: 0.1440
Epoch [8600/20000], Loss: 0.1420
Epoch [8700/20000], Loss: 0.1400
Epoch [8800/20000], Loss: 0.1381
Epoch [8900/20000], Loss: 0.1363
Epoch [9000/20000], Loss: 0.1345
Epoch [9100/20000], Loss: 0.1328
Epoch [9200/20000], Loss: 0.1312
Epoch [9300/20000], Loss: 0.1295
Epoch [9400/20000], Loss: 0.1280
Epoch [9500/20000], Loss: 0.1265
Epoch [9600/20000], Loss: 0.1250
Epoch [9700/20000], Loss: 0.1236
Epoch [9800/20000], Loss: 0.1222
Epoch [9900/20000], Loss: 0.1208
Epoch [10000/20000], Loss: 0.1195
Epoch [10100/20000], Loss: 0.1183
Epoch [10200/20000], Loss: 0.1170
Epoch [10300/20000], Loss: 0.1159
Epoch [10400/20000], Loss: 0.1147
Epoch [10500/20000], Loss: 0.1136
Epoch [10600/20000], Loss: 0.1125
Epoch [10700/20000], Loss: 0.1114
Epoch [10800/20000], Loss: 0.1104
Epoch [10900/20000], Loss: 0.1094
Epoch [11000/20000], Loss: 0.1084
Epoch [11100/20000], Loss: 0.1075
Epoch [11200/20000], Loss: 0.1065
Epoch [11300/20000], Loss: 0.1056
Epoch [11400/20000], Loss: 0.1048
Epoch [11500/20000], Loss: 0.1039
Epoch [11600/20000], Loss: 0.1031
Epoch [11700/20000], Loss: 0.1023
Epoch [11800/20000], Loss: 0.1015
Epoch [11900/20000], Loss: 0.1007
Epoch [12000/20000], Loss: 0.0999
Epoch [12100/20000], Loss: 0.0992
Epoch [12200/20000], Loss: 0.0985
Epoch [12300/20000], Loss: 0.0978
Epoch [12400/20000], Loss: 0.0971
Epoch [12500/20000], Loss: 0.0964
Epoch [12600/20000], Loss: 0.0958
Epoch [12700/20000], Loss: 0.0951
Epoch [12800/20000], Loss: 0.0945
Epoch [12900/20000], Loss: 0.0939
Epoch [13000/20000], Loss: 0.0933
Epoch [13100/20000], Loss: 0.0927
Epoch [13200/20000], Loss: 0.0922
Epoch [13300/20000], Loss: 0.0916
Epoch [13400/20000], Loss: 0.0910
Epoch [13500/20000], Loss: 0.0905
Epoch [13600/20000], Loss: 0.0900
Epoch [13700/20000], Loss: 0.0895
Epoch [13800/20000], Loss: 0.0890
Epoch [13900/20000], Loss: 0.0885
Epoch [14000/20000], Loss: 0.0880
Epoch [14100/20000], Loss: 0.0875
Epoch [14200/20000], Loss: 0.0871
Epoch [14300/20000], Loss: 0.0866
Epoch [14400/20000], Loss: 0.0862
Epoch [14500/20000], Loss: 0.0857
Epoch [14600/20000], Loss: 0.0853
Epoch [14700/20000], Loss: 0.0849
Epoch [14800/20000], Loss: 0.0845
Epoch [14900/20000], Loss: 0.0840
Epoch [15000/20000], Loss: 0.0837
Epoch [15100/20000], Loss: 0.0833
Epoch [15200/20000], Loss: 0.0829
Epoch [15300/20000], Loss: 0.0825
Epoch [15400/20000], Loss: 0.0821
Epoch [15500/20000], Loss: 0.0818
Epoch [15600/20000], Loss: 0.0814
Epoch [15700/20000], Loss: 0.0811
Epoch [15800/20000], Loss: 0.0807
Epoch [15900/20000], Loss: 0.0804
Epoch [16000/20000], Loss: 0.0800
Epoch [16100/20000], Loss: 0.0797
Epoch [16200/20000], Loss: 0.0794
Epoch [16300/20000], Loss: 0.0791
Epoch [16400/20000], Loss: 0.0788
Epoch [16500/20000], Loss: 0.0785
Epoch [16600/20000], Loss: 0.0782
Epoch [16700/20000], Loss: 0.0779
Epoch [16800/20000], Loss: 0.0776
Epoch [16900/20000], Loss: 0.0773
Epoch [17000/20000], Loss: 0.0770
Epoch [17100/20000], Loss: 0.0767
Epoch [17200/20000], Loss: 0.0765
Epoch [17300/20000], Loss: 0.0762
Epoch [17400/20000], Loss: 0.0759
Epoch [17500/20000], Loss: 0.0757
Epoch [17600/20000], Loss: 0.0754
Epoch [17700/20000], Loss: 0.0751
Epoch [17800/20000], Loss: 0.0749
Epoch [17900/20000], Loss: 0.0747
Epoch [18000/20000], Loss: 0.0744
Epoch [18100/20000], Loss: 0.0742
Epoch [18200/20000], Loss: 0.0739
Epoch [18300/20000], Loss: 0.0737
Epoch [18400/20000], Loss: 0.0735
Epoch [18500/20000], Loss: 0.0733
Epoch [18600/20000], Loss: 0.0730
Epoch [18700/20000], Loss: 0.0728
Epoch [18800/20000], Loss: 0.0726
Epoch [18900/20000], Loss: 0.0724
Epoch [19000/20000], Loss: 0.0722
Epoch [19100/20000], Loss: 0.0720
Epoch [19200/20000], Loss: 0.0718
Epoch [19300/20000], Loss: 0.0716
Epoch [19400/20000], Loss: 0.0714
Epoch [19500/20000], Loss: 0.0712
Epoch [19600/20000], Loss: 0.0710
Epoch [19700/20000], Loss: 0.0708
Epoch [19800/20000], Loss: 0.0706
Epoch [19900/20000], Loss: 0.0704
Epoch [20000/20000], Loss: 0.0702
PS E:\shucai\py>
Epoch [20000/20000], Loss: 0.0702
Epoch [20000/20000], Loss: 0.0702
Epoch [20000/20000], Loss: 0.0702
Epoch [20000/20000], Loss: 0.0702
PS E:\shucai\py>
Epoch [20000/20000], Loss: 0.0702
PS E:\shucai\py>
Epoch [20000/20000], Loss: 0.0702
Epoch [20000/20000], Loss: 0.0702
Epoch [20000/20000], Loss: 0.0702

@浙大疏锦行

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/news/906789.shtml
繁体地址,请注明出处:http://hk.pswp.cn/news/906789.shtml
英文地址,请注明出处:http://en.pswp.cn/news/906789.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

JDK21深度解密 Day 1:JDK21全景图:关键特性与升级价值

【JDK21深度解密 Day 1】JDK21全景图:关键特性与升级价值 引言 欢迎来到《JDK21深度解密:从新特性到生产实践的全栈指南》系列的第一天。今天我们将探讨JDK21的关键特性和升级价值。作为近5年最重要的LTS版本,JDK21不仅带来了性能上的巨大突…

[docker]更新容器中镜像版本

从peccore-dev仓库拉取镜像 docker pull 10.12.135.238:8060/peccore-dev/configserver:v1.13.45如果报错,请参考docker拉取镜像失败,添加仓库地址 修改/etc/CET/Common/peccore-docker-compose.yml文件中容器的版本,为刚刚拉取的版本 # 配置中心confi…

LVS原理详解及LVS负载均衡工作模式

什么是虚拟服务器(LVS) 虚拟服务器是高度可扩展且高度可用的服务器 构建在真实服务器集群上。服务器集群的架构 对最终用户完全透明,并且用户与 cluster 系统,就好像它只是一个高性能的虚拟 服务器。请考虑下图。 真实服务器和负…

上位机知识篇---keil IDE操作

文章目录 前言文件操作按键新建打开保存保存所有编辑操作按键撤销恢复复制粘贴剪切全选查找书签操作按键添加书签跳转到上一个书签跳转到下一个书签清空所有书签编译操作按键编译当前文件构建目标文件重新构建调试操作按键进入调试模式复位全速运行停止运行单步调试逐行调试跳出…

前端大文件上传性能优化实战:分片上传分析与实战

前端文件分片是大文件上传场景中的重要优化手段,其必要性和优势主要体现在以下几个方面: 一、必要性分析 1. 突破浏览器/服务器限制 浏览器限制:部分浏览器对单次上传文件大小有限制(如早期IE限制4GB) 服务器限制&a…

解决react-router-dom没有支持name命名使用的问题

1. 前言 react-router-dom 并不能像 vue 的route 那样给每个路由命名 name ,导致代码不能解耦路由路径与导航逻辑。 2. react-router 为什么没有支持? 很早之前官方 issue 中就有过很多讨论: 翻译过来,就是由于以下几个重要原…

Spring AI 之结构化输出转换器

截至 2024 年 2 月 5 日,旧的 OutputParser、BeanOutputParser、ListOutputParser 和 MapOutputParser 类已被弃用,取而代之的是新的 StructuredOutputConverter、BeanOutputConverter、ListOutputConverter 和 MapOutputConverter 实现类。后者可直接替换前者,并提供相同的…

MCP与AI模型的多语言支持:让人工智能更懂世界

MCP与AI模型的多语言支持:让人工智能更懂世界 在人工智能(AI)的时代,我们追求的不仅是强大的计算能力,更是让AI能够理解并使用不同语言,真正服务全球用户。而这背后,一个至关重要的技术就是 MCP(Multi-Context Processing,多上下文处理) ——一种旨在优化 AI 模型理…

【MySQL】 数据库基础数据类型

一、数据库简介 1.什么是数据库 数据库(Database)是一种用于存储、管理和检索数据的系统化集合。它允许用户以结构化的方式存储大量数据,并通过高效的方式访问和操作这些数据。数据库通常由数据库管理系统(DBMS)管理&…

NRM:快速切换 npm 镜像源的管理工具指南

🚀 NRM:快速切换 npm 镜像源的管理工具指南 🔍 什么是 NRM? NRM(Npm Registry Manager) 是一个用于管理 npm 镜像源的命令行工具。 它能帮助开发者 ⚡快速切换 不同的 npm 源(如官方源、淘宝源…

基于Java的话剧购票小程序【附源码】

摘 要 随着文化产业的蓬勃发展,话剧艺术日益受到大众喜爱,便捷的购票方式成为观众的迫切需求。当前传统购票渠道存在购票流程繁琐、信息获取不及时等问题。本研究致力于开发一款基于 Java 的话剧购票小程序,Java 语言具有跨平台性、稳定性和…

Pr -- 耳机没有Pr输出的声音

问题 很久没更新视频号了,想用pr剪辑一下,结果使用Pr打开后发现耳机没有Pr输出的声音 解决方法 在编辑--首选项-音频硬件中设置音频硬件的输出为当前耳机设备

Leaflet根据坐标画圆形区域

在做地图应用时,有时需要根据指定的坐标来画一个圆形区域,比如签到打卡类的应用,此时我们可以使用 leaflet.Circle 来在在指定坐标上创建一个圆并添加到的地图上,其中可以通过 radius 属性可以指定区域半径,比如: con…

vue3中使用computed

在 Vue 3 中,computed 是一个非常重要的响应式 API,用于声明依赖于其他响应式状态的派生状态。以下是 computed 的详细用法: 1. 基本用法 import { ref, computed } from vueexport default {setup() {const firstName ref(张)const lastN…

【iOS】类结构分析

前言 之前我们已经探索得出对象的本质就是一个带有isa指针的结构体,这篇文章来分析一下类的结构以及类的底层原理。 类的本质 类的本质 我们在main函数中写入以上代码,然后利用clang对其进行反编译,可以得到c文件 可以看到底层使用Class接…

Vanna.AI:解锁连表查询的新境界

Vanna.AI:解锁连表查询的新境界 在当今数字化时代,数据已成为企业决策的核心驱动力。然而,从海量数据中提取有价值的信息并非易事,尤其是当数据分散在多个表中时,连表查询成为了数据分析师和开发者的日常挑战。传统的…

前端流行框架Vue3教程:24.动态组件

24.动态组件 有些场景会需要在两个组件间来回切换&#xff0c;比如 Tab 界面 我们准备好A B两个组件ComponentA ComponentA App.vue代码如下&#xff1a; <script> import ComponentA from "./components/ComponentA.vue" import ComponentB from "./…

海拔案例分享-实践活动报名测评小程序

大家好&#xff0c;今天湖南海拔科技想和大家分享一款实践活动报名测评小程序&#xff0c;客户是长沙一家专注青少年科创教育的机构&#xff0c;这家机构平时要组织各种科创比赛、培训课程&#xff0c;随着学员增多&#xff0c;管理上的问题日益凸显&#xff1a;每次组织活动&a…

【MySQL】CRUD

CRUD 简介 CRUD是对数据库中的记录进行基本的增删改查操作 Create&#xff08;创建&#xff09;Retrieve&#xff08;读取&#xff09;Update&#xff08;更新&#xff09;Delete&#xff08;删除&#xff09; 一、新增&#xff08;Create&#xff09; 语法&#xff1a; I…

【数据架构04】数据湖架构篇

✅ 10张高质量数据治理架构图 无论你是数据架构师、治理专家&#xff0c;还是数字化转型负责人&#xff0c;这份资料库都能为你提供体系化参考&#xff0c;高效解决“架构设计难、流程不清、平台搭建慢”的痛点&#xff01; &#x1f31f;限时推荐&#xff0c;速速收藏&#…