在阅读这篇文章前,推荐先阅读以下内容:
- [netty5: WebSocketFrame]-源码分析
- [netty5: WebSocketFrameEncoder & WebSocketFrameDecoder]-源码解析
WebSocketClientHandshakerFactory
WebSocketClientHandshakerFactory
是用于根据 URI 和协议版本创建对应 WebSocket 握手器(Handshaker)的工厂类,简化客户端握手流程。
public final class WebSocketClientHandshakerFactory {private WebSocketClientHandshakerFactory() {}// ...// new WebSocketClientProtocolHandler(config)public static WebSocketClientHandshaker newHandshaker(URI webSocketURL, WebSocketVersion version, String subprotocol,boolean allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength,boolean performMasking, boolean allowMaskMismatch, long forceCloseTimeoutMillis,boolean absoluteUpgradeUrl, boolean generateOriginHeader) {return new WebSocketClientHandshaker13(webSocketURL, subprotocol, allowExtensions, customHeaders,maxFramePayloadLength, performMasking, allowMaskMismatch, forceCloseTimeoutMillis,absoluteUpgradeUrl, generateOriginHeader);}
}
WebSocketClientHandshaker13
WebSocketClientHandshaker13
是实现 WebSocket 协议 RFC 6455(版本13)的客户端握手器,负责构造握手请求、验证响应并完成协议升级。
public class WebSocketClientHandshaker13 extends WebSocketClientHandshaker {private final boolean allowExtensions;private final boolean performMasking;private final boolean allowMaskMismatch;private volatile String sentNonce;// WebSocketClientHandshakerFactory.newHandshakerWebSocketClientHandshaker13(URI webSocketURL, String subprotocol,boolean allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength,boolean performMasking, boolean allowMaskMismatch,long forceCloseTimeoutMillis, boolean absoluteUpgradeUrl,boolean generateOriginHeader) {super(webSocketURL, WebSocketVersion.V13, subprotocol, customHeaders, maxFramePayloadLength,forceCloseTimeoutMillis, absoluteUpgradeUrl, generateOriginHeader);this.allowExtensions = allowExtensions;this.performMasking = performMasking;this.allowMaskMismatch = allowMaskMismatch;}/*** /*** <p>* Sends the opening request to the server:* </p>** <pre>* GET /chat HTTP/1.1* Host: server.example.com* Upgrade: websocket* Connection: Upgrade* Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==* Sec-WebSocket-Protocol: chat, superchat* Sec-WebSocket-Version: 13* </pre>**/@Overrideprotected FullHttpRequest newHandshakeRequest(BufferAllocator allocator) {URI wsURL = uri();FullHttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, upgradeUrl(wsURL),allocator.allocate(0));HttpHeaders headers = request.headers();if (customHeaders != null) {headers.add(customHeaders);if (!headers.contains(HttpHeaderNames.HOST)) {headers.set(HttpHeaderNames.HOST, websocketHostValue(wsURL));}} else {headers.set(HttpHeaderNames.HOST, websocketHostValue(wsURL));}String nonce = createNonce();headers.set(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET).set(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE).set(HttpHeaderNames.SEC_WEBSOCKET_KEY, nonce);if (generateOriginHeader && !headers.contains(HttpHeaderNames.ORIGIN)) {headers.set(HttpHeaderNames.ORIGIN, websocketHostValue(wsURL));}sentNonce = nonce;String expectedSubprotocol = expectedSubprotocol();if (!StringUtil.isNullOrEmpty(expectedSubprotocol)) {headers.set(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL, expectedSubprotocol);}headers.set(HttpHeaderNames.SEC_WEBSOCKET_VERSION, version().toAsciiString());return request;}/*** <p>* Process server response:* </p>** <pre>* HTTP/1.1 101 Switching Protocols* Upgrade: websocket* Connection: Upgrade* Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=* Sec-WebSocket-Protocol: chat* </pre>** @param response* HTTP response returned from the server for the request sent by beginOpeningHandshake00().* @throws WebSocketHandshakeException if handshake response is invalid.*/@Overrideprotected void verify(FullHttpResponse response) {HttpResponseStatus status = response.status();if (!HttpResponseStatus.SWITCHING_PROTOCOLS.equals(status)) {throw new WebSocketClientHandshakeException("Invalid handshake response status: " + status, response);}HttpHeaders headers = response.headers();CharSequence upgrade = headers.get(HttpHeaderNames.UPGRADE);if (!HttpHeaderValues.WEBSOCKET.contentEqualsIgnoreCase(upgrade)) {throw new WebSocketClientHandshakeException("Invalid handshake response upgrade: " + upgrade, response);}if (!headers.containsIgnoreCase(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE)) {throw new WebSocketClientHandshakeException("Invalid handshake response connection: " + headers.get(HttpHeaderNames.CONNECTION), response);}CharSequence accept = headers.get(HttpHeaderNames.SEC_WEBSOCKET_ACCEPT);if (accept == null) {throw new WebSocketClientHandshakeException("Invalid handshake response sec-websocket-accept: null", response);}String expectedAccept = WebSocketUtil.calculateV13Accept(sentNonce);if (!AsciiString.contentEquals(expectedAccept, AsciiString.trim(accept))) {throw new WebSocketClientHandshakeException("Invalid handshake response sec-websocket-accept: " + accept + ", expected: " + expectedAccept, response);}}@Overrideprotected WebSocketFrameDecoder newWebsocketDecoder() {return new WebSocket13FrameDecoder(false, allowExtensions, maxFramePayloadLength(), allowMaskMismatch);}@Overrideprotected WebSocketFrameEncoder newWebSocketEncoder() {return new WebSocket13FrameEncoder(performMasking);}@Overridepublic WebSocketClientHandshaker13 setForceCloseTimeoutMillis(long forceCloseTimeoutMillis) {super.setForceCloseTimeoutMillis(forceCloseTimeoutMillis);return this;}// 生成一个符合 WebSocket 协议要求的 16 字节 Base64 编码的随机值,用作 Sec-WebSocket-Keyprivate static String createNonce() {var nonce = WebSocketUtil.randomBytes(16);return WebSocketUtil.base64(nonce);}
}
WebSocketClientHandshaker
public abstract class WebSocketClientHandshaker {protected static final int DEFAULT_FORCE_CLOSE_TIMEOUT_MILLIS = 10000;// 代表握手时的目标地址, 例如 ws://example.com/chatprivate final URI uri;// 控制握手请求和数据帧的格式, 比如 RFC 6455 标准版本private final WebSocketVersion version;// 标记握手是否完成,volatile 保证多线程访问时的可见性private volatile boolean handshakeComplete;// 握手完成后,如果关闭 WebSocket 连接时等待超时,会触发强制关闭。private volatile long forceCloseTimeoutMillis;// 用于标记强制关闭流程是否初始化, 通过 AtomicIntegerFieldUpdater 原子更新private volatile int forceCloseInit;private static final AtomicIntegerFieldUpdater<WebSocketClientHandshaker> FORCE_CLOSE_INIT_UPDATER = AtomicIntegerFieldUpdater.newUpdater(WebSocketClientHandshaker.class, "forceCloseInit");// 标记强制关闭流程是否完成。private volatile boolean forceCloseComplete;// 握手时客户端希望协商的子协议(Subprotocol), 例如视频、聊天子协议名称等private final String expectedSubprotocol;// 握手后服务器协商确认的子协议,握手成功后才有值。private volatile String actualSubprotocol;// 握手请求时使用,方便传递用户自定义信息。protected final HttpHeaders customHeaders;// 最大单个 WebSocket 帧负载长度限制, 防止收到超大数据导致内存溢出。private final int maxFramePayloadLength;// 是否在握手请求中使用绝对 URI 作为 Upgrade URL, 一般用于特殊代理或协议场景private final boolean absoluteUpgradeUrl;// 是否自动生成 Origin 请求头protected final boolean generateOriginHeader;protected WebSocketClientHandshaker(URI uri, WebSocketVersion version, String subprotocol,HttpHeaders customHeaders, int maxFramePayloadLength,long forceCloseTimeoutMillis, boolean absoluteUpgradeUrl, boolean generateOriginHeader) {this.uri = uri;this.version = version;expectedSubprotocol = subprotocol;this.customHeaders = customHeaders;this.maxFramePayloadLength = maxFramePayloadLength;this.forceCloseTimeoutMillis = forceCloseTimeoutMillis;this.absoluteUpgradeUrl = absoluteUpgradeUrl;this.generateOriginHeader = generateOriginHeader;}// WebSocketClientProtocolHandshakeHandler.channelActivepublic Future<Void> handshake(Channel channel) {requireNonNull(channel, "channel");ChannelPipeline pipeline = channel.pipeline();// 检查管道中解码器HttpResponseDecoder decoder = pipeline.get(HttpResponseDecoder.class);if (decoder == null) {HttpClientCodec codec = pipeline.get(HttpClientCodec.class);if (codec == null) {return channel.newFailedFuture(new IllegalStateException("ChannelPipeline does not contain " + "an HttpResponseDecoder or HttpClientCodec"));}}// 检查 URI 和 Header 相关的 Host 与 Originif (uri.getHost() == null) {if (customHeaders == null || !customHeaders.contains(HttpHeaderNames.HOST)) {return channel.newFailedFuture(new IllegalArgumentException("Cannot generate the 'host' header value," + " webSocketURI should contain host or passed through customHeaders"));}if (generateOriginHeader && !customHeaders.contains(HttpHeaderNames.ORIGIN)) {final String originName = HttpHeaderNames.ORIGIN.toString();return channel.newFailedFuture(new IllegalArgumentException("Cannot generate the '" + originName + "' header" + " value, webSocketURI should contain host or disable generateOriginHeader or pass value" + " through customHeaders"));}}// 创建握手请求FullHttpRequest request = newHandshakeRequest(channel.bufferAllocator());// 创建 Promise,异步写出请求Promise<Void> promise = channel.newPromise();channel.writeAndFlush(request).addListener(channel, (ch, future) -> {// 如果写操作成功if (future.isSuccess()) {ChannelPipeline p = ch.pipeline();//找出管道中 HTTP 请求编码器 HttpRequestEncoder 或者 HttpClientCodec,ChannelHandlerContext ctx = p.context(HttpRequestEncoder.class);if (ctx == null) {ctx = p.context(HttpClientCodec.class);}if (ctx == null) {promise.setFailure(new IllegalStateException("ChannelPipeline does not contain " + "an HttpRequestEncoder or HttpClientCodec"));return;}// 然后在其后面动态添加 WebSocket 专用的编码器 ws-encoder(由 newWebSocketEncoder() 创建)p.addAfter(ctx.name(), "ws-encoder", newWebSocketEncoder());promise.setSuccess(null);} else {promise.setFailure(future.cause());}});return promise.asFuture();}// WebSocketClientProtocolHandshakeHandler.channelReadpublic final void finishHandshake(Channel channel, FullHttpResponse response) {verify(response);// 服务器返回的子协议CharSequence receivedProtocol = response.headers().get(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL);receivedProtocol = receivedProtocol != null ? AsciiString.trim(receivedProtocol) : null;// 客户端期望的子协议String expectedProtocol = expectedSubprotocol != null ? expectedSubprotocol : "";boolean protocolValid = false;// 如果客户端没指定预期协议,且服务器也没返回协议,视为通过。if (expectedProtocol.isEmpty() && receivedProtocol == null) {protocolValid = true;setActualSubprotocol(expectedSubprotocol);} else if (!expectedProtocol.isEmpty() && receivedProtocol != null && receivedProtocol.length() > 0) {// 如果客户端有期望协议且服务器返回了协议,则判断服务器返回的协议是否在客户端允许的列表中for (String protocol : expectedProtocol.split(",")) {if (AsciiString.contentEquals(protocol.trim(), receivedProtocol)) {protocolValid = true;setActualSubprotocol(receivedProtocol.toString());break;}}}// 如果子协议校验失败,抛出握手异常。if (!protocolValid) {throw new WebSocketClientHandshakeException(String.format("Invalid subprotocol. Actual: %s. Expected one of: %s",receivedProtocol, expectedSubprotocol), response);}// 标记握手完成。setHandshakeComplete();final ChannelPipeline p = channel.pipeline();// 移除 HTTP 消息解压处理器(如 gzip 解压),以及 HTTP 聚合器,WebSocket 不需要这些HttpContentDecompressor decompressor = p.get(HttpContentDecompressor.class);if (decompressor != null) {p.remove(decompressor);}HttpObjectAggregator aggregator = p.get(HttpObjectAggregator.class);if (aggregator != null) {p.remove(aggregator);}// 查找 HTTP 解码器上下文:// 1. 若是 HttpClientCodec,先调用 removeOutboundHandler(),然后添加 WebSocket 解码器,最后异步移除 HTTP Codec。// 2. 若是单独的 HttpResponseDecoder,先移除对应的请求编码器,再添加 WebSocket 解码器,异步移除响应解码器。// 新加入的 ws-decoder 是 WebSocket 的解码器,处理 WebSocket 帧。ChannelHandlerContext ctx = p.context(HttpResponseDecoder.class);if (ctx == null) {ctx = p.context(HttpClientCodec.class);if (ctx == null) {throw new IllegalStateException("ChannelPipeline does not contain " +"an HttpRequestEncoder or HttpClientCodec");}final HttpClientCodec codec = (HttpClientCodec) ctx.handler();codec.removeOutboundHandler();p.addAfter(ctx.name(), "ws-decoder", newWebsocketDecoder());channel.executor().execute(() -> p.remove(codec));} else {if (p.get(HttpRequestEncoder.class) != null) {p.remove(HttpRequestEncoder.class);}final ChannelHandlerContext context = ctx;p.addAfter(context.name(), "ws-decoder", newWebsocketDecoder());channel.executor().execute(() -> p.remove(context.handler()));}}// ...protected abstract FullHttpRequest newHandshakeRequest(BufferAllocator allocator);protected abstract void verify(FullHttpResponse response);protected abstract WebSocketFrameDecoder newWebsocketDecoder();protected abstract WebSocketFrameEncoder newWebSocketEncoder();
}