lstm 我有 20*6 条数据,20个样本,每个样本6条历史数据,每条数据有5个值,我送给网络输入时应该是20*6*5 还是 6*20*5
你的数据是:
20 个样本(batch size = 20)
每个样本有 6 条历史数据(sequence length = 6)
每条数据有 5 个值(input size = 5)
✅ 正确的输入形状是:
(20, 6, 5) # 即 batch_size=20, seq_len=6, input_size=5
前提是你在定义 LSTM 时设置了:
nn.LSTM(input_size=5, hidden_size=..., batch_first=True)