深度学习训练配置参数详解
1. 启动初始化参数 说明 CUDA_VISIBLE_DEVICES
指定使用的GPU设备编号("0"表示单卡) seed
随机种子(1777777),保证实验可复现性 cuda
是否启用GPU加速(True) benchmark
是否启用cudnn基准测试(False),输入尺寸固定时可设为True加速 deterministic
是否强制确定性算法(True),保证可复现性但可能降低性能
2. 数据预处理参数 说明 resample_spacing
体数据重采样间距([0.5,0.5,0.5]毫米) clip_lower_bound
灰度值截断下限(-1412) clip_upper_bound
灰度值截断上限(17943) samples_train
每张训练图像的采样点数(2048) crop_size
训练裁剪尺寸(160×160×96体素) crop_threshold
有效裁剪的最小前景占比(0.5)
3. 数据增强参数 说明 augmentation_probability
数据增强应用概率(30%) augmentation_method
增强策略("Choice"表示随机选择一种) open_elastic_transform
是否启用弹性形变(True) elastic_transform_sigma
弹性形变强度(20) elastic_transform_alpha
弹性形变缩放系数(1) open_gaussian_noise
是否添加高斯噪声(True) gaussian_noise_mean
噪声均值(0) gaussian_noise_std
噪声标准差(0.01) open_random_flip
是否启用随机翻转(True) open_random_rescale
是否启用随机缩放(True) random_rescale_min_percentage
最小缩放比例(0.5倍) random_rescale_max_percentage
最大缩放比例(1.5倍) open_random_rotate
是否启用随机旋转(True) random_rotate_min_angle
最小旋转角度(-50°) random_rotate_max_angle
最大旋转角度(50°) normalize_mean
数据标准化均值(0.050) normalize_std
数据标准化标准差(0.028)
4. 数据加载参数 说明 dataset_name
数据集名称(“3D-CBCT-Tooth”) dataset_path
数据集存储路径 create_data
是否重新生成预处理数据(False) batch_size
批大小(1) num_workers
数据加载线程数(4)
5. 模型配置参数 说明 model_name
模型名称(“KanNet”) in_channels
输入通道数(1表示灰度图像) classes
分类数量(2类:背景/前景) index_to_class_dict
类别索引映射字典 resume
断点续训模型路径(None表示不启用) pretrain
预训练权重路径(None表示不启用) high_frequency
高频成分权重(0.9) low_frequency
低频成分权重(0.1)
6. 优化器参数 说明 optimizer_name
优化器类型(“AdamW”) learning_rate
初始学习率(0.0005) weight_decay
L2正则化系数(0.00005) momentum
动量参数(0.8)
7. 学习率调度参数 说明 lr_scheduler_name
学习率调度器类型(“ReduceLROnPlateau”) mode
监控指标方向("max"表示越大越好) factor
学习率衰减系数(0.5) patience
等待epoch数(1轮不提升后衰减) milestones
多步学习率调整时机([1,3,5,7,8,9]epoch)
8. 损失函数与评估参数 说明 metric_names
评估指标列表([“DSC”]) loss_function_name
损失函数(“DiceLoss”) class_weight
类别权重(背景0.005,前景0.995) dice_loss_mode
Dice损失变体(“extension”) sigmoid_normalization
是否使用Sigmoid归一化(False)
9. 训练设置参数 说明 optimize_params
是否优化超参数(False) use_amp
是否使用混合精度(False) run_dir
实验日志保存目录 start_epoch
起始epoch(0) end_epoch
终止epoch(20) best_dice
初始最佳Dice分数(0.60) save_epoch_freq
模型保存频率(每4个epoch) crop_stride
预测时的滑动窗口步长([32,32,32])
关键说明:
GPU相关参数:需根据实际硬件调整CUDA_VISIBLE_DEVICES
数据增强:所有open_*参数控制是否启用对应增强方法
类别不平衡:通过class_weight参数显著提高前景权重(牙科结构)
训练控制:deterministic=True保证可复现性,但会禁用benchmark优化
注:实际使用时需根据数据集特性和硬件条件调整参数值。对于医学图像分割任务,建议优先保证deterministic和精细的数据预处理。