这一节我们用TensorFlow定义简单的分类器。首先考虑分类器的方程式是什么是值得的。数学习的技巧是使用sigmoid函数。sigmoid函数绘制如图3-40, 通常标记为σ, 是实数域里的函数取值(0, 1)。这个特征很便利,因为我们可以将sigmoid的输出解释为事件发现的概率。  (转换离散事件到连续值是机器学习里反复出现的主题)

图3-40. 绘制sigmoid 函数.

预测离散事件的概率的方程式如下。这些方程式定义了简单的逻辑回归模型:

y0 = σ( wx + b

y1 = 1 σ (wx + b

TensorFlow提供了工具函数来计算sigmoidal值的交叉熵损失。最简单的函数是tf.nn.sigmoid_cross_entropy_with_logits. ( logitsigmoid的逆。实际上,这意味着传递参数到 sigmoid, wx + b, 而不是sigmoidal value σ wx + b 本身)。我们推荐使用 TensorFlow的实现而不是手工定义交叉熵,因为计算交叉熵损失有许多复杂的数值问题。

#List3-44

import numpy as np

np.random.seed(456)

import tensorflow as tf

#tf.set_random_seed(456)

import matplotlib.pyplot as plt

from sklearn.metrics import accuracy_score

from scipy.special import logit

# Generate synthetic data

N = 100

# Zeros form a Gaussian centered at (-1, -1)

x_zeros = np.random.multivariate_normal(mean=np.array((-1, -1)), cov=.1*np.eye(2), size=(N//2,))

y_zeros = np.zeros((N//2,))

# Ones form a Gaussian centered at (1, 1)

x_ones = np.random.multivariate_normal(mean=np.array((1, 1)), cov=.1*np.eye(2), size=(N//2,))

y_ones = np.ones((N//2,))

x_np = np.vstack([x_zeros, x_ones])

y_np = np.concatenate([y_zeros, y_ones])

# Save image of the data distribution

plt.xlabel(r"$x_1$")

plt.ylabel(r"$x_2$")

plt.title("Toy Logistic Regression Data")

# Plot Zeros

plt.scatter(x_zeros[:, 0], x_zeros[:, 1], color="blue")

plt.scatter(x_ones[:, 0], x_ones[:, 1], color="red")

plt.savefig("logistic_data.png")

x_np,y_np

模型的训练代码见List3-45 ,与线性回归模型的代码相同。

#List3-45

W = tf.Variable(tf.random.normal((2, 1)))

b = tf.Variable(tf.random.normal((1,)))

W,b

x=tf.cast(x_np,tf.float32)

y=tf.cast(y_np,tf.float32)

learning_r=0.01

optimizer = tf.optimizers.SGD(learning_r)

n_steps = 100

# Train model

for i in range(n_steps):   

    with tf.GradientTape() as tape:

        #_, summary, loss = sess.run([train_op, merged, l], feed_dict=feed_dict)

        y_logit = tf.squeeze(tf.matmul(x, W) + b)

        # the sigmoid gives the class probability of 1

        y_one_prob = tf.sigmoid(y_logit)

        # Rounding P(y=1) will give the correct prediction.

        y_pred = tf.round(y_one_prob)

        entropy = tf.nn.sigmoid_cross_entropy_with_logits(logits=y_logit, labels=y)

        # Sum all contributions

        l = tf.reduce_sum(entropy)       

        gradients=tape.gradient(l,[W,b])

    optimizer.apply_gradients(zip(gradients, [W, b]))

    #W=W-tf.Variable(learning_r,tf.float32)*W

    #b=b-tf.Variable(learning_r,tf.float32)*b

    print("loss: %f" % l)

    #train_writer.add_summary(summary, i)

    # Get weights

    w_final=W

    b_final=b

# Make Predictions

    #y_pred_np = sess.run(y_pred, feed_dict={x: x_np})

#score = accuracy_score(y_np, y_pred_np)

#print("Classification Accuracy: %f" % score)

plt.clf()

# Save image of the data distribution

plt.xlabel(r"$x_1$")

plt.ylabel(r"$x_2$")

plt.title("Learned Model (Classification Accuracy: 1.00)")

plt.xlim(-2, 2)

plt.ylim(-2, 2)

# Plot Zeros

plt.scatter(x_zeros[:, 0], x_zeros[:, 1], color="blue")

plt.scatter(x_ones[:, 0], x_ones[:, 1], color="red")

x_left = -2

y_left = (1./w_final[1]) * (-b_final + logit(.5) - w_final[0]*x_left)

x_right = 2

y_right = (1./w_final[1]) * (-b_final + logit(.5) - w_final[0]*x_right)

plt.plot([x_left, x_right], [y_left, y_right], color='k')

plt.savefig("logistic_pred.png")

图3-41

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/pingmian/89000.shtml
繁体地址,请注明出处:http://hk.pswp.cn/pingmian/89000.shtml
英文地址,请注明出处:http://en.pswp.cn/pingmian/89000.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Java代码审计(2)】MyBatis XML 注入审计

代码背景:某公司使用 MyBatis 作为持久层框架,登录功能如下: Java 接口代码: public interface UserMapper {User findByUsernameAndPassword(Param("username") String username,Param("password") String p…

Spring Boot目录变文件夹?3步解决!

在 Spring Boot 项目中,当你在 src/main/java 下看到目录结构而不是包结构时,这通常是 IDE(如 IntelliJ IDEA)的显示问题或项目配置问题。以下是原因和解决方案:问题原因IDE 未正确识别 Java 源代码根目录 src/main/ja…

Appium源码深度解析:从驱动到架构

Appium源码深度解析:从驱动到架构 Appium 源码概览 Appium 是一个开源的移动自动化测试框架,支持跨平台(iOS、Android)和多种编程语言(Java、Python 等)。其源码托管在 GitHub 上,主要由 JavaScript 和 Node.js 实现,核心逻辑围绕客户端-服务器架构设计。 GitHub 仓库…

给 Excel 整列空格文字内容加上前缀:像给文字穿衣服一样简单!

目录 步骤一:选中目标列 打开Excel表格并定位列点击列标题选中整列 步骤二:输入公式,变身“魔法” 在公式编辑栏输入公式按下回车键查看效果 步骤三:向下填充,批量处理 鼠标定位到单元格右下角按住鼠标左键向下拖动填充…

Spring Boot 启动原理揭秘:从 main 方法到自动装配

Spring Boot 启动原理揭秘:从 main 方法到自动装配 引言 Spring Boot 作为 Java 领域最流行的开发框架之一,凭借其“开箱即用”的特性极大地简化了 Spring 应用的搭建和部署。然而,尽管开发者在日常工作中频繁使用 Spring Boot 的启动类&…

OpenCV 与深度学习:从图像分类到目标检测技术

一、深度学习:从 “人工设计” 到 “自动学习”1.1 深度学习的定位:AI、机器学习与深度学习的关系人工智能(AI):是一个宽泛的领域,目标是构建能模拟人类智能的系统,涵盖推理、感知、决策等能力。…

Docker 镜像推送至 Coding 制品仓库超时问题排查与解决

Docker 镜像推送至 Coding 制品仓库超时问题排查与解决 在将 Docker 镜像推送至 Coding 的制品仓库时,可能会遇到 docker push 命令超时失败的问题。但使用 curl -i http://xxx.coding.xxxx.xx 测试时,连接却能成功建立。以下是排查过程及解决方案。 问题…

https交互原理

Https 交互时序图:HTTPS 通信中结合 RSA 和 AES 加密的流程,本质是利用 RSA 的安全特性交换 AES 密钥,再用高效的 AES 加密实际数据传输。HTTPS 交互核心流程(TLS/SSL 握手) 1. 建立 TCP 连接 客户端通过 TCP 三次握手…

LSTM入门案例(时间序列预测)| pytorch实现

需求 假如我有一个时间序列,例如是前113天的价格数据(训练集),然后我希望借此预测后30天的数据(测试集),实际上这143天的价格数据都已经有了。这里为了简单,每一天的数据只有一个价…

WPS、Word加载项开发流程(免费最简版本)

文章目录1 加载项对比2 WPS 加载项2.1 本地开发2.1.1 准备开发环境2.1.2 新建 WPS 加载项项目2.1.3 运行项目2.2 在线部署2.2.1 编译项目2.2.2 部署项目2.2.3 生成分发文件2.2.4 部署分发文件2.3 安装加载项2.4 取消发布3 Word 加载项3.1 本地开发3.1.1 准备开发环境3.1.2 新建…

Flink SQL 性能优化实战

最近我们组在大规模上线Flink SQL作业。首先,在进行跑批量初始化完历史数据后,剩下的就是消费Kafka历史数据进行追数了。但是发现某些作业的追数过程十分缓慢,要运行一晚上甚至三四天才能追上最新数据。由于是实时数仓指标计算上线初期&#…

HTML 树结构(DOM)深入讲解教程

一、HTML 树结构的核心概念 1.1 DOM(文档对象模型)的定义 DOM(Document Object Model)是 W3C 制定的标准接口,允许程序或脚本(如 JavaScript)动态访问和更新 HTML/XML 文档的内容、结构和样式。…

用鼠标点击终端窗口的时候出现:0;61;50M0;61;50M0;62;50M0

在做aws webrtc viewer拉流压测的过程中,我本地打开了多个终端,用于连接EC2实例: 一个终端用于启动 ‘并发master脚本’、监控master端的cpu、mem;一个终端用于监控master端的带宽情况;一个终端用于监控viewer端的cpu、…

C++-linux 5.gdb调试工具

GDB调试工具 在C/C开发中,程序运行时的错误往往比编译错误更难定位。GDB(GNU Debugger)是Linux环境下最强大的程序调试工具,能够帮助开发者追踪程序执行流程、查看变量状态、定位内存错误等。本章将从基础到进阶,全面讲…

Update~Read PLC for Chart ~ Log By Shift To be... Alarm AI Machine Learning

上图~ 持续迭代 1、增加报警弹窗,具体到哪个值,双边规格具体是多少 2、实时显示当前值的统计特征,Max Min AVG ... import tkinter as tk from tkinter import simpledialog import time import threading import queue import logging from datetime import datet…

es的自定义词典和停用词

在 Elasticsearch 中,自定义词典是优化分词效果的核心手段,尤其适用于中文或专业领域的文本处理。以下是关于 ES 自定义词典的完整指南: 为什么需要自定义词典? 默认分词不足: ES 自带的分词器(如 Standard…

微算法科技技术突破:用于前馈神经网络的量子算法技术助力神经网络变革

随着量子计算和机器学习的迅猛发展,企业界正逐步迈向融合这两大领域的新时代。在这一背景下,微算法科技(NASDAQ:MLGO)成功研发出一套用于前馈神经网络的量子算法,突破了传统神经网络在训练和评估中的性能瓶颈。这一创新…

一文读懂循环神经网络(RNN)—语言模型+读取长序列数据(2)

目录 读取长序列数据 为什么需要 “读取长序列数据”? 读取长序列数据的核心方法 1. 滑动窗口(Sliding Window) 2. 分段截取(Segmentation) 3. 滚动生成(Rolling Generation) 4. 关键信息…

Oracle Virtualbox 虚拟机配置静态IP

Oracle Virtualbox 虚拟机配置静态IP VirtualBox的网卡,默认都是第一个不能自定义,后续新建的可以自定义。 新建NAT网卡、host主机模式网卡 依次点击:管理->工具->网络管理器新建host主机模式网卡 这个网卡的网段自定义,创建…

Linux RAID1 创建与配置实战指南(mdadm)

Linux RAID1 创建与配置实战指南(mdadm)一、RAID1 核心价值与实战目标RAID1(磁盘镜像) 通过数据冗余提供高可靠性:当单块硬盘损坏时,数据不丢失支持快速阵列重建读写性能略低于单盘(镜像写入开销…