SPOT(Sequential Predictive Modeling of Clinical Trial Outcome with Meta-Learning)模型是用于临床试验结果预测的模型,
借鉴了模型无关元学习(MAML,Model-Agnostic Meta-Learning)的框架,将模型参数分为全局共享参数和任务特定参数,以平衡跨任务泛化与任务内适配:
一、任务定义:将每个试验主题序列视为独立任务
SPOT通过主题发现模块(Topic Discovery)将临床试验数据聚类为多个主题(topic),每个主题包含具有相似特征(如疾病类型、治疗方案、试验设计)的试验。由于同一主题的试验在时间上具有连续性(按时间戳排序),SPOT将每个主题的时序试验序列定义为一个“任务”。
- 动机:临床试验数据存在严重的不平衡性(如某些疾病或治疗方案的试验数量少,属于“小众任务”)。元学习的核心优势是“学习如何学习”,能在少量数据上快速适应新任务,因此适合处理这类不平衡场景。
- 具体操作:每个主题的时序序列(如某类肿瘤药物的I期试验按年份排列的序列)被视为一个独立任务,模型需要为每个任务学习特定的预测模式。
二、参数设计:全局参数与任务特定参数分离
SPOT借鉴了模型无关元学习(MAML,Model-Agnostic Meta-Learning)的框架,将模型参数分为全局共享参数和任务特定参数,以平衡跨任务泛化与任务内适配:
-
全局参数(θ₁和θ₂):
- θ₁:来自静态试验嵌入模块(如疾病编码器GRAM、治疗方案编码器MPNN、入排标准编码器Trial2Vec),负责提取所有试验的通用特征(如疾病本体、分子结构的共性),在所有任务中共享。
- θ₂:对应序列建模模块(RNN和序列预测网络)的基础参数,用于捕捉时序模式的通用规律(如试验设计随时间演进的共性趋势)。
-
任务特定参数(θ₂ᵏ):
- 针对每个主题任务k,θ₂ᵏ是θ₂的微调版本,通过局部更新适配该主题的独特时序模式(如某类罕见病试验的成功率波动规律)。
- 设计目的:让模型在保留全局共性的同时,为每个任务定制参数,避免“多数类任务”主导模型学习,提升对“小众任务”的预测能力。
其算法流程可以分为以下几个主要步骤:
1. 模型初始化
在初始化阶段,会对模型的各种参数进行设置,并创建模型对象和主题发现器。以下是初始化部分的代码:
class SPOT(TrialOutcomeBase):def __init__(self,num_topics=50,n_trial_projector=2,n_timestemp_projector=2,n_rnn_layer=1,criteria_column='criteria',batch_size=1,n_trial_per_batch=None,learning_rate=1e-4,weight_decay=1e-4,epochs=10,evaluation_steps=50,warmup_ratio=0,device="cuda:0",seed=42,output_dir="./checkpoints/spot",):self.config = {'num_topics': num_topics,'n_trial_projector': n_trial_projector,'n_timestemp_projector': n_timestemp_projector,'n_rnn_layer': n_rnn_layer,'criteria_column': criteria_column,'batch_size': batch_size,'n_trial_per_batch':n_trial_per_batch,'learning_rate': learning_rate,'epochs': epochs,'weight_decay': weight_decay,'evaluation_steps':eva