功能概述

这段代码实现了一个基于TensorFlow和Keras的MNIST手写数字识别模型。主要功能包括:

  1. 加载并预处理MNIST数据集
  2. 构建一个简单的全连接神经网络模型
  3. 训练模型并评估其性能
  4. 使用训练好的模型进行预测
  5. 保存和加载模型

代码解析

1. 导入必要的库

import matplotlib
import tensorflow.keras as keras
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from pasta.augment import inline
  • 导入TensorFlow和Keras用于构建和训练神经网络
  • 导入NumPy用于数值计算
  • 导入Matplotlib用于数据可视化
  • 从pasta.augment导入inline用于在Jupyter Notebook中直接显示图像

2. 打印TensorFlow版本

print(tf.__version__)

输出当前使用的TensorFlow版本,用于环境检查。

3. 加载MNIST数据集

path = '../doc/mnist.npz'
with np.load(path) as data:x_train, y_train = data['x_train'], data['y_train']x_test, y_test = data['x_test'], data['y_test']
print(x_train[0])
  • 从本地文件加载MNIST数据集
  • 数据集包含训练集(x_train, y_train)和测试集(x_test, y_test)
  • 打印第一个训练样本的像素值

4. 数据可视化

%matplotlib inline
plt.imshow(x_train[0], cmap=plt.cm.binary)
plt.show()
  • 使用Matplotlib显示第一个训练样本的图像
  • cmap=plt.cm.binary设置为黑白显示

5. 打印第一个训练样本的标签

print(y_train[0])

输出第一个训练样本对应的数字标签。

6. 数据归一化

x_train = tf.keras.utils.normalize(x_train, axis=1)
x_test = tf.keras.utils.normalize(x_test, axis=1)
print(x_train[0])
  • 对图像数据进行归一化处理,将像素值缩放到0-1范围
  • 打印归一化后的第一个训练样本

7. 构建神经网络模型

model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Flatten(input_shape=(28, 28)))
model.add(tf.keras.layers.Dense(128, activation=tf.nn.relu))
model.add(tf.keras.layers.Dense(128, activation=tf.nn.relu))
model.add(tf.keras.layers.Dense(10, activation=tf.nn.softmax))
  • 创建一个Sequential模型
  • 添加Flatten层将28x28的图像展平为784维向量
  • 添加两个全连接层(Dense),每层128个神经元,使用ReLU激活函数
  • 添加输出层,10个神经元对应10个数字类别,使用Softmax激活函数

8. 编译模型

model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])
  • 使用Adam优化器
  • 使用稀疏分类交叉熵作为损失函数
  • 使用准确率作为评估指标

9. 训练模型

model.fit(x_train, y_train, epochs=3)
  • 训练模型3个epoch
  • 使用训练数据进行拟合

10. 评估模型

val_loss, val_acc = model.evaluate(x_test, y_test)
print(val_loss)
print(val_acc)
  • 在测试集上评估模型性能
  • 输出测试损失和准确率

11. 使用模型进行预测

predictions = model.predict(x_test)
print(predictions)
print(np.argmax(predictions[0]))
plt.imshow(x_test[0], cmap=plt.cm.binary)
plt.show()
  • 对测试集进行预测
  • 打印预测结果(概率分布)
  • 使用argmax获取第一个测试样本的预测标签
  • 显示第一个测试样本的图像

12. 保存和加载模型

def softmax_v2(x):return tf.keras.activations.softmax(x)new_model = tf.keras.models.load_model('epic_num_reader.model.keras',custom_objects={'softmax_v2': softmax_v2}
)predictions = new_model.predict(x_test)
print(np.argmax(predictions[0]))
  • 定义一个softmax_v2函数用于兼容性
  • 加载之前保存的模型
  • 使用加载的模型进行预测

总结

这段代码实现了一个简单但有效的MNIST手写数字分类器。主要特点包括:

  1. 使用全连接神经网络结构
  2. 实现了数据预处理和归一化
  3. 达到了较高的测试准确率(约97%)
  4. 包含了模型保存和加载功能
  5. 提供了可视化工具检查数据和预测结果

demo001.ipynb

# 导入 keras 模块
import matplotlib
import tensorflow.keras as keras
# 导入 tensorflow 模块
import tensorflow as tf
# 导入 pasta 模块中的 augment 和 inline 子模块
from pasta.augment import inline# 打印 TensorFlow 的版本
print(tf.__version__)# 指定本地文件路径
path = '../doc/mnist.npz'
# 导入 numpy 模块
import numpy as np
# 从本地加载 MNIST 数据集
with np.load(path) as data:x_train, y_train = data['x_train'], data['y_train']x_test, y_test = data['x_test'], data['y_test']
# 打印训练数据集的第一个样本
print(x_train[0])# 导入 matplotlib.pyplot 模块
import matplotlib.pyplot as plt
# 使用 inline 后,图形将直接显示在 Jupyter Notebook 中
# %matplotlib inline
# 可视化训练数据集的第一个样本
plt.imshow(x_train[0], cmap=plt.cm.binary)
plt.show()# 打印训练标签的第一个样本
print(y_train[0])# 对训练和测试数据进行归一化处理
x_train = tf.keras.utils.normalize(x_train, axis=1)
x_test = tf.keras.utils.normalize(x_test, axis=1)# 打印归一化后的训练数据集的第一个样本
print(x_train[0])# 可视化归一化后的训练数据集的第一个样本
plt.imshow(x_train[0], cmap=plt.cm.binary)
plt.show()# 创建一个 Sequential 模型
model = tf.keras.models.Sequential()
# 添加一个 Flatten 层,用于将输入数据展平
model.add(tf.keras.layers.Flatten(input_shape=(28, 28)))
# 添加一个 Dense 层,包含 128 个神经元,使用 ReLU 激活函数
model.add(tf.keras.layers.Dense(128, activation=tf.nn.relu))
# 再添加一个 Dense 层,配置同上
model.add(tf.keras.layers.Dense(128, activation=tf.nn.relu))
# 添加一个 Dense 层,包含 10 个神经元,使用 Softmax 激活函数
model.add(tf.keras.layers.Dense(10, activation=tf.nn.softmax))
# 编译模型,指定优化器、损失函数和评估指标
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])
# 训练模型
model.fit(x_train, y_train, epochs=3)
# 评估模型
val_loss, val_acc = model.evaluate(x_test, y_test)
print(val_loss)
print(val_acc)# 使用模型进行预测
predictions = model.predict(x_test)
print(predictions)# 导入 numpy 模块
import numpy as np# 打印第一个测试样本的预测标签
print(np.argmax(predictions[0]))# 可视化第一个测试样本
plt.imshow(x_test[0], cmap=plt.cm.binary)
plt.show()# 保存模型
def softmax_v2(x):# 将 softmax_v2 映射到标准 softmaxreturn tf.keras.activations.softmax(x)# 加载之前保存的模型
new_model = tf.keras.models.load_model('epic_num_reader.model.keras',custom_objects={'softmax_v2': softmax_v2}
)# 使用加载的模型进行预测
predictions = new_model.predict(x_test)
# 打印第一个测试样本的预测标签
print(np.argmax(predictions[0]))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/pingmian/90275.shtml
繁体地址,请注明出处:http://hk.pswp.cn/pingmian/90275.shtml
英文地址,请注明出处:http://en.pswp.cn/pingmian/90275.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

进阶系统策略

该策略主要基于价格动态分析,结合多种技术指标和数学计算来生成交易信号。其核心逻辑包括: 1. 价格极值计算:首先,策略计算给定周期(由`Var3`定义)内的最高价和最低价,分别存储在`Var12`和`Var13`中。这一步骤旨在捕捉价格的短期波动范围。 2. 相对位置计算:接着,策…

【Linux内核】Linux驱动开发

推荐书籍: 《Linux内核探秘:深入解析文件系统和设备驱动的架构与设计》 知识点 x86的IO地址空间和内存地址空间是独立的两套地址空间,并且使用不同的指令访问。MOV, IN, OUT。内存映射I/O可以将IO映射到内存。ARM等RISC采用统一编编址&#x…

MySQL用户管理(15)

文章目录前言一、用户用户信息创建用户修改密码删除用户二、数据库的权限MySQL中的权限给用户授权回收权限总结前言 其实与 Linux 操作系统类似,MySQL 中也有 超级用户 和 普通用户 之分 如果一个用户只需要访问 MySQL 中的某一个数据库,甚至数据库中的某…

react19相关问题和解答

目录 1. react19将ref放在了props中(不再需要 forwardRef),那么是不是可以通过ref获取子组件的全部变量了? 我的子组件的useImperativeHandle还需要定义吗? 1.1. ref 在 props 中的本质变化 1.2. 为什么不能访问全部变量? 2. In HTML,cannot be a descendant of. Thi…

Code Composer Studio:CCS 设置代码折叠

Code Composer Studio:设置代码折叠,可以按函数,if, 等把代码折叠起来。1.2.开启折叠选项3.开启后,如果文件已经打开,要关掉重新打开文件就可以开到折叠功能生效。

JMeter groovy 编译成.jar 文件

groovy 编译 一、windows 下手动安装Groovy 下载 Groovy 二进制包 前往官网:https://groovy.apache.org/download.html 下载 Binary release( https://groovy.jfrog.io/ui/native/dist-release-local/groovy-zips/apache-groovy-sdk-4.0.27.zip &#xf…

使用maven-shade-plugin解决依赖版本冲突

项目里引入多个版本依赖时,最后只会使用其中一个,一般可以通过排除不使用的依赖处理,但是如果需要同时使用多个版本,可以使用maven-shade-plugin解决。以最典型的poi为例,poi版本兼容性很低,如果出现找不到…

[CH582M入门第十一步]DS18B20驱动

学习目标: 1、介绍DS18B20 2、学习单总线 3、学习DS18B20程序驱动一、DS18B20介绍 DS18B20 是一款由 Maxim Integrated(原Dallas Semiconductor) 推出的 数字温度传感器,以其单总线(1-Wire)通信协议、高精度和广泛应用而闻名。以下是其核心特点和应用介绍: 主要特性 数…

SGLang + 分布式推理部署DeepSeek671B满血版

部署设备:28A100 80G,两台机器,每台机器8张A100。 模型:deepseek-671B-int8 模型下载地址:https://huggingface.co/meituan/DeepSeek-R1-Block-INT8 模型参考: 1、SGLang Docker部署 github地址&#…

PCL 间接平差拟合球

目录 一、算法原理 1、计算流程 2、参考文献 二、代码实现 三、结果展示 本文由CSDN点云侠原创,首发于2025年7月24日。博客长期更新,本文最新更新时间为:2025年7月24日。 一、算法原理 1、计算流程 空间球方程: ( x − a ) 2 + ( y − b ) 2 + ( z − c ) 2 = R 2 (1) (…

基于 HAProxy 搭建 EMQ X 集群

负载均衡器(LB)负责分发设备的 MQTT 连接与消息到 EMQ X 集群,采用 LB 可以提高 EMQ X 集群可用性、实现负载平衡以及动态扩容。 HAProxy简介 HAProxy 是一款高性能的 开源负载均衡器 和 反向代理服务器,主要用于在多个服务器之…

RISC-V基金会Datacenter SIG月会圆满举办,探讨RAS、PMU性能分析实践和经验

一直以来,龙蜥社区在 RISC-V 生态建设中持续投入,并积极贡献上游社区。多位龙蜥社区成员在 RISC-V 国际基金会担任主席/副主席角色,与来自阿里云、阿里达摩院、中兴通讯、浪潮信息、中科院软件所、字节跳动、Google、 MIT、Akeana 等企业的专…

CloudComPy使用PyInstaller打包后报错解决方案

情况描述 笔者在spec文件中,datas变量设置如下。如果你的报错类似于“找不到cloudComPy”,先尝试如下的设置。 datas[(CloudCompare,cloudComPy)], 笔者在打包完成后,打开软件发现报错: from cloudComPy import* ModuleNotFoun…

node.js中的path模块

在 Node.js 中,path 模块提供了处理和操作文件路径的功能,其中 path.join 和 path.resolve 是两个常用的方法。它们在处理路径时有不同的行为和用途: 功能概述 path.join(): 该方法主要用于将多个路径片段拼接成一个完整的路径字符串。它会正…

将Scrapy项目容器化:Docker镜像构建的工程实践

引言:爬虫容器化的战略意义在云原生与微服务架构主导的时代,​​容器化技术​​已成为爬虫项目交付的黄金标准。据2023年分布式系统调查报告显示:92%的生产爬虫系统采用容器化部署容器化使爬虫环境配置时间​​减少87%​​Docker化爬虫的故障…

Unity × RTMP × 头显设备:打造沉浸式工业远控视频系统的完整方案

结合工业现场需求,探索如何通过大牛直播SDK打造可在 Pico、Quest 等头显设备中运行的 RTMP 低延迟播放器,助力构建沉浸式远程操控系统。 一、背景:沉浸式远程操控的新趋势 随着工业自动化、5G 专网、XR 技术的发展,远程操控正在从…

HTTPS如何保障安全?详解证书体系与加密通信流程

HTTP协议本身是明文传输的,安全性较低,因此现代互联网普遍采用 HTTPS(HTTP over TLS/SSL) 来实现加密通信。HTTPS的核心是 TLS/SSL证书体系 和 加密通信流程。一、HTTPS 证书体系HTTPS依赖 公钥基础设施(PKI, Public K…

数据的评估与清洗篇---清洗数据

处理前的准备 检查索引与列名 在处理内容之前,需要先看看索引或列名是否有意义,若索引和列名都是乱七八糟的,应该对他们进行重命名或者重新排序,以便我们理解数据。 清洗数据 清洗数据原则 针对数据内容,一般先解决结构性问题,再处理内容性问题。整洁数据的特点是: …

Ubuntu apt和apt-get的区别

好的,这是一个非常经典且重要的问题。apt install 和 apt-get install 的区别是很多 Ubuntu/Debian 新手都会遇到的困惑。 简单来说,它们的功能非常相似,但设计目标和用户体验不同。 一句话总结 apt 是 apt-get 的一个更新、更友好、更现代化…

多端适配灾难现场:可视化界面在PC/平板/大屏端的响应式布局实战

摘要精心设计的可视化大屏,在平板上显示时图表挤成一团,在PC端操作按钮小到难以点击,某企业的可视化项目曾因多端适配失败沦为“灾难现场”,不仅用户差评如潮,还被竞争对手嘲讽技术落后。多端适配真的只能靠“反复试错…