上节介绍了框架下的注意力机制的主要成分 图10.1.3: 查询(自主提示)和键(非自主提示)之间的交互形成了注意力汇聚; 注意力汇聚有选择地聚合了值(感官输入)以生成最终的输出。 本节将介绍注意力汇聚的更多细节, 以便从宏观上了解注意力机制在实践中的运作方式。 具体来说,1964年提出的Nadaraya-Watson核回归模型 是一个简单但完整的例子,可以用于演示具有注意力机制的机器学习。

import torch
from torch import nn
from d2l import torch as d2l

10.2.1. 生成数据集

n_train = 50  # 训练样本数
x_train, _ = torch.sort(torch.rand(n_train) * 5)   # 0-1之间随机生成50个数然后排序后的训练样本def f(x):return 2 * torch.sin(x) + x**0.8y_train = f(x_train) + torch.normal(0.0, 0.5, (n_train,))  # 训练样本的输出 加一点噪音
x_test = torch.arange(0, 5, 0.1)  # 测试样本,是均匀的
y_truth = f(x_test)  # 测试样本的真实输出
n_test = len(x_test)  # 测试样本数
n_test输出:50

下面的函数将绘制所有的训练样本(样本由圆圈表示), 不带噪声项的真实数据生成函数(标记为“Truth”), 以及学习得到的预测函数(标记为“Pred”)。

def plot_kernel_reg(y_hat):d2l.plot(x_test, [y_truth, y_hat], 'x', 'y', legend=['Truth', 'Pred'],xlim=[0, 5], ylim=[-1, 5])d2l.plt.plot(x_train, y_train, 'o', alpha=0.5);

10.2.2. 平均汇聚

先使用最简单的估计器来解决回归问题。 基于平均汇聚来计算所有训练样本输出值的平均值:

y_hat = torch.repeat_interleave(y_train.mean(), n_test)
plot_kernel_reg(y_hat)

如下图所示,这个估计器确实不够聪明。 真实函数(“Truth”)和预测函数(“Pred”)相差很大。

10.2.3. 非参数注意力汇聚

# X_repeat的形状:(n_test,n_train),
# 每一行都包含着相同的测试输入(例如:同样的查询)
X_repeat = x_test.repeat_interleave(n_train).reshape((-1, n_train))#把1234变成111222333444.n_train来控制 x_test:测试输入数据,形状为 (n_test,)。
x_test.repeat_interleave(n_train):将每个测试输入重复 n_train 次,形成一个一维数组。
.reshape((-1, n_train)):将重复后的数组重新整形为二维数组,形状为 (n_test, n_train)。每一行都包含相同的测试输入,重复 n_train 次。
# x_train包含着键。attention_weights的形状:(n_test,n_train),计算测试输入与每个训练输入之间的差值。X_repeat 的形状为 (n_test, n_train),x_train 的形状为 (n_train,),广播后形状为 (n_test, n_train)
# 每一行都包含着要在给定的每个查询的值(y_train)之间分配的注意力权重
attention_weights = nn.functional.softmax(-(X_repeat - x_train)**2 / 2, dim=1)
# y_hat的每个元素都是值的加权平均值,其中的权重是注意力权重
y_hat = torch.matmul(attention_weights, y_train)
plot_kernel_reg(y_hat)#y_train:训练目标值,形状为 (n_train,)
#torch.matmul(attention_weights, y_train):计算加权平均值。attention_weights 的形状为 (n_test, n_train),y_train 的形状为 (n_train,),广播后形状为 (n_train, 1)。矩阵乘法的结果 y_hat 的形状为 (n_test, 1),表示每个测试输入的预测值。

这一部分的作用就是在于,对于一个train长度的y train,你每一个元素都要加上权重,因此得现把train单行长度给广播成train列长度,这样权重相乘就可以做矩阵乘法了

d2l.show_heatmaps(attention_weights.unsqueeze(0).unsqueeze(0),xlabel='Sorted training inputs',ylabel='Sorted testing inputs')

10.2.4. 带参数注意力汇聚

10.2.4.1. 批量矩阵乘法

X = torch.ones((2, 1, 4))
Y = torch.ones((2, 4, 6))
torch.bmm(X, Y).shape输出:torch.Size([2, 1, 6])

在注意力机制的背景中,我们可以使用小批量矩阵乘法来计算小批量数据中的加权平均值。

weights = torch.ones((2, 10)) * 0.1
values = torch.arange(20.0).reshape((2, 10))
torch.bmm(weights.unsqueeze(1), values.unsqueeze(-1))输出:tensor([[[ 4.5000]],[[14.5000]]])

10.2.4.2. 定义模型

基于 (10.2.7)中的 带参数的注意力汇聚,使用小批量矩阵乘法, 定义Nadaraya-Watson核回归的带参数版本为:

w干的事情就是控制你的高斯核让自己不那么平滑,你看上面那个不带参数就是到处都很平滑

class NWKernelRegression(nn.Module):def __init__(self, **kwargs):super().__init__(**kwargs)self.w = nn.Parameter(torch.rand((1,), requires_grad=True))def forward(self, queries, keys, values):# queries和attention_weights的形状为(查询个数,“键-值”对个数)queries = queries.repeat_interleave(keys.shape[1]).reshape((-1, keys.shape[1]))self.attention_weights = nn.functional.softmax(-((queries - keys) * self.w)**2 / 2, dim=1)# values的形状为(查询个数,“键-值”对个数)return torch.bmm(self.attention_weights.unsqueeze(1),values.unsqueeze(-1)).reshape(-1)#和上面的算法是差不多的,无非是写了一个类且多用了一个w

10.2.4.3. 训练

接下来,将训练数据集变换为键和值用于训练注意力模型。 在带参数的注意力汇聚模型中, 任何一个训练样本的输入都会和除自己以外的所有训练样本的“键-值”对进行计算, 从而得到其对应的预测输出。

# X_tile的形状:(n_train,n_train),每一行都包含着相同的训练输入
X_tile = x_train.repeat((n_train, 1))
# Y_tile的形状:(n_train,n_train),每一行都包含着相同的训练输出
Y_tile = y_train.repeat((n_train, 1))
# keys的形状:('n_train','n_train'-1)
keys = X_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))
# values的形状:('n_train','n_train'-1)
values = Y_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))

训练带参数的注意力汇聚模型时,使用平方损失函数和随机梯度下降:

net = NWKernelRegression()
loss = nn.MSELoss(reduction='none')
trainer = torch.optim.SGD(net.parameters(), lr=0.5)
animator = d2l.Animator(xlabel='epoch', ylabel='loss', xlim=[1, 5])for epoch in range(5):trainer.zero_grad()l = loss(net(x_train, keys, values), y_train)l.sum().backward()trainer.step()print(f'epoch {epoch + 1}, loss {float(l.sum()):.6f}')animator.add(epoch + 1, float(l.sum()))

# keys的形状:(n_test,n_train),每一行包含着相同的训练输入(例如,相同的键)
keys = x_train.repeat((n_test, 1))
# value的形状:(n_test,n_train)
values = y_train.repeat((n_test, 1))
y_hat = net(x_test, keys, values).unsqueeze(1).detach()
plot_kernel_reg(y_hat)

d2l.show_heatmaps(net.attention_weights.unsqueeze(0).unsqueeze(0),xlabel='Sorted training inputs',ylabel='Sorted testing inputs')

10.2.5. 小结

  • Nadaraya-Watson核回归是具有注意力机制的机器学习范例。

  • Nadaraya-Watson核回归的注意力汇聚是对训练数据中输出的加权平均。从注意力的角度来看,分配给每个值的注意力权重取决于将值所对应的键和查询作为输入的函数。

  • 注意力汇聚可以分为非参数型和带参数型。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/bicheng/92949.shtml
繁体地址,请注明出处:http://hk.pswp.cn/bicheng/92949.shtml
英文地址,请注明出处:http://en.pswp.cn/bicheng/92949.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

nginx高新能web服务器

一、Nginx 概述和安装 Nginx是免费的、开源的、高性能的HTTP和反向代理服务器、邮件代理服务器、以及TCP/UDP代理服务器。 Nginx 功能介绍 静态的web资源服务器html,图片,js,css,txt等静态资源 http/https协议的反向代理 结合F…

Unity大型场景性能优化全攻略:PC与安卓端深度实践 - 场景管理、渲染优化、资源调度 C#

本文将深入探讨Unity在大型场景中的性能优化策略,涵盖场景管理、渲染优化、资源调度等核心内容,并提供针对PC和安卓平台的优化方案及实战案例。 提示:内容纯个人编写,欢迎评论点赞。 文章目录1. 大型场景性能挑战1.1 性能瓶颈定位…

Java集合框架、Collection体系的单列集合

Java集合框架、Collection1. 认识Java集合框架及结构1.1 集合框架整体结构1.2 集合框架的核心作用2. Collection的两大常用集合体系及各个系列集合的特点2.1 List系列集合(有序、可重复)2.2 Set系列集合(无序、不可重复)3. Collec…

HTML <picture> 元素:让图片根据设备 “智能切换” 的响应式方案

在响应式设计中,图片适配是一个绕不开的难题:同一张高清图片在大屏设备上清晰美观,但在小屏手机上可能加载缓慢;而适合手机的小图在桌面端又会模糊失真。传统的解决方案往往需要用JavaScript判断设备尺寸并动态替换图片源&#xf…

Spring Boot 监控与日志管理实战

在 Spring Boot 应用开发中,指标监控和日志管理是保障应用稳定运行的核心环节。指标监控能实时掌握应用健康状态、性能瓶颈,日志管理则用于问题排查和安全审计。本文基于 Spring Boot 提供的 Actuator 监控工具、Spring Boot Admin 可视化平台&#xff0…

【排序算法】②希尔排序

系列文章目录 第一篇:【排序算法】①直接插入排序-CSDN博客 第二篇:【排序算法】②希尔排序-CSDN博客 第三篇:【排序算法】③直接选择排序-CSDN博客 第四篇:【排序算法】④堆排序-CSDN博客 第五篇:【排序算法】⑤冒…

Linux Shell为文件添加BOM并自动转换为unix格式

1.添加并查看BOM添加bomvim -c "set bomb|set fileencodingutf-8|wq" ./gradlew查看bomhead -c 3 ./gradlew | hexdump -C2.安装dos2unix并转换为unix格式安装sudo apt install dos2unix转换dos2unix ./gradlew

华清远见25072班C语言学习day5

重点内容:数组:为什么有数组?为了便于存储多个数据特点:连续存储多个同种数据类型元素(连续指内存地址连续)数组名:数组中首元素的地址,是一个地址常量。一维整形数组:定义:数据类型…

安全守护,温情陪伴 — 智慧养老产品上新

- 养老智慧看护终端接入萤石开放平台 - 在2025 ECDC萤石云开发者大会,萤石产品经理已经介绍了基于萤石云服务AI能力适老化设备的养老智能能力开放。 而今天,养老智慧看护终端再升级,集成跌倒检测、物理隐私遮蔽、火柴人遮蔽、AI语音智能体…

鸿蒙flutter项目接入极光推送

推送的自分类权益 需要审核15个工作日,实际约3个工作日 项目使用极光推送flutter代码,代码端已经配置的东西(需要配置flutter端和对应各自平台原生端),我的工程是多target,所以和单target有一点不同。 一、…

2025牛客多校第八场 根号-2进制 个人题解

J.根号-2进制 #数学 #FFT 思路 赛后发现身边的同学都是通过借位来解决进位问题的,在此提供一种全程不出现减法的顺推做法 首先A,BA,BA,B可以理解为两个多项式:A0A1−2A2(−2)2…A_{0}A_{1}\sqrt{ -2 }A_{2}(\sqrt{ -2 })^2\dotsA0​A1​−2​A2​(−…

DataEase官方出品丨SQLBot:基于大模型和RAG的智能问数系统

2025年8月7日,DataEase开源项目组发布SQLBot开源项目(github.com/dataease/SQLBot)。SQLBot是一款基于大语言模型(Large Language Model,LLM)和RAG(Retrieval Augmented Generation,…

第十四节 代理模式

在代理模式(Proxy Pattern)中,一个类代表另一个类的功能。这种类型的设计模式属于结构型模式。在代理模式中,我们创建具有现有对象的对象,以便向外界提供功能接口。介绍意图:为其他对象提供一种代理以控制对…

训推一体 | 暴雨X8848 G6服务器 x Intel®Gaudi® 2E AI加速卡

近日,暴雨信息携手英特尔,针对Gaudi 2E AI加速器HL-288 PCIe卡(简称IntelGaudi 2E PCIe卡,下同)完成专项调优与适配工作,并重磅推出Intel Eagle Stream平台4U8卡解决方案。该方案通过软硬件协同优化&#x…

GB17761-2024标准与电动自行车防火安全的技术革新

随着我国电动自行车保有量突破3.5亿辆,这一便捷的交通工具已成为城市出行的重要组成。然而,伴随市场规模扩大而来的是日益突出的安全问题——2023年全国电动自行车火灾事故高达2.5万起,年均增长率约20%,火灾中塑料件加速燃烧并释放…

利用容器编排完成haproxy和nginx负载均衡架构实施

1 创建测试目录和文件[rootdocker-a ~]# mkdir lee [rootdocker-a ~]# cd lee/ [rootdocker-a lee]# touch docker-compose.yml # 容器编排工具Docker Compose 默认识别docker-compose.yml文件2 编写docker-compose.yml文件和haproxy.cfg文件2.1 核心配置说明2.1.1 服务结构共定…

WinRAR v7.13 烈火汉化稳定版,解压缩全格式专家

[软件名称]: WinRAR v7.13 烈火汉化稳定版 [软件大小]: 3.8 MB [下载通道]: 夸克盘 | 迅雷盘 软件介绍 WinRAR 压缩文件管理器,知名解压缩软件,电脑装机必备软件,国内最流行最好用的压缩文件管理器、解压缩必备软件。它提供 RAR 和 ZIP 文…

强化学习常用数据集

强化学习常用数据集数学推理数据集数值标签GSM8K(2021 OpenAI)问答数据集在LLM场景下进行强化学习训练的时候,时常会涉及到各种各样的数据集,容易记不住,因此开个帖子记录一下。可采取的分类方法有很多,这里直接按照领…

ROS2学习(1)—基础概念及环境搭建

文章目录核心框架环境搭建小乌龟机器人控制小乌龟启动键盘控制启动rqt查看ros节点关系核心框架 这里有几个比较重要的概念: 四大通信机制:话题(Topic)、服务(Service)、动作(Action&#xff09…

基于STM32单片机超声波测速测距防撞报警设计

1 系统功能介绍 本设计是一套基于 STM32F103C8T6 单片机 的超声波测速测距防撞报警系统,能够实现对目标物体的实时测距与测速,并通过 TFT 彩屏进行动态显示,同时根据用户设定的距离与速度阈值进行报警提示。该系统不仅可以用于固定场景的安全…