- Transformers 模型设计上是可定制的。
- 每个模型的代码都包含在 Transformers 仓库的
model
子文件夹中(transformers/src/transformers/models at main · huggingface/transformers),每个模型文件夹通常包含:modeling.py
:定义模型结构与前向传播configuration.py
:定义模型的超参数配置
1 配置(Configuration)
1.1 自定义配置
- 自定义配置类的要点:
- 必须继承自
PretrainedConfig
,以继承from_pretrained()
、save_pretrained()
等功能; - 构造函数
__init__()
必须接收任意**kwargs
并传给父类; - 添加
model_type
属性,以支持 AutoClass; - 可以加入参数校验逻辑。
- 必须继承自
1.2 保存配置
resnet50d_config = ResnetConfig(block_type="bottleneck", stem_width=32, stem_type="deep", avg_down=True)
resnet50d_config.save_pretrained("custom-resnet")
2 模型结构
- 模型类需要继承自
PreTrainedModel
,并接受配置对象作为输入 - Transformers 约定模型的所有超参数由配置对象提供
- 可以构建两种模型:
2.1 裸模型(输出隐藏状态)
2.2 带分类头的模型(支持 Trainer
,输出 logits 和 loss)
2.3 加载预训练权重
import timmresnet50d = ResnetModel(resnet50d_config)
#此时 resnet50d.model 就是一个结构为 ResNet-50d 的模型,但权重是 随机初始化的,没有训练。pretrained_model = timm.create_model("resnet50d", pretrained=True)
#从 timm 加载已经训练好的 resnet50d 模型resnet50d.model.load_state_dict(pretrained_model.state_dict())
3 启用 AutoClass 支持
AutoClass API 能自动根据配置加载模型,简化用户调用
需要:
-
在配置类中加入
model_type
; -
在模型类中加入
config_class
; -
使用
AutoConfig.register()
和AutoModel.register()
注册。
from transformers import AutoConfig, AutoModel, AutoModelForImageClassificationAutoConfig.register("resnet", ResnetConfig)
#注册自定义配置类 ResnetConfig。
#"resnet" 是 ResnetConfig.model_type,它必须和配置类中的 model_type = "resnet" 一致。
#注册后,用户可以通过 AutoConfig.from_pretrained() 自动加载这个配置类。AutoModel.register(ResnetConfig, ResnetModel)
#把裸模型类 ResnetModel 绑定到 AutoModel。
'''
这样用户就可以用如下方式加载模型:
model = AutoModel.from_pretrained("your-username/custom-resnet50d", trust_remote_code=True)
'''AutoModelForImageClassification.register(ResnetConfig, ResnetModelForImageClassification)
#注册了你带分类头的模型 ResnetModelForImageClassification 到 AutoModelForImageClassification。
'''
用户可以像这样加载:
model = AutoModelForImageClassification.from_pretrained("your-username/custom-resnet50d", trust_remote_code=True
)
'''
4 本地保存& 加载特定模型
假设已经定义和注册配置和模型,并加载了预训练权重
resnet50d_config = ResnetConfig(block_type="bottleneck", stem_width=32, stem_type="deep", avg_down=True)
#加载自定义configresnet50d = ResnetModelForImageClassification(resnet50d_config)
#加载自定义model# 加载预训练权重
import timm
pretrained = timm.create_model("resnet50d", pretrained=True)
resnet50d.model.load_state_dict(pretrained.state_dict())
注册 AutoClass 支持,保存 AutoClass 映射信息
resnet50d_config.register_for_auto_class()
resnet50d.register_for_auto_class("AutoModelForImageClassification")
保存模型和配置到本地
resnet50d.save_pretrained("custom-resnet50d/")
resnet50d_config.save_pretrained("custom-resnet50d/")
4.1 本地重新加载
from transformers import AutoModelForImageClassification# 加载模型
model = AutoModelForImageClassification.from_pretrained("custom-resnet50d/", trust_remote_code=True
)
由于使用的是自定义模型类,加载时一定要加上trust_remote_code=True
4.2 保存后的本地目录
4.3 为什么要保存config?
config
是必须保存的,因为AutoModel
是依赖config.json
来决定加载哪个模型类。- AutoModel.from_pretrained("path_or_repo")背后的机制是
- 先加载配置文件
config.json
config = AutoConfig.from_pretrained("path_or_repo")
- 根据
config.model_type
决定使用哪个模型类"model_type": "resnet"
→ 查找注册的ResnetModel
-
再加载权重文件(.bin 或 .safetensors)到模型中
- 先加载配置文件