机器学习从入门到精通 - 强化学习初探:从 Q-Learning 到 Deep Q-Network 实战
一、开场白:推开强化学习这扇门
不知道你有没有过这种感觉 —— 盯着一个复杂的系统,既想让它达到某个目标,又苦于无法用传统规则去精确描述每一步该怎么做?比如训练一个机器人走出迷宫,或者让算法学会玩《超级马里奥》。这就是强化学习(Reinforcement Learning, RL)大展拳脚的地方了!它不要求你预先告知所有正确答案,而是让一个"智能体"(Agent)在环境中不断试错、根据反馈调整策略,最终学会达成目标。今天这篇长文,咱们就手把手从最经典的 Q-Learning 开始,一路打通关,最后用 PyTorch 实现一个解决经典控制问题的 Deep Q-Network (DQN)! 我保证,过程中你会掉进不少坑,也会看到我是怎么狼狈爬出来的 —— 这才是真实的学习过程嘛。
二、强化学习基本框架:环境、状态、动作与奖励
先说个容易踩的坑:很多人一上来就扎进算法公式里,结果连 Agent 和 Environment 怎么交互都搞不清,后面就全乱了套。必须得先理解这个核心交互循环:
- 环境(Environment):智能体存在的世界(比如一个迷宫、一个游戏画面、一个股票市场)。
- 状态(State):环境在时刻
t
的完整描述(比如迷宫里的坐标、游戏画面像素、股票价格+指标)。 - 动作(Action):智能体在状态
s_t
下能做出的选择(比如向上/下/左/右移动、买入/卖出/持有)。 - 奖励(Reward):环境在智能体执行动作
a_t
后,进入新状态s_{t+1}
时给出的即时评价信号(比如撞墙扣分,到达终点加分)。记住,智能体的终极目标就是最大化长期累积奖励!
关键概念:马尔可夫决策过程(MDP)
绝大多数强化学习问题都建模成 MDP。它要求:下一个状态 s_{t+1}
和当前奖励 r_t
只取决于当前状态 s_t
和当前动作 a_t
,与之前的历史无关。 用数学表示就是:
P(s_{t+1}, r_t | s_t, a_t, s_{t-1}, a_{t-1}, ..., s_0, a_0) = P(s_{t+1}, r_t | s_t, a_t)
这个假设是很多强化学习算法(包括 Q-Learning)的理论基石。
三、Q-Learning:价值函数的艺术
好了,现在我们请出今天的第一位主角:Q-Learning。它是一种 无模型(Model-Free)、基于价值(Value-Based) 的强化学习算法。核心思想是学习一个叫 Q-Table
的东西。
什么是 Q 值?
Q(s, a)
表示在状态 s
下选择动作 a
,并且之后一直采取最优策略所能获得的期望累积奖励。简单说,它衡量了在 s
选 a
有多“好”。
目标:找到最优 Q 函数 Q^*(s, a)
如果我能知道所有状态 s
下所有动作 a
的 Q^*(s, a)
,那么最优策略 π^*(s)
就简单了:永远选择当前状态下 Q
值最大的那个动作!π^*(s) = argmax_a Q^*(s, a)
Q-Learning 的更新魔法:时间差分(TD)
问题是,我们一开始不知道 Q^*
。Q-Learning 的核心在于通过不断尝试和更新来逼近 Q^*
。它的更新公式是重中之重(推导来了!):
Q(s_t, a_t) <-- Q(s_t, a_t) + α * [ TD_Target - Q(s_t, a_t) ]
其中:
α
(Alpha):学习率(Learning Rate),控制新信息覆盖旧信息的程度(0 < α ≤ 1)。这个值吧 —— 选大了震荡,选小了学得慢,后面实战会踩坑。TD_Target
:时间差分目标(Temporal Difference Target),代表我们对Q(s_t, a_t)
的最新估计。它是怎么来的呢?
-
贝尔曼方程(Bellman Equation) 是理解的基础。对于最优 Q 函数,它满足:
Q^*(s, a) = E [ r + γ * max_{a'} Q^*(s', a') | s, a ]
s'
:执行a
后转移到的下一个状态。r
:执行a
后得到的即时奖励。γ
(Gamma):折扣因子(0 ≤ γ < 1),表示我们有多重视未来奖励(γ 越接近 1,越重视长远收益)。max_{a'} Q^*(s', a')
:在下一个状态s'
下,采取最优动作a'
所能获得的最大 Q 值(代表了s'
状态的价值V^*(s')
)。E[...]
:期望值(因为状态转移可能有随机性)。
这个方程说明:当前状态-动作对的价值等于即时奖励加上折扣后的下一个状态的最优价值。 它揭示了 Q 值之间的递归关系。
-
从贝尔曼最优方程到 Q-Learning 更新: Q-Learning 用当前估计值去逼近贝尔曼方程定义的理想值。在
s_t
执行a_t
,我们观察到即时奖励r_{t+1}
和新状态s_{t+1}
。这时候,我们会对Q(s_t, a_t)
应该等于什么有一个新的“目标”:
TD_Target = r_{t+1} + γ * max_{a} Q(s_{t+1}, a)
注意这里用的是我们当前的 Q 表来估计s_{t+1}
状态的价值 (max_{a} Q(s_{t+1}, a)
),而不是Q^*
。 -
更新量:
TD_Target - Q(s_t, a_t)
就是当前估计值和新的目标值之间的差异,称为 TD 误差(Temporal Difference Error)。Q-Learning 做的就是:用这个误差乘以学习率 α,去调整当前的Q(s_t, a_t)
,让它更接近TD_Target
。
最终合并得到的 Q-Learning 更新公式:
# Q(s_t, a_t) 更新公式
Q[s_t, a_t] = Q[s_t, a_t] + α * ( r_{t+1} + γ * np.max(Q[s_{t+1}, :]) - Q[s_t, a_t] )
符号释义:
s_t
:当前时刻t
的状态。a_t
:在s_t
状态下选择的动作。r_{t+1}
:执行a_t
后得到的即时奖励。s_{t+1}
:执行a_t
后转移到的下一个状态。Q[s_{t+1}, :]
:在 Q 表中,状态s_{t+1}
对应的所有Q
值。np.max(Q[s_{t+1}, :])
:下一个状态s_{t+1}
下,所有可能动作的最大Q
值估计。α
:学习率。γ
:折扣因子。
Q-Learning 算法的伪代码流程:
关键点:探索与利用(ε-greedy)
- 如果每次都选当前 Q 表认为最好的动作(
argmax_a Q(s, a)
),可能永远发现不了真正更好的动作。 - ε-greedy 策略: 以
1 - ε
的概率选择当前 Q 值最大的动作(利用),以ε
的概率随机选择一个动作(探索)。ε
通常随着训练衰减(从 1.0 开始,逐渐减小到 0.01 或 0.1)。ε 衰减策略没设计好,模型可能学偏或者卡住,这是个大坑点。
四、Q-Table 的局限与 Deep Q-Network 的崛起
Q-Learning 在小规模、离散状态和动作空间下表现很好。但是,现实世界往往是连续的,状态维度极高(比如一张游戏图像有几十万个像素点)。Q-Table 的致命伤来了:它无法处理高维或连续状态空间!
- 存储问题: 状态太多(甚至是无限的),Q-Table 根本存不下。想象一下用表格存储每个可能的像素组合的 Q 值 —— 天文数字!
- 泛化问题: 即使能存储,遇到没见过的状态,Q-Table 无法给出合理的 Q 值估计。它没有泛化能力。
解决方案:用函数逼近器代替 Q-Table!
这里 —— 深度神经网络(DNN)闪亮登场。它强大的函数拟合能力,让它成为学习 Q(s, a; θ)
函数的绝佳选择,其中 θ
是神经网络的参数。这就是 Deep Q-Network (DQN)。
DQN 的核心技术(两大支柱):
-
经验回放(Experience Replay):
- 问题: 连续采集的经验
(s_t, a_t, r_{t+1}, s_{t+1}, done)
是强相关的。直接用它们训练网络会导致参数更新不稳定、振荡甚至发散。 - 解决: 建立一个固定大小的经验池(Replay Buffer)。每次与环境交互得到的经验元组先存入池中。训练时,随机抽取一小批(Mini-batch)经验进行学习。
- 好处:
- 打破样本间的时间相关性,使训练更稳定。
- 提高数据利用率,单个样本可被多次学习。
- 离线学习(Off-policy):可以重复利用过去的经验。经验池大小和采样方式的选择很关键,太小容易过时,太大训练慢,后面会踩坑。
- 问题: 连续采集的经验
-
目标网络(Target Network):
- 问题: 在计算
TD_Target = r + γ * max_a Q(s', a; θ)
时,θ
是我们正在更新的网络参数。更新θ
会导致TD_Target
本身也在快速变化(像个移动的目标),加剧训练的不稳定性。 - 解决: 引入一个结构相同但参数不同的目标网络
Q(s, a; θ⁻)
。这个网络的参数θ⁻
并不是每一步都更新,而是定期(比如每 N 步)从当前训练网络Q(s, a; θ)
复制参数(θ⁻ ← θ
)。计算TD_Target
时使用目标网络,并且需要考虑终止状态:
TD_Target = r
(如果s'
是终止状态)
TD_Target = r + γ * max_a Q(s', a; θ⁻)
(如果s'
不是终止状态) - 好处:
TD_Target
在一段时间内相对稳定,为训练网络Q(s, a; θ)
提供了一个更可靠的更新目标。更新频率N
是个超参数,需要调。
- 问题: 在计算
DQN 算法流程:
网络架构设计(以 CartPole 为例):
- 输入: 状态
s
(CartPole 中是 4 维向量[cart_position, cart_velocity, pole_angle, pole_angular_velocity]
)。 - 输出: 每个可能动作
a
的 Q 值估计(CartPole 中是 2 维向量[Q(s, left), Q(s, right)]
)。 - 隐藏层: 通常使用全连接层(FC)。对于简单问题如 CartPole,1-2 个隐藏层(如 128 或 256 个神经元)足够。激活函数通常用 ReLU。
import torch
import torch.nn as nn
import torch.optim as optimclass DQN(nn.Module):def __init__(self, state_dim, action_dim, hidden_dim=128):super(DQN, self).__init__()self.fc1 = nn.Linear(state_dim, hidden_dim)self.fc2 = nn.Linear(hidden_dim, hidden_dim) # 可选的第二层self.fc3 = nn.Linear(hidden_dim, action_dim)self.relu = nn.ReLU()def forward(self, x):x = self.relu(self.fc1(x))x = self.relu(self.fc2(x)) # 如果只有一层隐藏层则去掉这行x = self.fc3(x)return x
五、实战:用 PyTorch 实现 DQN 解决 CartPole 问题
为什么选 CartPole? 它是 OpenAI Gym 提供的经典控制环境,状态简单(4维),动作离散(2个)。DQN 能很好地解决它,非常适合入门演示。目标: 控制小车左右移动,让上面的杆子尽可能长时间保持竖直不倒。成功标准通常是连续保持平衡 195 步(或平均奖励达到 195)。
1. 安装环境 & 导入库
!pip install gymnasium[classic_control]==0.29.1 numpy==1.26.4 torch==2.2.2 matplotlib # 强烈建议指定版本避免兼容性问题import gymnasium as gym
import numpy as np
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from collections import deque # 用于实现经验回放池
import matplotlib.pyplot as plt
2. 核心组件实现
经验回放池 (Replay Buffer)
class ReplayBuffer:def __init__(self, capacity):self.buffer = deque(maxlen=capacity) # 固定大小的双端队列def add(self, state, action, reward, next_state, done):"""存储一条经验 (s, a, r, s', done)"""self.buffer.append((state, action, reward, next_state, done))def sample(self, batch_size):"""随机采样一批经验"""experiences = random.sample(self.buffer, batch_size)# 拆分元组为独立的 NumPy 数组states, actions, rewards, next_states, dones = zip(*experiences)# 转换为 PyTorch Tensor (注意! 后面踩坑点)return (torch.tensor(np.array(states), dtype=torch.float32),torch.tensor(actions, dtype=torch.long).unsqueeze(1), # 增加批次维度torch.tensor(rewards, dtype=torch.float32).unsqueeze(1),torch.tensor(np.array(next_states), dtype=torch.float32),torch.tensor(dones, dtype=torch.float32).unsqueeze(1))def __len__(self):return len(self.buffer)
踩坑记录1:数据类型转换
- 环境返回的
state
,next_state
是np.ndarray
。 action
,reward
,done
是标量或布尔值。- 必须小心地转换为正确数据类型的
torch.Tensor
,并确保维度一致(特别是actions
,rewards
,dones
通常需要增加一个维度表示 batch)。 不注意这个会在计算损失函数时报各种维度不匹配的错误。
DQN Agent
class DQNAgent:def __init__(self, state_dim, action_dim, hidden_dim, lr, gamma, epsilon, target_update_freq, buffer_capacity):self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")print(f"Using device: {self.device}")self.action_dim = action_dim# Networksself.policy_net = DQN(state_dim, action_dim, hidden_dim).to(self.device)self.target_net = DQN(state_dim, action_dim, hidden_dim).to(self.device)self.target_net.load_state_dict(self.policy_net.state_dict()) # 同步初始权重self.target_net.eval() # 目标网络不进行梯度计算self.optimizer = optim.Adam(self.policy_net.parameters(), lr=lr)self.memory = ReplayBuffer(buffer_capacity)# Hyperparametersself.gamma = gammaself.epsilon = epsilonself.target_update_freq = target_update_freqself.learn_step_counter = 0def select_action(self, state):"""根据 ε-greedy 策略选择动作"""if random.random() < self.epsilon:return random.randrange(self.action_dim) # 探索:随机选择动作else:# 利用:选择Q值最高的动作with torch.no_grad():state_tensor = torch.tensor(np.array([state]), dtype=torch.float32).to(self.device)q_values = self.policy_net(state_tensor)return q_values.argmax().item()def learn(self, batch_size):"""从经验池中采样学习"""if len(self.memory) < batch_size:return # 经验池不够,先不学习states, actions, rewards, next_states, dones = self.memory.sample(batch_size)states = states.to(self.device)actions = actions.to(self.device)rewards = rewards.to(self.device)next_states = next_states.to(self.device)dones = dones.to(self.device)# 1. 计算当前状态的 Q 值: Q(s_t, a_t)# self.policy_net(states) 输出所有动作的Q值# .gather(1, actions) 提取出实际采取动作 a_t 对应的 Q 值current_q_values = self.policy_net(states).gather(1, actions)# 2. 计算 TD Targetwith torch.no_grad():# 用目标网络计算下一个状态的最大 Q 值next_q_values = self.target_net(next_states).max(1)[0].unsqueeze(1)# 如果 done=True (值为1), 那么未来的奖励为0td_target = rewards + self.gamma * next_q_values * (1 - dones)# 3. 计算损失loss = F.mse_loss(current_q_values, td_target)# 4. 优化模型self.optimizer.zero_grad()loss.backward()self.optimizer.step()self.learn_step_counter += 1# 5. 定期更新目标网络if self.learn_step_counter % self.target_update_freq == 0:self.target_net.load_state_dict(self.policy_net.state_dict())
3. 设置超参数与训练循环
踩坑记录2:超参数调优是门玄学!
DQN 的超参数非常敏感。LEARNING_RATE
太大,训练会不稳定;GAMMA
太小,智能体会变得短视;EPSILON
衰减太快,探索不足;TARGET_UPDATE_FREQ
太频繁或太稀疏都不好。下面的参数是针对 CartPole 调优过的一组,但并不唯一。
# --- Hyperparameters ---
EPISODES = 400
BUFFER_CAPACITY = 10000
BATCH_SIZE = 64
LEARNING_RATE = 0.001
GAMMA = 0.99# Epsilon-greedy 策略参数
EPSILON_START = 1.0
EPSILON_END = 0.05
EPSILON_DECAY = (EPSILON_START - EPSILON_END) / (EPISODES * 0.6) # 线性衰减TARGET_UPDATE_FREQ = 100 # 每100步学习更新一次目标网络
HIDDEN_DIM = 128# --- Setup ---
env = gym.make('CartPole-v1')
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.nagent = DQNAgent(state_dim, action_dim, HIDDEN_DIM, LEARNING_RATE, GAMMA, EPSILON_START, TARGET_UPDATE_FREQ, BUFFER_CAPACITY)episode_rewards = []# --- Training Loop ---
for i_episode in range(EPISODES):state, info = env.reset()done = Falsetotal_reward = 0while not done:action = agent.select_action(state)next_state, reward, terminated, truncated, info = env.step(action)done = terminated or truncatedagent.memory.add(state, action, reward, next_state, done)state = next_statetotal_reward += rewardagent.learn(BATCH_SIZE)# 更新 Epsilonif agent.epsilon > EPSILON_END:agent.epsilon -= EPSILON_DECAYepisode_rewards.append(total_reward)if (i_episode + 1) % 20 == 0:print(f"Episode {i_episode+1}/{EPISODES}, Total Reward: {total_reward}, Epsilon: {agent.epsilon:.4f}")print("Training finished.")
env.close()
4. 结果可视化
def plot_rewards(rewards):plt.figure(figsize=(12, 6))plt.plot(rewards, label='Reward per Episode')# 计算并绘制100个episode的移动平均线,以更好地观察趋势moving_avg = np.convolve(rewards, np.ones(100)/100, mode='valid')plt.plot(np.arange(len(rewards) - 99), moving_avg, label='100-Episode Moving Average')plt.title('CartPole DQN Training Performance')plt.xlabel('Episode')plt.ylabel('Total Reward')plt.grid(True)plt.legend()plt.show()plot_rewards(episode_rewards)
运行代码后,你大概率会看到一张奖励曲线图。一开始奖励很低(智能体在随机乱撞),但随着训练的进行,曲线会逐渐攀升并最终稳定在高位(比如 200 以上,甚至达到 CartPole-v1 的上限 500),移动平均线能更清晰地展示这个趋势。这就是 DQN 学会了如何平衡杆子的证明!
六、总结与展望:DQN 之后的路
今天我们从强化学习最基础的交互框架出发,深入了 Q-Learning 的核心更新机制,然后为了克服 Q-Table 的局限性,引入了 DQN 的两大支柱——经验回放和目标网络。最后,我们用 PyTorch 从零到一地实现了一个能解决 CartPole 问题的 DQN 智能体。
回顾我们踩过的坑:
- 数据类型与维度:在
ReplayBuffer
和learn
函数中,NumPy 和 PyTorch Tensor 之间的转换、维度的匹配是 bug 高发区。 - 超参数敏感:DQN 的表现严重依赖于超参数的选择,需要耐心调优。没有一组“万能”参数。
- TD Target 的终止状态:计算 TD Target 时忘记处理
done
信号,会导致智能体错误地评估终局的价值,这是个常见的逻辑错误。
DQN 并非终点,而是起点。 它本身也存在一些问题,比如 Q 值过高估计(Overestimation Bias)。后续的研究提出了许多改进方案,构建了庞大的“彩虹 DQN”(Rainbow DQN)家族:
- Double DQN (DDQN):解耦了“选择”和“评估”下一个状态 Q 值的网络,缓解 Q 值过高估计问题。
- Dueling DQN:将 Q 值网络结构分解为“状态价值 V(s)”和“动作优势 A(s, a)”,学习更高效。
- Prioritized Experience Replay (PER):不再随机采样经验,而是优先学习那些 TD 误差大的、“更值得学习”的经验。