RNN小练习

要求:

假设有 4 个字 吃 了 没 ?,请使用 torch.nn.RNN 完成以下任务

  • 将每个进行 one-hot 编码
  • 请使用 吃 了 没 作为输入序列,了 没 ? 作为输出序列
  • RNN 的 hidden_size = 64
  • 请将 RNN 的输出使用全连接转换成 4 个特征,并使用 CrossEntropyLoss 训练模型
  • 训练模型并验证

1、准备数据集

import torch.nn.functional
from torch.utils.data import Datasetclass mydataset(Dataset):def __init__(self):super().__init__()texts = '吃 了 没 ?'self.words = texts.split()self.input = self.words[:3]self.label = self.words[1:]def __len__(self):return 1def __getitem__(self, idx):# 对输入进行 one_hot 编码inp = torch.nn.functional.one_hot(torch.tensor([self.words.index(word) for word in self.input]),len(self.words)).float()# 对标签进行编码,返回文字的索引label = torch.tensor([self.words.index(word) for word in self.label])return inp, label

2、创建模型

import torch.nn as nnclass mymodel(nn.Module):def __init__(self):super().__init__()self.rnn = nn.RNN(4,64,nonlinearity='relu')self.fc1 = nn.Linear(64,4)def forward(self, x,h=None):x,h = self.rnn(x,h)x = self.fc1(x)return x,h

3、训练模型以及预测

import torch.nn as nn
from torch import optimfrom myset import mydataset
from mymodel import mymodelEPOCH = 1000
LR = 1e-2ds = mydataset()
inputs,lables = ds[0]model = mymodel()loss_fn = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(),lr=LR)for epoch in range(EPOCH):optimizer.zero_grad()y,h = model(inputs)loss = loss_fn(y,lables)print(loss)loss.backward()optimizer.step()model.eval()y,h = model(inputs)y = y.softmax(-1)
maxarg = y.argmax(-1)print([ds.words[indx] for indx in maxarg.tolist()])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/pingmian/96165.shtml
繁体地址,请注明出处:http://hk.pswp.cn/pingmian/96165.shtml
英文地址,请注明出处:http://en.pswp.cn/pingmian/96165.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

ESPIDF官方文档,启用dhcp会禁用对应的STA或AP的静态IP,我测试STA确实是,但是AP不是,为什么

1. STA 模式下的 DHCP(客户端角色)ESP32 当 Station(STA) 时,它的行为就跟你的手机/笔记本连 Wi-Fi 一样:DHCP 客户端 → 去路由器(DHCP 服务器)要一个 IP。特点启用 DHCP&#xff0…

cocos2d. 3.17.2 c++如何实现下载断点续传zip压缩包带进度条

新建类CurlDown #include “curl/curl.h” #include using namespace std; USING_NS_CC; /** 资源下载curl */ class CurlDown { public: CurlDown(); ~CurlDown(); void StartDownResZip(string downLoadUrl, int64_t totalSize); //下载控制 void downloadControler(); //下…

MySQL 整型数据类型:选对数字类型,让存储效率翻倍

MySQL 整型数据类型:选对数字类型,让存储效率翻倍 在 MySQL 中,整型(整数类型)是最常用的数据类型之一,从用户 ID 到商品数量,几乎所有涉及数字的场景都离不开它。但你知道吗?选对整…

公司电脑监控软件有哪些?公司电脑监控软件应该怎么选择

大家好呀,电竞直播运营团队常常面临 “直播脚本被抄袭、用户付费数据篡改、主播话术外泄” 的问题!尤其是独家直播流程脚本、用户充值记录、主播互动话术库、赛事解说手稿,一旦泄露可能导致竞品跟风、用户信任下降、直播竞争力减弱&#xff5…

ARM裸机开发:链接脚本、进阶Makefile(bsp)、编译过程、beep实验

一、链接脚本的作用?各个段存放什么数据类型(一)链接脚本内容SECTIONS {. 0x87800000;.text : {obj/start.o*(.text)}.rodata ALIGN(4) : {*(.rodata*)}.data ALIGN(4) : {*(.data)}__bss_start .;.bss ALIGN(4) : {*(.bss) *(COMMON)}__bs…

Linux驱动开发(1)概念、环境与代码框架

一、驱动概念驱动与底层硬件直接打交道,充当了硬件与应用软件中间的桥梁。1、具体任务(1)读写设备寄存器(实现控制的方式)(2)完成设备的轮询、中断处理、DMA通信(CPU与外设通信的方式…

计算机视觉(十):ROI

什么是感兴趣区域(ROI)? 在计算机视觉中,**感兴趣区域(ROI)**指的是图像中包含我们想要分析、处理或识别的目标或特征的特定子集。就像我们在阅读一本书时会聚焦于某个重要的段落,计算机视觉系统…

Jenkins 构建 Node 项目报错解析与解决——pnpm lockfile 问题实战

在使用 Jenkins 自动化构建 Node.js 项目时,经常会遇到类似报错: ERR_PNPM_OUTDATED_LOCKFILE  Cannot install with "frozen-lockfile" because pnpm-lock.yaml is not up to date with package.json Error: Cannot find module node_module…

Kafka在多环境中安全管理敏感

1. 配置提供者是什么? 配置提供者(ConfigProvider)是一类按需“拉取配置”的组件:应用读取配置时,按约定的占位符语法去外部来源(目录、环境变量、单一 properties 文件、你自定义的来源……)取…

编程工具的演进逻辑:从Python IDLE到Arduino IDE的深度剖析

引言:工具进化的本质 在编程学习与开发的道路上,我们总会与各种各样的工具相遇。一个有趣的现象是,无论是初学者的第一款工具Python IDLE,还是硬件爱好者常用的Thonny和Arduino IDE,它们都自称“集成开发环境”(IDE)。这背后隐藏着怎样的逻辑? 本文将带你深入分析这三…

p10k configure执行报错: ~/powerlevel10k/config/p10k-lean.zsh is not readable

[ERROR] p10k configure: ~/powerlevel10k/config/p10k-lean.zsh is not readable 背景 我移动了Powerlevel10k文件夹的位置,导致p10k configure命令找不到powerlevel10k文件夹的位置。 原来Powerlevel10k的位置:~/powerlevel10k 移动后Powerlevel10k的位…

Java 学习笔记(进阶篇3)

1. 美化界面关键逻辑 1:// 相对路径:直接从项目的 src 目录开始写,不包含 D:\ 和个人名字 ImageIcon bg new ImageIcon("src/image/background.png"); JLabel background new JLabel(bg);这两行代码是 Swing 中加载并显示图片的经…

BFD 概述

BFD简介1.BFD:Bidirectional Forwarding Detection,双向转发检查概述:毫秒级链路故障检查,通常结合三层协议(如静态路由、vrrp、 ospf、 BGP等)实现链路故障快速切换。作用:① 检测二层非直连故障② 加快三层协议收敛底…

【嵌入式DIY实例-ESP32篇】-Flappy Bird游戏

Flappy Bird游戏 文章目录 Flappy Bird游戏 1、游戏介绍 2、硬件准备与接线 3、代码实现 《Flappy Bird》游戏以其引人入胜的玩法和简约的设计风靡全球。本文将探讨如何使用 OLED SSD1306 显示屏和 ESP32 微控制器重现这款经典游戏。这个 DIY 项目不仅充满乐趣,也是学习编程和…

[数据结构——lesson2.顺序表]

目录 学习目标 引言 1.什么是线性表? 2.什么是顺序表? 2.1概念及结构 2.2 接口实现 2.2.1顺序表的功能 1.顺序表的初始化 2.打印数据 3.尾插数据 (1)检查空间 (2)插入数据 4.尾删数据 5.头插数据 6.头删数据 7.数据查找 8.指定位置数据…

ChatGPT大模型训练指南:如何借助动态代理IP提高训练效率

随着人工智能技术的飞速发展,ChatGPT等大型语言模型(LLM)已成为科技界和产业界关注的焦点。模型的训练过程耗时、耗资源且对网络环境要求极高。尤其是在需要模拟真实用户行为、进行大规模数据爬取或分布式训练的场景下,单一IP地址…

Docker 学习笔记(六):多容器管理与集群部署实践

Docker Docker-compose 单个 Dockerfile 可定义单容器应用,但日常工作中,Web 项目等常需 Web 服务、数据库、负载均衡等多容器配合,手动按序启停容器会导致维护量大、效率低。 Docker Compose 是高效的多容器管理工具,通过单个 do…

C++类和对象初识

面向过程 1.1 面向过程特点 1.2 通俗解释:煮方便面 1.3 面向过程实现代码 1.4 特点总结面向对象 2.1 面向对象特点 2.2 通俗解释:对象协作思维 2.3 面向对象实现代码 2.4 特点总结面向对象和面向过程总结C 面向对象介绍 4.1 面向对象三大基本特征封装&am…

C++ Int128 —— 128位有符号整数类实现剖析

🧠 C Int128 —— 128位有符号整数类实现剖析 引用:openppp2/ppp/Int128.h 🏗️ 1. 存储结构设计 #mermaid-svg-2JDFsdz6MTbX253D {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-sv…

【C 语言生成指定范围随机数(整数 + 小数):原理、实现与避坑指南】

概述 在 C 语言开发中,生成指定范围的随机数是高频需求(如游戏随机道具、数据模拟、测试用例生成等)。但很多新手会卡在 “范围控制”“随机数重复”“小数生成” 等问题上。本文结合实战场景,从原理到代码详细讲解如何生成 1100、…