之前已经完整的拆解了CLIP中所用到的ResNet、ViT和Transformer三个模型(CLIP拆解-CSDN博客),这篇将讲解model.py实现中的其他细节。

1.关于ResNet模型中vision_head的设置

ResNet:

vision_heads = vision_width * 32 // 64

ViT:

vision_heads = vision_width // 64

ResNet需要乘32是因为经过前面卷积处理后输入AttentionPool2d的是width*32,所以计算head的时候要把这个考虑进去。至于这里的64是分为多头后每一个头的embed的通道数,ResNet通常取64,ViT-B常取768

2.关于conver_weights

convert_weights() 是为了节省显存、提高推理速度,将模型中适合的权重转换为 fp16。

(1)half()的作用 就是把fp32转为fp16,如果输入本身是 fp16,那将不进行任何处理。

(2)一些结构不建议转化为fp16,因为转化后会不稳定,所以选择性的处理

    def _convert_weights_to_fp16(l):if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):l.weight.data = l.weight.data.half()if l.bias is not None:l.bias.data = l.bias.data.half()if isinstance(l, nn.MultiheadAttention):for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:tensor = getattr(l, attr)if tensor is not None:tensor.data = tensor.data.half()for name in ["text_projection", "proj"]:if hasattr(l, name):attr = getattr(l, name)if attr is not None:attr.data = attr.data.half()

下面是常见的不建议使用fp16的模块:

模块/操作原因说明
LayerNorm / BatchNorm均值/方差运算容易数值下溢,精度敏感
Softmax / LogSoftmax输出接近 0 或 1,fp16 下舍入误差大
Sigmoid / Tanh对小输入不敏感,精度损失后容易失效
CrossEntropyLoss包含 log(softmax),fp16 精度不足导致数值不稳定
Attention(部分实现)scaled dot-product 会导致爆炸,尤其是大输入或长序列时
Exp, Div, Log本身不稳定,数值小容易下溢出为 0

3.模型输入也要相应的进行转化,否则会遇到类型不匹配的问题

 解决方法1:使用autocast

from torch.cuda.amp import autocastwith autocast():output = model(x)  # 自动在每一层内部管理精度转换

但autocast只针对模块的外部类型来判断是否进行类型转化(如nn.Linear, nn.Conv2d),但是自定义的模块(类)autocast不会进行类型转换(autocast只是解决了类型不匹配的问题,但是低精度产生的梯度爆炸等问题无法解决,由反向传播时gradscaler解决)

问题场景AMP 是否能处理说明
输入是 fp16,模块需要 fp32autocast() 会自动转换
自定义模块内部 +,/ 导致类型错❌ 你要自己管理,AMP 不管你自写的算子
梯度为 0 或爆炸GradScaler() 自动放大/还原
权重混用不同精度✅ 支持
推理时类型优化(加速,混用不同精度)✅ 只用 autocast() 即可

解决方法2:手动转化类型

# 例如 LayerNorm 中人为转 float32:
def forward(self, x):orig_type = x.dtyperet = super().forward(x.float())  # 保证 LayerNorm 在 float32 下执行return ret.to(orig_type)

4.关于forward的输出

# cosine similarity as logits
logit_scale = self.logit_scale.exp()
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_text = logits_per_image.t()

logit_scale是缩放因子,定义是self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))

logits_per_image是图像视角下的相似度分布,用于计算图像到文本的对比损失

logits_per_text是文本视角下的相似度分布,和图像视角下对称。

5.关于权重初始化

(1)ResNet的bn3初始化为0

for resnet_block in [self.visual.layer1,self.visual.layer2,self.visual.layer3,self.visual.layer4]:for name, param in resnet_block.named_parameters():if name.endswith("bn3.weight"):nn.init.zeros_(param)

手动初始化bn3.weight为0确保为恒等映射,从而防止残差支路输出不稳定、扰动太大的问题。

(2)CLIP中的手动初始化和自动初始化

CLIP只手动初始化了一些对训练稳定性或性能影响较大的模块,如embedding和位置编码(nanoGPT中也对这两个部分进行了手动初始化)、QKVC投影、transformer最后输出的初始化

    def initialize_parameters(self):nn.init.normal_(self.token_embedding.weight, std=0.02)nn.init.normal_(self.positional_embedding, std=0.01)if isinstance(self.visual, ModifiedResNet):if self.visual.attnpool is not None:std = self.visual.attnpool.c_proj.in_features ** -0.5nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:for name, param in resnet_block.named_parameters():if name.endswith("bn3.weight"):nn.init.zeros_(param)proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)attn_std = self.transformer.width ** -0.5fc_std = (2 * self.transformer.width) ** -0.5for block in self.transformer.resblocks:nn.init.normal_(block.attn.in_proj_weight, std=attn_std)nn.init.normal_(block.attn.out_proj.weight, std=proj_std)nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)if self.text_projection is not None:nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)

***与nanoGPT的_init_weights对比

    # mainself.apply(self._init_weights)# apply special scaled init to the residual projections, per GPT-2 paperfor pn, p in self.named_parameters():if pn.endswith('c_proj.weight'):torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))#init_weightdef _init_weights(self, module):if isinstance(module, nn.Linear):torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)if module.bias is not None:torch.nn.init.zeros_(module.bias)elif isinstance(module, nn.Embedding):torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

*GPT
GPT 结构对初始化非常敏感,GPT 使用残差连接 + LayerNorm,梯度传播对初始权重分布非常依赖。所以在初始化的时候Linear和Embedding的weight的mean都初始化为0

*CLIP

CLIP更复杂,只初始化关键敏感部件,如embedding、positional encoding、attention等。

***目前总结到的经验

*建议手动初始化:

模块类型初始化建议原因
Embedding手动正态初始化(如 std=0.01~0.02)防止稀疏索引导致偏置
Q/K/V Linear手动初始化(如 std=1/√d_k防止 attention dot-product 初始值爆炸
Positional Embedding正态初始化因为是 learnable 参数,数值不宜过大
残差 block 最后一层(如 BN3)初始化为 0初始退化为恒等映射,提高收敛性
任何“关键分支”的 projection 层建议初始化如 CLIP 的 text_projection, image_projection

 一般不主动初始化:

模块类型

理由
Conv2d, Linear默认初始化已很好,除非有论文要求
LayerNorm, BatchNorm默认 weight=1, bias=0 是最优策略
非残差中的普通线性层默认即可

(3)初始化时std的设置

① attn_std = self.transformer.width ** -0.5

标准的transformer初始化方法

②fc_std = (2 * self.transformer.width) ** -0.5

用于初始化FFN中的前向Linear层,第一层输出通道很大(通常是 4×),为了避免输出激活过大,std 要适当减小

x → Linear(width, 4*width) → GELU → Linear(4*width, width)

③proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)

用于 Residual AttentionBlock 最后投影的 Linear 层

来源:来自论文 Understanding the Difficulty of Training Transformers,特别适用于 深层 Transformer(如 GPT-3, CLIP)

核心思想是:

如果模型深度是 L 层,那每个 residual branch 叠加的方差也会增加,应该将其 std 缩小为 1/sqrt(2L)以稳定整体输出。

 6.关于build_model的参数的使用

(1)

vision_width = state_dict["visual.conv1.weight"].shape[0]
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]

这里使用visual.conv1.weight的第一个维度的大小作为width,conv2d的weight的形状是(out_channle, in_channel, patch_size[0], patch_size[1])。

另外这里补充一下ViT patch和传统CNN卷积核的区别:
传统CNN是使用多个小卷积堆叠构建大感受野(kernel_size较小,stride小于kernel_size允许重叠),而ViT是使用一个大kernel,把整块patch当作token(kernel_size较大,stride=kernel_size,即不重复采样)

(2)

vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])

每个 Transformer block 里会有一个 nn.MultiheadAttention 模块,对应权重名如:visual.transformer.resblocks.0.attn.in_proj_weight

(3)

grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
image_resolution = vision_patch_size * grid_size

这里image_resolution是因为

self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))

(4)几个易混淆的概念

名字意义举例值类似于
vision_width通道维度64、128、256 等CNN 中的输出 channels
output_width特征图尺寸7、14 等feature map 的宽度
patch_sizepatch 的边长32ViT 中的切片大小‘

(5)ResNet中image_resolution = output_width * 32

*32是因为在ResNet中总共下采样了5次

模块操作类型输出尺寸
conv1stride=2变成 H/2 × W/2
stem_poolAvgPool2d(2)变成 H/4 × W/4
layer1无下采样尺寸不变
layer2stride=2变成 H/8 × W/8
layer3stride=2变成 H/16 × W/16
layer4stride=2变成 H/32 × W/32 ✅ 最终输出
attnpool空间尺寸 = H/32 × W/32

 (6)删除state_dict中的一些辅助信息字段

    for key in ["input_resolution", "context_length", "vocab_size"]:if key in state_dict:del state_dict[key]

这些不是模型参数的一部分,加载模型权重前必须删掉,否则会引起state_dict键不匹配

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/diannao/89630.shtml
繁体地址,请注明出处:http://hk.pswp.cn/diannao/89630.shtml
英文地址,请注明出处:http://en.pswp.cn/diannao/89630.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

国科大深度学习作业1-手写数字识别实验

背景介绍:单位实习,趁机摸鱼,由于电脑只安装了VSCode,所以算是从环境搭建写起。 目录 一、环境搭建 1. 安装Anaconda 2. 创建Python环境 3. 安装PyTorch 4. 安装其他必要库 二、在 VSCode 中配置环境 1. 安装Pytho…

基于Spring Boot的绿园社区团购系统的设计与实现

第1章 摘 要 本设计与实现的基于Spring Boot的绿园社区团购系统,旨在为社区居民提供一套高效、便捷的团购购物解决方案。随着电子商务的发展和社区居民对便捷购物需求的增加,传统的团购模式已无法满足用户的个性化需求。本系统通过整合现代化技术&…

【51单片机四位数码管从0循环显示到99,每0.5秒增加一个数字,打击键计数】2022-6-11

缘由 #include "REG52.h" unsigned char code smgduan[]{0x3f,0x06,0x5b,0x4f,0x66,0x6d,0x7d,0x07,0x7f,0x6f,0x77,0x7c,0x39,0x5e,0x79,0x71,0,64,15,56}; //共阴0~F消隐减号 unsigned char Js0, miao0;//中断计时 秒 分 时 毫秒 unsigned int shu0; //bit Mb0;//…

如何通过python脚本向redis和mongoDB传点位数据

向MongoDB传数据 from pymongo import MongoClient #导入库对应的库localhost "172.16.0.203" #数据库IP地址 baseName "GreenNagoya" client MongoClient(localhost, 27017, username"admin", password"zdiai123") #数…

昆仑通泰触摸屏Modbus TCP服务器工程 || TCP客户端工程

目录 一、Modbus TCP服务端 1.设备地址 2.实操及数据 二、Modbus TCP客户端 1.结果及协议解析 一、Modbus TCP服务端 1.设备地址 --单元标识符 DI输入/4个离散输入 DO输出/单个线圈输出 输入寄存器 读输入寄存器操作,写输入寄存器操作 保持寄存器 …

PyTorch 安装使用教程

一、PyTorch 简介 PyTorch 是由 Facebook AI Research 团队开发的开源深度学习框架。它以动态图机制、灵活性强、易于调试而著称,广泛应用于自然语言处理、计算机视觉和学术研究。 二、安装 PyTorch 2.1 通过官网选择安装命令(推荐) 访问官…

开源功能开关(feature flags) 和管理平台之unleash

文章目录 背景Flagsmith 和 Unleash什么是unleash架构Unleash Edge 安装和使用Unleash SDKs开放API Tokens访问**Server-side SDK (CLIENT)****查询所有 Feature Toggles****查询特定 Toggle** API token typesClient tokensFrontend tokensPersonal access tokensService acco…

细胞建模“图灵测试”:解析学习虚拟细胞挑战赛

一、AI能否预测细胞的未来? 想象一下,有一天我们不必一管管地做实验,就能在计算机中模拟细胞对基因敲除、药物处理乃至微环境变化的反应。这不再是科幻,而是“虚拟细胞”(Virtual Cell)研究的宏大目标。然…

centos9安装docker Dify

CentOS | Docker Docs yum -y install gcc gcc-c yum-utils Docker 官方的 YUM 软件仓库配置文件到系统,设置存储库 yum-config-manager --add-repo https://download.docker.com/linux/centos/docker-ce.repo 也可以从阿里云下(我选择上面的) yum-config-manager --add-re…

基于Jenkins和Kubernetes构建DevOps自动化运维管理平台

目录 引言 基础概念 DevOps概述 Jenkins简介 Kubernetes简介 Jenkins与Kubernetes的关系 Jenkins与Kubernetes的集成 集成架构 安装和配置 安装Jenkins 安装Kubernetes插件 配置Kubernetes连接 配置Jenkins Agent Jenkins Pipeline与Kubernetes集成 Pipeline定义…

MySQL 8.0 OCP 1Z0-908 题目解析(18)

题目69 Choose three. A MySQL server is monitored using MySQL Enterprise Monitor’s agentless installation. Which three features are available with this installation method? □ A) MySQL Replication monitoring □ B) security-related advisor warnings □ …

【mongodb】安装和使用mongod

文章目录 前言一、如何安装?二、使用步骤1. 开启mongod服务2. 客户端连接数据库3. 数据库指令 总结 前言 Mongodb的安装可以直接安装系统默认的版本,也可以安装官网维护的版本,相对而言更推荐安装官网维护的版本,版本也相当更新。…

云效DevOps vs Gitee vs 自建GitLab的技术选型

针对「云效DevOps vs Gitee vs 自建GitLab」的技术选型,我们从核心需求、成本、运维、扩展性四个维度进行深度对比,并给出场景化决策建议: 一、核心能力对比表 能力维度云效DevOpsGitee自建GitLab(社区版/企业版)代码…

CentOS 7 安装RabbitMQ详细教程

前言:在分布式系统架构中,消息队列作为数据流转的 “高速公路”,是微服务架构不可或缺的核心组件。RabbitMQ 凭借其稳定的性能、灵活的路由机制和强大的生态支持,成为企业级消息中间件的首选之一。不过,当我们聚焦 Cen…

Python爬虫用途和介绍

目录 什么是Python爬虫 Python爬虫用途 Python爬虫可以获得那些数据 Python爬虫的用途 反爬是什么 常见的反爬措施 Python爬虫技术模块总结 获取网站的原始响应数据 获取到响应数据对响应数据进行过滤 对收集好的数据进行存储 抵御反爬机制 Python爬虫框架 Python…

uni-app开发app保持登录状态

在 uni-app 中实现用户登录一次后在 token 过期前一直免登录的功能,可以通过以下几个关键步骤实现:本地持久化存储 Token、使用请求与响应拦截器自动处理 Token 刷新、以及在 App.vue 中结合 pages.json 设置登录状态跳转逻辑。 ✅ 一、pages.json 配置说…

21、MQ常见问题梳理

目录 ⼀ 、MQ如何保证消息不丢失 1 、哪些环节可能会丢消息 2 、⽣产者发送消息如何保证不丢失 2.1、⽣产者发送消息确认机制 2.2、Rocket MQ的事务消息机制 2.3 、Broker写⼊数据如何保证不丢失 2.3.1** ⾸先需要理解操作系统是如何把消息写⼊到磁盘的**。 2.3.2然后来…

MySQL数据库--SQL DDL语句

SQL--DDL语句 1,DDL-数据库操作2,DDL-表操作-查询3,DDL-表操作-创建4,DDL-表操作-数据类型4.1,DDL-表操作-数值类型4.2,DDL-表操作-字符串类型4.3,DDL-表操作-日期时间类型4.4,实例 …

Spring Cloud 服务追踪实战:使用 Zipkin 构建分布式链路追踪

Spring Cloud 服务追踪实战:使用 Zipkin 构建分布式链路追踪 在分布式微服务架构中,一个用户请求往往需要经过多个服务协作完成,如果出现性能瓶颈或异常,排查会非常困难。此时,分布式链路追踪(Distributed…

Linux云计算基础篇(6)

一、IO重定向和管道 stdin:standard input 标准输入 stdout:standard output 标准输出 stderr: standard error 标准错误输出 举例 find /etc/ -name passwd > find.out 将正确的输出重定向在这个find.ou…