4.1 手动实现权重衰减

import torch
from torch import nn
from torch.utils.data import TensorDataset,DataLoader
import matplotlib.pyplot as plt
def synthetic_data(w,b,num_inputs):X=torch.normal(0,1,size=(num_inputs,w.shape[0]))y=X@w+by+=torch.normal(0,0.1,size=y.shape)return X,y
def load_array(data,batch_size,is_train=True):dataset=TensorDataset(*data)return DataLoader(dataset,batch_size=batch_size,shuffle=is_train)
def init_params(num_inputs):w=torch.normal(0,1,size=(num_inputs,1),requires_grad=True)b=torch.zeros(1,requires_grad=True)return [w,b]
def l2_penalty(w):return 0.5*torch.sum(w.pow(2))def linear_reg(X,w,b):return torch.matmul(X,w)+b
def mse_loss(y_hat,y):return (y_hat-y)**2/2
def sgd(params,lr,batch_size):for params in params:params.data-=lr*params.grad/batch_sizeparams.grad.zero_()
def evaluate_loss(net, data_iter, loss):total_loss, total_samples = 0.0, 0for X, y in data_iter:l = loss(net(X), y)total_loss += l.sum().item()total_samples += y.numel()return total_loss / total_samples
n_train,n_test,num_inputs,batch_size=20,100,200,5
true_w,true_b=torch.ones((num_inputs,1))*0.01,0.05
train_data=synthetic_data(true_w,true_b,n_train)
test_data=synthetic_data(true_w,true_b,n_test)
train_iter=load_array(train_data,batch_size)
test_iter=load_array(test_data,batch_size,is_train=False)
w,b=init_params(num_inputs)
net=lambda X:linear_reg(X,w,b)
loss=mse_loss
num_epochs,lr,lambd=10,0.05,3
#animator=SimpleAnimator()
for epoch in range(num_epochs):for X,y in train_iter:l=loss(net(X),y)+lambd*l2_penalty(w)l.sum().backward()sgd([w,b],lr,batch_size)if (epoch+1)%5==0:train_loss=evaluate_loss(net,train_iter,loss)test_loss=evaluate_loss(net,test_iter,loss)#animator.add(epoch+1,train_loss,test_loss)print(f"Epoch {epoch+1}: Train Loss: {train_loss:.4f},test Loss: {test_loss:.4f}")
print('w的L2范数是:', torch.norm(w).item())
plt.show()

4.2 简单实现权重衰减

import torch
from torch import nn
from torch.utils.data import TensorDataset,DataLoader
import matplotlib.pyplot as plt
def synthetic_data(w,b,num_inputs):X=torch.normal(0,1,size=(num_inputs,w.shape[0]))y=X@w+by+=torch.normal(0,0.1,size=y.shape)return X,y
def load_array(data,batch_size,is_train=True):dataset=TensorDataset(*data)return DataLoader(dataset,batch_size=batch_size,shuffle=is_train)
def init_params(num_inputs):w=torch.normal(0,1,size=(num_inputs,1),requires_grad=True)b=torch.zeros(1,requires_grad=True)return [w,b]
def l2_penalty(w):return 0.5*torch.sum(w.pow(2))
def linear_reg(X,w,b):return torch.matmul(X,w)+b
def mse_loss(y_hat,y):return ((y_hat-y)**2).sum()/2
def evaluate_loss(net, data_iter, loss):total_loss, total_samples = 0.0, 0for X, y in data_iter:l = loss(net(X), y)total_loss += l.item()*y.shape[0]total_samples += y.numel()return total_loss / total_samples
n_train,n_test,num_inputs,batch_size=20,100,200,5
true_w,true_b=torch.ones((num_inputs,1))*0.01,0.05
train_data=synthetic_data(true_w,true_b,n_train)
test_data=synthetic_data(true_w,true_b,n_test)
train_iter=load_array(train_data,batch_size)
test_iter=load_array(test_data,batch_size,is_train=False)
w,b=init_params(num_inputs)
net=lambda X:linear_reg(X,w,b)
loss=mse_loss
num_epochs,lr,lambd=100,0.001,3
optimizer=torch.optim.SGD([w,b],lr=lr,weight_decay=0.001)
#animator=SimpleAnimator()
for epoch in range(num_epochs):for X,y in train_iter:optimizer.zero_grad()l=loss(net(X),y)l.backward()#sgd([w,b],lr,batch_size)optimizer.step() if (epoch+1)%5==0:train_loss=evaluate_loss(net,train_iter,loss)test_loss=evaluate_loss(net,test_iter,loss)#animator.add(epoch+1,train_loss,test_loss)print(f"Epoch {epoch+1}: Train Loss: {train_loss:.4f},test Loss: {test_loss:.4f}")
print('w的L2范数是:', torch.norm(w).item())
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/bicheng/88189.shtml
繁体地址,请注明出处:http://hk.pswp.cn/bicheng/88189.shtml
英文地址,请注明出处:http://en.pswp.cn/bicheng/88189.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

OpenCV开发-初始概念

第一章 OpenCV核心架构解析1.1 计算机视觉的基石OpenCV(Open Source Computer Vision Library)作为跨平台计算机视觉库,自1999年由Intel发起,已成为图像处理领域的标准工具。其核心价值体现在:跨平台性:支持…

LeetCode 930.和相同的二元子数组

给你一个二元数组 nums ,和一个整数 goal ,请你统计并返回有多少个和为 goal 的 非空 子数组。 子数组 是数组的一段连续部分。 示例 1: 输入:nums [1,0,1,0,1], goal 2 输出:4 解释: 有 4 个满足题目要求…

【论文解读】Referring Camouflaged Object Detection

论文信息 论文题目:Referring Camouflaged Object Detection 论文链接:https://arxiv.org/pdf/2306.07532 代码链接:https://github.com/zhangxuying1004/RefCOD 录用期刊:TPAMI 2025 论文单位:南开大学 ps&#xff1a…

Spring中过滤器和拦截器的区别及具体实现

在 Spring 框架中,过滤器(Filter) 和 拦截器(Interceptor) 都是用于处理 HTTP 请求的中间件,但它们在作用范围、实现方式和生命周期上有显著区别。以下是详细对比和实现方式:核心区别特性过滤器…

CANFD 数据记录仪在新能源汽车售后维修中的应用

一、前言随着新能源汽车市场如火如荼和新能源汽车电子系统的日益复杂,传统维修手段在面对复杂和偶发故障时往往捉襟见肘,CANFD 数据记录仪则凭借其独特优势,为售后维修带来新的解决方案。二、 详细介绍在新能源汽车领域,CANFD 数据…

某当CRM XlsFileUpload存在任意文件上传(CNVD-2025-10982)

免责声明 本文档所述漏洞详情及复现方法仅限用于合法授权的安全研究和学术教育用途。任何个人或组织不得利用本文内容从事未经许可的渗透测试、网络攻击或其他违法行为。使用者应确保其行为符合相关法律法规,并取得目标系统的明确授权。 前言: 我们建立了一个更多,更全的…

自然语言处理与实践

文章目录Lesson1:Introduction to NLP、NLP 基础与文本预处理1.教材2.自然语言处理概述(1)NLP 的定义、发展历程与应用场景(2)NLP 的主要任务:分词、词性标注、命名实体识别、句法分析等2.文本预处理3.文本表示方法:词向量表示/词表征Lesson2…

CSS揭秘:9.自适应的椭圆

前置知识:border-radius 用法前言 本篇目标是实现一个椭圆,半椭圆,四分之一椭圆。 一、圆形和椭圆 当我们想实现一个圆形时,通常只要指定 border-radius 为 width/height 的一半就可以了。 当我们指定的border-radius的值超过了 w…

善用关系网络:开源AI大模型、AI智能名片与S2B2C商城小程序赋能下的成功新路径

摘要:本文聚焦于关系在个人成功中的关键作用,指出关系即财富,善用关系、拓展人脉是成功的重要途径。在此基础上,引入开源AI大模型、AI智能名片以及S2B2C商城小程序等新兴技术工具,探讨它们如何助力个体在复杂的关系网络…

2025年渗透测试面试题总结-2025年HW(护网面试) 34(题目+回答)

安全领域各种资源,学习文档,以及工具分享、前沿信息分享、POC、EXP分享。不定期分享各种好玩的项目及好用的工具,欢迎关注。 目录 2025年HW(护网面试) 34 一、网站信息收集 核心步骤与工具 二、CDN绕过与真实IP获取 6大实战方法 三、常…

萤石全新上线企业AI对话智能体,开启IoT人机交互新体验

一、什么是萤石AI对话智能体?如何让设备听得到、听得懂?这次萤石发布的AI对话Agent,让设备能进行自然、流畅、真人感的AI对话智能体,帮助开发者打造符合业务场景的AI对话智能体能力,实现全双工、实时打断、可扩展、对话…

智绅科技:以科技为翼,构建养老安全守护网

随着我国老龄化进程加速,2025年60岁以上人口突破3.2亿,养老安全问题成为社会关注的焦点。智绅科技作为智慧养老领域的领军企业,以“科技赋能健康,智慧守护晚年”为核心理念,通过人工智能、物联网、大数据等技术融合&am…

矩阵系统源码部署实操指南:搭建全解析,支持OEM

矩阵系统源码部署指南矩阵系统是一种高效的数据处理框架,适用于大规模分布式计算。以下为详细部署步骤,包含OEM支持方案。环境准备确保服务器满足以下要求:操作系统:Linux(推荐Ubuntu 18.04/CentOS 7)硬件配…

基于python的个人财务记账系统

博主介绍:java高级开发,从事互联网行业多年,熟悉各种主流语言,精通java、python、php、爬虫、web开发,已经做了多年的毕业设计程序开发,开发过上千套毕业设计程序,没有什么华丽的语言&#xff0…

从 CODING 停服到极狐 GitLab “接棒”,软件研发工具市场风云再起

CODING DevOps 产品即将停服的消息,如同一颗重磅炸弹,在软件研发工具市场炸开了锅。从今年 9 月开始,CODING 将陆续下线其 DevOps 产品,直至 2028 年 9 月 30 日完全停服。这一变动让众多依赖 CODING 平台的企业和个人开发者陷入了…

#渗透测试#批量漏洞挖掘#HSC Mailinspector 任意文件读取漏洞(CVE-2024-34470)

免责声明 本教程仅为合法的教学目的而准备,严禁用于任何形式的违法犯罪活动及其他商业行为,在使用本教程前,您应确保该行为符合当地的法律法规,继续阅读即表示您需自行承担所有操作的后果,如有异议,请立即停…

深入解析C++驱动开发实战:优化高效稳定的驱动应用

深入解析C驱动开发实战:优化高效稳定的驱动应用 在现代计算机系统中,驱动程序(Driver)扮演着至关重要的角色,作为操作系统与硬件设备之间的桥梁,驱动程序负责管理和控制硬件资源,确保系统的稳定…

SNIProxy 轻量级匿名CDN代理架构与实现

🌐 SNIProxy 轻量级匿名CDN代理架构与实现 🏗️ 1. 整体架构设计 🔹 1.1 系统架构概览 #mermaid-svg-S4n74I2nPLGityDB {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-S4n74I2nP…

Qt的信号与槽(一)

Qt的信号与槽(一)1.信号和槽的基本认识2.connect3.关闭窗口的按钮4.函数的根源5.形参和实参的类型🌟hello,各位读者大大们你们好呀🌟🌟 🚀🚀系列专栏:【Qt的学习】 &…

springMVC02-视图解析器、RESTful设计风格,静态资源访问配置

一、SpringMVC 的视图在 SpringMVC 中,视图的作用渲染数据,将模型 Model (将控制器(Controller))中的数据展示给用户。在 Java 代码中,视图由接口 org.springframework.web.servlet.View 表示SpringMVC 视图的种类很多…