在这里插入图片描述

本节课你将学到

  • 理解什么是TensorFlow,为什么要用它
  • 掌握TensorFlow安装和基本操作
  • 学会搭建第一个神经网络
  • 完成手写数字识别项目

开始之前

环境要求

  • Python 3.8+
  • 至少4GB内存
  • 网络连接(用于下载数据集)

前置知识

  • 第1-8讲:Python基础和开发环境
  • 基本的数学概念(加减乘除即可)

什么是TensorFlow?

用最简单的话解释

想象你要盖房子:

  • 传统编程:你需要自己制作每一块砖头、每一根钢筋
  • TensorFlow:就像一个预制构件工厂,砖头、钢筋、水泥都给你准备好了,你只需要按图纸组装

TensorFlow就是Google开发的"AI积木工厂",它提供了:

  • 🧱 基础积木:各种数学运算函数
  • 🔧 组装工具:神经网络层、优化器
  • 📏 测量工具:损失函数、评估指标
  • 🏭 生产线:自动训练和优化

为什么选择TensorFlow?

  1. 简单易用:像搭积木一样构建神经网络
  2. 功能强大:支持从简单分类到复杂的图像识别
  3. 社区庞大:遇到问题容易找到解决方案
  4. 工业级:Google、Netflix等大公司都在用

TensorFlow安装

安装步骤

# 方法1:使用pip安装(推荐)
pip install tensorflow# 方法2:如果上面很慢,使用国内镜像
pip install tensorflow -i https://pypi.tuna.tsinghua.edu.cn/simple/# 验证安装
python -c "import tensorflow as tf; print('TensorFlow版本:', tf.__version__)"

验证安装

import tensorflow as tf# 检查版本
print("TensorFlow版本:", tf.__version__)# 检查是否支持GPU(有GPU会显示GPU信息,没有也正常)
print("GPU可用:", len(tf.config.list_physical_devices('GPU')) > 0)# 简单测试
hello = tf.constant("Hello, TensorFlow!")
print(hello.numpy().decode())

预期输出:

TensorFlow版本: 2.x.x
GPU可用: False  # 没有GPU也没关系
Hello, TensorFlow!

TensorFlow核心概念

1. 张量(Tensor)- 数据容器

张量就是多维数组,就像不同形状的盒子:

import tensorflow as tf
import numpy as np# 0维张量(标量)- 一个数字
scalar = tf.constant(42)
print("标量:", scalar)# 1维张量(向量)- 一行数字
vector = tf.constant([1, 2, 3, 4])
print("向量:", vector)# 2维张量(矩阵)- 表格
matrix = tf.constant([[1, 2], [3, 4]])
print("矩阵:")
print(matrix)# 3维张量 - 立体数据(比如彩色图片)
tensor_3d = tf.constant([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
print("3维张量形状:", tensor_3d.shape)

2. 计算图 - 操作流程

TensorFlow会自动记录你的操作,就像记录菜谱步骤:

# 定义变量(可以改变的数)
x = tf.Variable(3.0, name="x")
y = tf.Variable(4.0, name="y")# 定义计算(TensorFlow会记录这些步骤)
z = x * x + y * y  # z = x² + y²print("x =", x.numpy())
print("y =", y.numpy()) 
print("z = x² + y² =", z.numpy())

3. 自动微分 - 神经网络的关键

神经网络需要不断调整参数,TensorFlow可以自动计算如何调整:

# 使用GradientTape记录操作
x = tf.Variable(2.0)with tf.GradientTape() as tape:y = x * x * x  # y = x³# 自动计算导数(斜率)
dy_dx = tape.gradient(y, x)
print(f"当x={x.numpy()}时,y=x³的导数是:{dy_dx.numpy()}")
print("手工计算:3*2²=12,验证正确!")

第一个神经网络

问题:预测房价

假设我们要根据房屋面积预测房价,这是一个最简单的神经网络:

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt# 设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False# 1. 准备数据
# 假设房价 = 面积 * 2 + 一些随机噪声
np.random.seed(42)  # 固定随机种子,确保结果可重现
areas = np.random.uniform(50, 200, 100)  # 100个面积数据,50-200平米
prices = areas * 2 + np.random.normal(0, 10, 100)  # 价格=面积*2+噪声# 数据标准化(重要!神经网络喜欢小数字)
areas_norm = (areas - areas.mean()) / areas.std()
prices_norm = (prices - prices.mean()) / prices.std()print("数据准备完成!")
print(f"面积范围:{areas.min():.1f} - {areas.max():.1f} 平米")
print(f"价格范围:{prices.min():.1f} - {prices.max():.1f} 万元")
# 2. 构建神经网络
# 最简单的神经网络:只有一层,一个神经元
model = tf.keras.Sequential([tf.keras.layers.Dense(1, input_shape=[1], name='price_predictor')
])# 编译模型(设置学习规则)
model.compile(optimizer='adam',      # 优化器:adam是最常用的loss='mse',           # 损失函数:均方误差metrics=['mae']       # 评估指标:平均绝对误差
)# 查看模型结构
print("模型结构:")
model.summary()
# 3. 训练模型
print("开始训练...")
history = model.fit(areas_norm, prices_norm,  # 训练数据epochs=100,               # 训练轮数verbose=0                 # 不显示训练过程(避免刷屏)
)print("训练完成!")# 4. 评估效果
test_area = np.array([100])  # 测试:100平米的房子
test_area_norm = (test_area - areas.mean()) / areas.std()
predicted_price_norm = model.predict(test_area_norm, verbose=0)# 反标准化得到实际价格
predicted_price = predicted_price_norm * prices.std() + prices.mean()print(f"预测结果:100平米的房子价格约为 {predicted_price[0][0]:.1f} 万元")
print(f"理论价格:100 * 2 = 200万元")
print(f"预测误差:{abs(predicted_price[0][0] - 200):.1f} 万元")

可视化结果

# 绘制训练过程
plt.figure(figsize=(12, 4))# 损失变化
plt.subplot(1, 2, 1)
plt.plot(history.history['loss'])
plt.title('训练损失变化')
plt.xlabel('训练轮数')
plt.ylabel('损失值')
plt.grid(True)# 预测效果
plt.subplot(1, 2, 2)
plt.scatter(areas, prices, alpha=0.6, label='真实数据')# 画预测线
test_areas = np.linspace(50, 200, 100)
test_areas_norm = (test_areas - areas.mean()) / areas.std()
predicted_prices_norm = model.predict(test_areas_norm, verbose=0)
predicted_prices = predicted_prices_norm * prices.std() + prices.mean()plt.plot(test_areas, predicted_prices, 'r-', linewidth=2, label='神经网络预测')
plt.xlabel('面积 (平米)')
plt.ylabel('价格 (万元)')
plt.title('房价预测效果')
plt.legend()
plt.grid(True)plt.tight_layout()
plt.show()print("图表已显示!红线是神经网络学到的规律")

完整项目:手写数字识别

现在我们来做一个更有趣的项目:让计算机识别手写数字!

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt# 1. 加载MNIST数据集(手写数字数据)
print("正在下载MNIST数据集...")
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()print("数据加载完成!")
print(f"训练图片数量: {len(x_train)}")
print(f"测试图片数量: {len(x_test)}")
print(f"图片尺寸: {x_train[0].shape}")# 查看几个样本
plt.figure(figsize=(10, 2))
for i in range(5):plt.subplot(1, 5, i+1)plt.imshow(x_train[i], cmap='gray')plt.title(f'标签: {y_train[i]}')plt.axis('off')
plt.suptitle('手写数字样本')
plt.show()
# 2. 数据预处理
# 标准化像素值到0-1范围
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0# 将28x28的图片展平成784维向量
x_train_flat = x_train.reshape(60000, 784)
x_test_flat = x_test.reshape(10000, 784)print("数据预处理完成!")
print(f"训练数据形状: {x_train_flat.shape}")
print(f"测试数据形状: {x_test_flat.shape}")
# 3. 构建神经网络
model = tf.keras.Sequential([# 输入层:784个神经元(对应784个像素)tf.keras.layers.Dense(128, activation='relu', input_shape=(784,)),# 隐藏层:128个神经元,使用ReLU激活函数tf.keras.layers.Dense(64, activation='relu'),# 输出层:10个神经元(对应0-9十个数字)tf.keras.layers.Dense(10, activation='softmax')
])# 编译模型
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',  # 多分类问题的损失函数metrics=['accuracy']
)# 查看模型结构
print("神经网络结构:")
model.summary()
# 4. 训练模型
print("开始训练神经网络...")
history = model.fit(x_train_flat, y_train,epochs=10,                    # 训练10轮batch_size=128,              # 每次处理128个样本validation_split=0.1,        # 10%的数据用于验证verbose=1                    # 显示训练进度
)print("训练完成!")
# 5. 评估模型
test_loss, test_accuracy = model.evaluate(x_test_flat, y_test, verbose=0)
print(f"测试准确率: {test_accuracy:.4f} ({test_accuracy*100:.2f}%)")# 预测几个测试样本
predictions = model.predict(x_test_flat[:5], verbose=0)
predicted_labels = np.argmax(predictions, axis=1)print("\n预测结果:")
for i in range(5):print(f"图片{i+1}: 真实标签={y_test[i]}, 预测标签={predicted_labels[i]}, "f"置信度={predictions[i][predicted_labels[i]]:.4f}")
# 6. 可视化结果
plt.figure(figsize=(15, 5))# 训练历史
plt.subplot(1, 3, 1)
plt.plot(history.history['accuracy'], label='训练准确率')
plt.plot(history.history['val_accuracy'], label='验证准确率')
plt.title('模型准确率')
plt.xlabel('训练轮数')
plt.ylabel('准确率')
plt.legend()
plt.grid(True)plt.subplot(1, 3, 2)
plt.plot(history.history['loss'], label='训练损失')
plt.plot(history.history['val_loss'], label='验证损失')
plt.title('模型损失')
plt.xlabel('训练轮数')
plt.ylabel('损失值')
plt.legend()
plt.grid(True)# 预测结果展示
plt.subplot(1, 3, 3)
# 显示一个预测示例
sample_idx = 0
plt.imshow(x_test[sample_idx], cmap='gray')
plt.title(f'真实: {y_test[sample_idx]}, 预测: {predicted_labels[sample_idx]}')
plt.axis('off')plt.tight_layout()
plt.show()print("🎉 恭喜!你已经成功训练了一个手写数字识别神经网络!")

运行效果

预期输出

数据加载完成!
训练图片数量: 60000
测试图片数量: 10000
图片尺寸: (28, 28)神经网络结构:
Model: "sequential_1"
_________________________________________________________________Layer (type)                Output Shape              Param #   
=================================================================dense_1 (Dense)             (None, 128)               100480    dense_2 (Dense)             (None, 64)                8256      dense_3 (Dense)             (None, 10)                650       
=================================================================
Total params: 109,386
Trainable params: 109,386
Non-trainable params: 0训练完成!
测试准确率: 0.9751 (97.51%)预测结果:
图片1: 真实标签=7, 预测标签=7, 置信度=0.9999
图片2: 真实标签=2, 预测标签=2, 置信度=0.9995
...

生成的文件

  • 模型训练过程可视化图表
  • 手写数字样本展示
  • 预测结果对比

常见问题解答

Q1: 安装TensorFlow时出错

错误信息: ERROR: Failed building wheel for tensorflow

解决方法:

# 方法1:升级pip
pip install --upgrade pip# 方法2:使用conda安装
conda install tensorflow# 方法3:安装CPU版本
pip install tensorflow-cpu

Q2: 训练很慢怎么办?

解决方法:

  • 减少训练轮数(epochs):从10改为5
  • 减少数据量:只用前1000个样本训练
  • 使用更小的网络:减少神经元数量

Q3: 准确率不高怎么办?

可能原因和解决方法:

  • 训练轮数太少:增加epochs
  • 网络太简单:增加更多层或神经元
  • 学习率不合适:尝试不同的优化器

Q4: 内存不够怎么办?

解决方法:

# 减少batch_size
model.fit(x_train, y_train, batch_size=32)  # 从128改为32# 或者使用更少的数据
x_train_small = x_train[:10000]  # 只用前10000个样本

课后练习

基础练习

  • 修改神经网络结构,尝试不同数量的神经元
  • 改变训练轮数,观察准确率变化
  • 使用自己手写的数字测试模型

进阶练习

  • 尝试识别时装图片(Fashion-MNIST数据集)
  • 添加更多隐藏层,观察效果变化
  • 使用不同的激活函数(如tanh、sigmoid)

挑战练习

  • 实现一个简单的绘图界面,让用户画数字并识别
  • 比较不同优化器的效果(SGD vs Adam)
  • 分析模型预测错误的样本,找出共同特点

总结

这节课我们学会了:

  1. TensorFlow基础概念:理解张量、计算图、自动微分
  2. 神经网络构建:使用Sequential模型搭建网络
  3. 模型训练流程:编译→训练→评估→预测
  4. 实际项目经验:完成了手写数字识别

下节课预告: 我们将学习PyTorch,对比两个主流深度学习框架的差异,并用PyTorch实现图像分类器。

技术支持

如遇问题,请检查:

  1. Python版本是否3.8+
  2. TensorFlow是否正确安装
  3. 网络连接是否正常(下载数据集需要)
  4. 内存是否足够(建议4GB+)

记住:每个AI专家都是从第一个神经网络开始的!你已经迈出了重要的一步! 🚀

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/bicheng/87641.shtml
繁体地址,请注明出处:http://hk.pswp.cn/bicheng/87641.shtml
英文地址,请注明出处:http://en.pswp.cn/bicheng/87641.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

STM32 串口USART通讯驱动

前言 本篇文章对串口Usart进行讲解,为后面的esp8266和语音模块控制打好基础。 1.串口USART USART(Universal Synchronous/Asynchronous Receiver/Transmitter,通用同步 / 异步收发器) 是一种常见的串行通信接口,广泛应…

pytorch版本densenet代码讲解

DenseNet 模型代码详解 下面是 DenseNet 模型代码的逐部分详细解析: 1. 导入模块 import re from collections import OrderedDict from functools import partial from typing import Any, Optionalimport torch import torch.nn as nn import torch.nn.functional…

前端常见设计模式深度解析

# 前端常见设计模式深度解析一、设计模式概述 设计模式是解决特定问题的经验总结,前端开发中常用的设计模式可分为三大类: 创建型模式:处理对象创建机制(单例、工厂等)结构型模式:处理对象组合(…

React 学习(3)

核心API——React.creatElement()方法优点:将创建元素、添加属性和事件、添加内容和子元素等使用原生dom需要进行复杂操作才能实现的功能集成在一个API中。1.该方法接收三个参数第一个是要创建的元素的名称(小写是因为如果,大写开头会被react…

倾斜摄影无人机飞行航线规划流程详解

在倾斜摄影测量项目中,航线规划的严谨性直接决定了最终三维模型的质量与完整性。照片覆盖不全、模型空洞、纹理模糊或分辨率不达标等问题,往往源于规划阶段对关键细节的疏忽。本文将系统梳理倾斜摄影无人机航线规划的核心流程与关键要点,旨在…

Minio大文件分片上传

一、引入依赖 <dependency><groupId>io.minio</groupId><artifactId>minio</artifactId><version>8.3.3</version></dependency> 二、自定义Minio客户端 package com.gstanzer.video.controller;import com.google.common.c…

Jenkins 插件深度应用:让你的CI/CD流水线如虎添翼 [特殊字符]

Jenkins 插件深度应用&#xff1a;让你的CI/CD流水线如虎添翼 &#x1f680; 嘿&#xff0c;各位开发小伙伴&#xff01;今天咱们来聊聊Jenkins的插件生态系统。如果说Jenkins是一台强大的引擎&#xff0c;那插件就是让这台引擎发挥最大威力的各种零部件。准备好了吗&#xff1…

密码学(斯坦福)

密码学笔记 \huge{密码学笔记} 密码学笔记 斯坦福大学密码学的课程笔记 课程网址&#xff1a;https://www.bilibili.com/video/BV1Rf421o79E/?spm_id_from333.337.search-card.all.click&vd_source5cc05a038b81f6faca188e7cf00484f6 概述 密码学的使用背景 安全信息保护…

代码随想录算法训练营第四十六天|动态规划part13

647. 回文子串 题目链接&#xff1a;647. 回文子串 - 力扣&#xff08;LeetCode&#xff09; 文章讲解&#xff1a;代码随想录 思路&#xff1a; 以dp【i】表示以s【i】结尾的回文子串的个数&#xff0c;发现递推公式推导不出来此路不通 以dp【i】【j】表示s【i】到s【j】的回…

基于四种机器学习算法的球队数据分析预测系统的设计与实现

文章目录 有需要本项目的代码或文档以及全部资源&#xff0c;或者部署调试可以私信博主项目介绍项目展示随机森林模型XGBoost模型逻辑回归模型catboost模型每文一语 有需要本项目的代码或文档以及全部资源&#xff0c;或者部署调试可以私信博主 项目介绍 本项目旨在设计与实现…

http、SSL、TLS、https、证书

一、基础概念 1.HTTP HTTP (超文本传输协议) 是一种用于客户端和服务器之间传输超媒体文档的应用层协议&#xff0c;是万维网的基础。 简而言之&#xff1a;一种获取和发送信息的标准协议 2.SSL 安全套接字层&#xff08;SSL&#xff09;是一种通信协议或一组规则&#xf…

在 C++ 中,判断 `std::string` 是否为空字符串

在 C 中&#xff0c;判断 std::string 是否为空字符串有多种方法&#xff0c;以下是最常用的几种方式及其区别&#xff1a; 1. 使用 empty() 方法&#xff08;推荐&#xff09; #include <string>std::string s; if (s.empty()) {// s 是空字符串 }特性&#xff1a; 时间…

【Harmony】鸿蒙企业应用详解

【HarmonyOS】鸿蒙企业应用详解 一、前言 1、应用类型定义速览&#xff1a; HarmonyOS目前针对应用分为三种类型&#xff1a;普通应用&#xff0c;游戏应用&#xff0c;企业应用。 而企业应用又分为&#xff0c;企业普通应用和设备管理应用MDM&#xff08;Mobile Device Man…

Linux云计算基础篇(8)

VIM 高级特性插入模式按 i 进入插入模式。按 o 在当前行下方插入空行并进入插入模式。按 O 在当前行上方插入空行并进入插入模式。命令模式:set nu 显示行号。:set nonu 取消显示行号。:100 光标跳转到第 100 行。G 光标跳转到文件最后一行。gg 光标跳转到文件第一行。30G 跳转…

Linux进程单例模式运行

Linux进程单例模式运行 #include <iostream> #include <stdlib.h> #include <unistd.h> #include <string.h> #include <stdio.h> #include <sys/types.h> #include <sys/stat.h> #include <fcntl.h>int write_pid(const cha…

【Web 后端】部署服务到服务器

文章目录 前言一、如何启动服务二、挂载和开机启动服务1. 配置systemctl 服务2. 创建server用户3. 启动服务 总结 前言 如果你的后端服务写好了如果部署到你的服务器呢&#xff0c;本次通过fastapi写的服务实例&#xff0c;示范如何部署到服务器&#xff0c;并做服务管理。 一…

国产MCU学习Day5——CW32F030C8T6:窗口看门狗功能全解析

每日更新教程&#xff0c;评论区答疑解惑&#xff0c;小白也能变大神&#xff01;" 目录 一.窗口看门狗&#xff08;WWDG&#xff09;简介 二.窗口看门狗寄存器列表 三.窗口看门狗复位案例 一.窗口看门狗&#xff08;WWDG&#xff09;简介 CW32F030C8T6 内部集成窗口看…

2025年文件加密软件分享:守护数字世界的核心防线

在数字化时代&#xff0c;数据已成为个人与企业的宝贵资产&#xff0c;文件加密软件通过复杂的算法&#xff0c;确保信息在存储、传输与共享过程中的保密性、完整性与可用性。一、文件加密软件的核心原理文件加密软件算法以其高效性与安全性广泛应用&#xff0c;通过对文件数据…

node.js下载教程

1.项目环境文档 语雀 2.nvm安装教程与nvm常见命令,超详细!-阿里云开发者社区 C:\Windows\System32>nvm -v 1.2.2 C:\Windows\System32>nvm list available Error retrieving "http://npm.taobao.org/mirrors/node/index.json": HTTP Status 404 C:\Window…

(AI如何解决问题)在一个项目,跳转到外部html页面,页面布局

问题描述目前&#xff0c;ERP后台有很多跳转外部链接的地方&#xff0c;会直接打开一个tab显示。因为有些页面是适配手机屏幕显示&#xff0c;放在浏览器会超级大。不美观&#xff0c;因此提出优化。修改前&#xff1a;修改后&#xff1a;思考过程1、先看下代码&#xff1a;log…