算法

一文通透动作分块算法ACT:斯坦福ALOHA团队推出的动作序列预测算法(Action Chunking with Transformers)
比较简单,算法题目里就写了:Action Chunking with Transformers,比较有特色的地方就是Action Chunking,核心就是不浪费之前做过的推理预测,统统拿过来加权一下,得到最终的答案。
在这里插入图片描述

源码

逐行解读ALOHA ACT的实现:机器人动作分块算法ACT的代码剖析、训练部署(含真机上的智能分拣复现)
代码写得很优雅,读起来很流畅

1.1.1 模仿学习及其挑战:Action Chunking with Transformers(ACT)

预测动作中的小误差会引起状态的大差异,加剧模仿学习的“复合误差”问题。为了解决这个问题,他们从动作分块(action chunking)中获得灵感,这是心理学中的一个概念,描述了如何将一系列动作组合在一起作为一个块,最终作为一个单元执行

他们使用Transformers实现动作分块策略,并将其训练为条件VAE (CVAE),以捕获人类数据中的可变性。他们将该方法命名为Action Chunking with Transformers(ACT),并发现它在一系列模拟和现实世界的精细操作任务上显著优于以前的模仿学习算法

2.2.2 第二步 推断z,以获得CVAE解码器输入中的风格变量z

一文通透动作分块算法ACT:斯坦福ALOHA团队推出的动作序列预测算法(Action Chunking with Transformers)的这句话啥意思?

最后
只取第一个输出,它对应于**[CLS]标记**,并使用另一个线性网络来预测z分布均值方差,将其参数化为对角高斯分布
且使用重新参数化获得z的样本,这是一种允许在采样过程中反向传播的标准方法,以便编码器和解码器可以联合优化[33]

看detr_vae.py的代码就知道了:
在DETRVAE的if is_training头上有个注释:Obtain latent z from action sequence,
意思是风格变量z就是latent_input
[CLS]标记:encoder_output = encoder_output[0] # take cls output only
均值:mu = latent_info[:, :self.latent_dim]
方差:logvar = latent_info[:, self.latent_dim:]
使用重新参数化获得z的样本:latent_sample = reparametrize(mu, logvar)
最后:latent_input = self.latent_out_proj(latent_sample)

2.3 优势特征:ACT与其他模仿学习方法的比较

一方面,transformer解码器的“query”是第一层固定的正弦位置嵌入,即如上图右下角所示的position embeddings(fixed),其维度为k ×512
二方面,transformer解码器的交叉注意力(cross-attention)层中的“keys”和“values”来自上述transformer编码器的输出

eval_bc(评估一个行为克隆(behavior cloning)模型)和train_bc(训练行为克隆BC模型)的区别

我看到train_bc里头有个eval的,但这个eval应该和eval_bc不一样,虽然两者都要用到policy.eval()
注:policy里头就会调用

model, optimizer = build_ACT_model_and_optimizer(args_override)
self.model = model

1.8.3.2 根据观察结果查询策略、获取动作

这里的train_bc的policy调用参数是(qpos_data, image_data, action_data, is_pad)
eval_bc的policy调用参数是(qpos, curr_image)
根据参数来判断是训练还是推理
在这里插入图片描述
在训练模式下,会计算出一系列的损失并返回一个包含这些损失的字典
在推理模式下,会从模型中获取预测的动作并返回

aloha act代码里头的qpos和action有什么区别?

https://metaso.cn/s/IOAGn1O

那mu, logvar是啥

https://metaso.cn/s/IOAGn1O
在变分自编码器(VAE)中,mu 和 logvar 是两个关键参数,它们分别代表潜在变量的均值和对数方差,用于生成潜在空间的样本。
这段代码是 变分自编码器(VAE) 中的 重参数化技巧(Reparameterization Trick) 的实现,其作用是 从潜在变量的分布中采样,同时保证 梯度可以连续传播,从而实现端到端的训练。

def reparametrize(mu, logvar):std = logvar.div(2).exp()eps = Variable(std.data.new(std.size()).normal_())return mu + std * eps

编码器和编码器的输入与输出

backbone + encoder 等等输入到 self.transformer,其实self.transformer就是decoder部分
核心代码是detr_vae.pyclass DETRVAE(nn.Module):def forwardif is_training:部分
前提:detr_vae.pyclass DETRVAE(nn.Module):def forward的参数:qpos, image, env_state, actions, is_pad,都来自于imitate_episodes.pydef forward_pass(data, policy)data

编码器的输入与输出

编码器的核心调用语句:self.encoder(encoder_input, pos=pos_embed, src_key_padding_mask=is_pad)
参数的来源:

# project action sequence to embedding dim, and concat with a CLS token
action_embed = self.encoder_action_proj(actions) # (bs, seq, hidden_dim)
qpos_embed = self.encoder_joint_proj(qpos)  # (bs, hidden_dim) # qpos来自于forward_pass(data, policy):的image_data, qpos_data, action_data, is_pad = data
qpos_embed = torch.unsqueeze(qpos_embed, axis=1)  # (bs, 1, hidden_dim)
cls_embed = self.cls_embed.weight # (1, hidden_dim)
cls_embed = torch.unsqueeze(cls_embed, axis=0).repeat(bs, 1, 1) # (bs, 1, hidden_dim)
encoder_input = torch.cat([cls_embed, qpos_embed, action_embed], axis=1) # (bs, seq+1, hidden_dim)
encoder_input = encoder_input.permute(1, 0, 2) # (seq+1, bs, hidden_dim)
# do not mask cls token 输出形状为(bs, 2)的二维张量,里面元素全部填充为False
cls_joint_is_pad = torch.full((bs, 2), False).to(qpos.device) # False: not a padding
is_pad = torch.cat([cls_joint_is_pad, is_pad], axis=1)  # (bs, seq+1)
# obtain position embedding
pos_embed = self.pos_table.clone().detach()
pos_embed = pos_embed.permute(1, 0, 2)  # (seq+1, 1, hidden_dim)
# query model
encoder_output = self.encoder(encoder_input, pos=pos_embed, src_key_padding_mask=is_pad)
encoder_output = encoder_output[0] # take cls output only

编码器的输入与输出

编码器的的核心调用语句为
hs = self.transformer(src, None, self.query_embed.weight, pos, latent_input, proprio_input, self.additional_pos_embed.weight)[0]
其中:

  1. src
all_cam_features = []
for cam_id, cam_name in enumerate(self.camera_names):features, pos = self.backbones[0](image[:, cam_id]) # HARDCODEDfeatures = features[0] # take the last layer featurepos = pos[0]all_cam_features.append(self.input_proj(features))all_cam_pos.append(pos)
# fold camera dimension into width dimension
src = torch.cat(all_cam_features, axis=3)
  1. pos
for cam_id, cam_name in enumerate(self.camera_names):features, pos = self.backbones[0](image[:, cam_id]) # HARDCODEDfeatures = features[0] # take the last layer featurepos = pos[0]all_cam_features.append(self.input_proj(features))all_cam_pos.append(pos)
pos = torch.cat(all_cam_pos, axis=3)
  1. latent_input 【Obtain latent z from action sequence】里的latent z
self.latent_dim = 32
latent_info = self.latent_proj(encoder_output) # 来自于编码器的输出
mu = latent_info[:, :self.latent_dim] # 潜在变量的均值
logvar = latent_info[:, self.latent_dim:] # 潜在变量的对数方差
latent_sample = reparametrize(mu, logvar) 
latent_input = self.latent_out_proj(latent_sample)
  1. proprio_input = self.input_proj_robot_state(qpos) # qpos来自于forward_pass(data, policy):的image_data, qpos_data, action_data, is_pad = data

为什么env_max_reward 设成0 ?

可能真机不需要看模拟出来的精度?

# load environment
if real_robot:from aloha_scripts.robot_utils import move_grippers # requires alohafrom aloha_scripts.real_env import make_real_env # requires alohaenv = make_real_env(init_node=True)env_max_reward = 0 # 为什么设成0 ?
success_rate = np.mean(np.array(highest_rewards) == env_max_reward)
avg_return = np.mean(episode_returns)
summary_str = f'\nSuccess rate: {success_rate}\nAverage return: {avg_return}\n\n'
for r in range(env_max_reward+1):more_or_equal_r = (np.array(highest_rewards) >= r).sum()more_or_equal_r_rate = more_or_equal_r / num_rolloutssummary_str += f'Reward >= {r}: {more_or_equal_r}/{num_rollouts} = {more_or_equal_r_rate*100}%\n'print(summary_str)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/pingmian/84247.shtml
繁体地址,请注明出处:http://hk.pswp.cn/pingmian/84247.shtml
英文地址,请注明出处:http://en.pswp.cn/pingmian/84247.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

数字ic后端设计从入门到精通6(含fusion compiler, tcl教学)repeater详解

Repeaters RC延迟与导线长度的关系: 导线的电阻(R)和电容(C)都会随着导线长度(l)的增加而增大。RC延迟是电阻和电容共同作用导致的信号延迟。由于RC延迟与R和C的乘积有关,因此它会随…

Data Warebase 成功押注 PostgreSQL 生态,或成 AI 时代数据底座

本文内容整理自 ProtonBase CEO 王绍翾在 AICon 的主题演讲《Data Warebase: Instant Ingest-Transform-Explore-Retrieve for AI Applications》。作者的职业经历贯穿了 AI 1.0、2.0 和 3.0 的时代,从搜索推荐,到视觉 / 语音 / NLP 智能,再到…

【电力电子】基于STM32F103C8T6单片机双极性SPWM逆变(硬件篇)

本项目是基于 STM32F103C8T6 微控制器的 SPWM(正弦脉宽调制)电源模块,能够生成可调频率和幅值的正弦波交流电源输出。该项目适用于逆变器、UPS电源、变频器等应用场景。 供电电源 输入电压采集 上图为本设计的电源电路,图中 D1 为二极管, 其目的是防止正负极电源反接, …

Kubernetes (k8s)版本发布情况

Kubernetes (k8s)版本发布情况 代码放在 GitHub - kubernetes/kubernetes: Production-Grade Container Scheduling and Management https://github.com/kubernetes/kubernetes/releases 文档放在 kubernetes.io各个版本变更等: https://github.com/kubernetes/kubernet…

Python 接口:从协议到抽象基 类(Python使用register的方式)

Python使用register的方式 示例 11-14 把 Tombola.register 当作类装饰器使用。在 Python 3.3 之 前的版本中不能这样使用 register,必须在定义类之后像普通函数那 样调用,如示例 11-14 中最后那行注释所述。 虽然现在可以把 register 当作装饰器使用了…

GRU 参数梯度推导与梯度消失分析

GRU 参数梯度推导与梯度消失分析 1. GRU 前向计算回顾 GRU 单元的核心计算步骤(忽略偏置项): 更新门: z_t σ(W_z [h_{t-1}, x_t]) 重置门: r_t σ(W_r [h_{t-1}, x_t]) 候选状态: ̃h_t tanh(W_h [r_t ⊙ h_{t-1}, x_t]) 新…

【字节拥抱开源】字节团队开源视频模型 ContentV: 有限算力下的视频生成模型高效训练

本项目提出了ContentV框架,通过三项关键创新高效加速基于DiT的视频生成模型训练: 极简架构设计,最大化复用预训练图像生成模型进行视频合成系统化的多阶段训练策略,利用流匹配技术提升效率经济高效的人类反馈强化学习框架&#x…

分布式增量爬虫实现方案

之前我们在讨论的是分布式爬虫如何实现增量爬取。增量爬虫的目标是只爬取新产生或发生变化的页面,避免重复抓取,以节省资源和时间。 在分布式环境下,增量爬虫的实现需要考虑多个爬虫节点之间的协调和去重。 另一种思路:将增量判…

单片机0-10V电压输出电路分享

一、原理图 二、芯片介绍 GP8101是一个PWM信号转模拟信号转换器,相当于一个PWM信号输入,模拟信号输出的DAC。此 芯片可以将占空比为0%到100%的PWM信号线性转换成0-5V或者0-10V的模拟电压,并且输出电压 精度小于1%。GP8101M可以处理高频调制的…

Spring AMQP

在现代分布式系统中,消息队列是一种非常重要的通信机制,它能够实现服务之间的异步通信、负载均衡以及解耦。Spring AMQP 是 Spring 框架对 AMQP(高级消息队列协议)的支持,而 RabbitMQ 是 AMQP 协议的最流行实现之一。通…

第6章:Neo4j数据导入与导出

在实际应用中,数据的导入与导出是使用Neo4j的重要环节。无论是初始数据加载、系统迁移还是数据备份,都需要高效可靠的数据传输机制。本章将详细介绍Neo4j中的各种数据导入与导出方法,帮助读者掌握不同场景下的最佳实践。 6.1 数据导入策略 …

RKNN开发环境搭建1-基于Ubuntu 18.04系统使用Docker安装rknn-toolkit2

目录 写在最前面Docker 方式安装rknn-toolkit2写在最前面 瑞芯微在RKNN的环境搭建方面的资料很多,但是在搭建过程中发现很多问题教程中并未提及,对初学者不友好。所以博主做了这个系列的文章,从开始搭建环境到对于RKNN Model Zoo的示例进行实践,希望能对初学者有帮助。坚持…

【实施指南】Android客户端HTTPS双向认证实施指南

🔐 一、所需准备材料 证书文件(6类核心文件) 类型 格式 作用 Android端要求 CA根证书 .crt/.pem 验证服务器/客户端证书合法性 需预置到Android信任库 服务器证书 .crt 服务器身份证明 客户端需持有以验证服务器 客户端证书 .crt 客户端身份…

FPGA管脚类型,及选择

fpga的IO Type选择,如下: 具体的定义:

SELinux是什么以及如何编写SELinux策略

目录 一、SELinux 是什么? 二、SELinux 的两种模式 如何查看当前 SELinux 状态? 三、SELinux 在 Android 中的作用 四、为什么Root之后很多设备是 Permissive? 五、开发与调试场景 总结 🧩 一、什么是 SELinux 策略&#x…

MQTT示例体验(C)

1、通用依赖准备 安装编译工具‌ Linux/macOS 需安装: sudo apt update && sudo apt install build-essential cmake git # Ubuntu/Debian:ml-citation{ref"6" data"citationList"} brew install cmake # macOSWindows 需安装 CMake…

MySQL中的系统库(简介、performance_schema)

文章目录 性能监控performance_schema1、performance schema入门2、performance_schema表的分类3、performance_schema的简单配置与使用4、常用配置项的参数说明5、重要配置表的相关说明6、performance_schema实践操作 Show processlist 性能监控 每次你提交完一个 sql 语句之…

【Ftrace 专栏】Ftrace 参考博文

ftrace、perf、bcc、bpftrace、ply、simple_perf的使用Ftrace 基本用法Linux 利用 ftrace 分析内核调用如何利用ftrace精确跟踪特定进程调度信息使用 ftrace 进行追踪延迟Linux-培训笔记-ftracehttps://www.kernel.org/doc/html/v4.18/trace/events.htmlhttps://blog.csdn.net/…

bug 记录 - 使用 el-dialog 的 before-close 的坑

需求说明 弹窗中内嵌一个 form 表单 原始代码 <script setup lang"ts"> import { reactive, ref } from "vue" import type { FormRules } from element-plus const ruleFormRef ref() interface RuleForm {name: stringregion: number | null } …

关键领域软件测试的突围之路:如何破解安全与效率的平衡难题

在数字化浪潮席卷全球的今天&#xff0c;软件系统已成为国家关键领域的核心战斗力。不同于普通商业软件&#xff0c;这些承载着国家安全使命的软件系统面临着前所未有的质量挑战——如何在确保绝对安全的前提下&#xff0c;实现高效测试与快速迭代&#xff1f;这一命题正考验着…