diff --git a/src/Http/WebSocketClient.h b/src/Http/WebSocketClient.h index c99ca7d6..c5230800 100644 --- a/src/Http/WebSocketClient.h +++ b/src/Http/WebSocketClient.h @@ -18,9 +18,9 @@ #include "HttpClientImp.h" #include "WebSocketSplitter.h" -namespace mediakit{ +namespace mediakit { -template +template class HttpWsClient; /** @@ -28,23 +28,23 @@ class HttpWsClient; * @tparam ClientType TcpClient派生类 * @tparam DataType 这里无用,为了声明友元用 */ -template +template class ClientTypeImp : public ClientType { public: friend class HttpWsClient; using onBeforeSendCB = std::function; - template - ClientTypeImp(ArgsType &&...args): ClientType(std::forward(args)...){} + template + ClientTypeImp(ArgsType &&...args) : ClientType(std::forward(args)...) {} ~ClientTypeImp() override {}; protected: /** * 发送前拦截并打包为websocket协议 */ - ssize_t send(toolkit::Buffer::Ptr buf) override{ - if(_beforeSendCB){ + ssize_t send(toolkit::Buffer::Ptr buf) override { + if (_beforeSendCB) { return _beforeSendCB(buf); } return ClientType::send(std::move(buf)); @@ -54,9 +54,7 @@ protected: * 设置发送数据截取回调函数 * @param cb 截取回调函数 */ - void setOnBeforeSendCB(const onBeforeSendCB &cb){ - _beforeSendCB = cb; - } + void setOnBeforeSendCB(const onBeforeSendCB &cb) { _beforeSendCB = cb; } private: onBeforeSendCB _beforeSendCB; @@ -67,17 +65,16 @@ private: * @tparam ClientType TcpClient派生类 * @tparam DataType websocket负载类型,是TEXT还是BINARY类型 */ -template -class HttpWsClient : public HttpClientImp , public WebSocketSplitter{ +template +class HttpWsClient : public HttpClientImp, public WebSocketSplitter { public: using Ptr = std::shared_ptr; - HttpWsClient(const std::shared_ptr > &delegate) : _weak_delegate(delegate), - _delegate(*delegate) { + HttpWsClient(const std::shared_ptr> &delegate) : _weak_delegate(delegate) { _Sec_WebSocket_Key = encodeBase64(toolkit::makeRandStr(16, false)); - setPoller(_delegate.getPoller()); + setPoller(delegate->getPoller()); } - ~HttpWsClient(){} + ~HttpWsClient() = default; /** * 发起ws握手 @@ -98,22 +95,22 @@ public: sendRequest(http_url); } - void closeWsClient(){ - if(!_onRecv){ - //未连接 + void closeWsClient() { + if (!_onRecv) { + // 未连接 return; } WebSocketHeader header; header._fin = true; header._reserved = 0; header._opcode = CLOSE; - //客户端需要加密 + // 客户端需要加密 header._mask_flag = true; WebSocketSplitter::encode(header, nullptr); } protected: - //HttpClientImp override + // HttpClientImp override /** * 收到http回复头 @@ -121,12 +118,12 @@ protected: * @param headers http头 */ void onResponseHeader(const std::string &status, const HttpHeader &headers) override { - if(status == "101"){ + if (status == "101") { auto Sec_WebSocket_Accept = encodeBase64(toolkit::SHA1::encode_bin(_Sec_WebSocket_Key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11")); - if(Sec_WebSocket_Accept == const_cast(headers)["Sec-WebSocket-Accept"]){ - //success + if (Sec_WebSocket_Accept == const_cast(headers)["Sec-WebSocket-Accept"]) { + // success onWebSocketException(toolkit::SockException()); - //防止ws服务器返回Content-Length + // 防止ws服务器返回Content-Length const_cast(headers).erase("Content-Length"); return; } @@ -134,7 +131,7 @@ protected: return; } - shutdown(toolkit::SockException(toolkit::Err_shutdown,StrPrinter << "bad http status code:" << status)); + shutdown(toolkit::SockException(toolkit::Err_shutdown, StrPrinter << "bad http status code:" << status)); }; /** @@ -145,17 +142,16 @@ protected: /** * 接收websocket负载数据 */ - void onResponseBody(const char *buf,size_t size) override{ - if(_onRecv){ - //完成websocket握手后,拦截websocket数据并解析 + void onResponseBody(const char *buf, size_t size) override { + if (_onRecv) { + // 完成websocket握手后,拦截websocket数据并解析 _onRecv(buf, size); } }; - //TcpClient override + // TcpClient override void onRecv(const toolkit::Buffer::Ptr &buf) override { - auto strong_ref = _weak_delegate.lock();; HttpClientImp::onRecv(buf); } @@ -163,26 +159,45 @@ protected: * 定时触发 */ void onManager() override { - auto strong_ref = _weak_delegate.lock();; if (_onRecv) { - //websocket连接成功了 - _delegate.onManager(); + // websocket连接成功了 + if (auto strong_ref = _weak_delegate.lock()) { + strong_ref->onManager(); + } } else { - //websocket连接中... + // websocket连接中... HttpClientImp::onManager(); } + + if (!_onRecv) { + // websocket尚未链接 + return; + } + + if (_recv_ticker.elapsedTime() > 30 * 1000) { + shutdown(toolkit::SockException(toolkit::Err_timeout, "websocket timeout")); + } else if (_recv_ticker.elapsedTime() > 10 * 1000) { + // 没收到回复,每10秒发送次ping 包 + WebSocketHeader header; + header._fin = true; + header._reserved = 0; + header._opcode = PING; + header._mask_flag = true; + WebSocketSplitter::encode(header, nullptr); + } } /** * 数据全部发送完毕后回调 */ void onFlush() override { - auto strong_ref = _weak_delegate.lock();; if (_onRecv) { - //websocket连接成功了 - _delegate.onFlush(); + // websocket连接成功了 + if (auto strong_ref = _weak_delegate.lock()) { + strong_ref->onFlush(); + } } else { - //websocket连接中... + // websocket连接中... HttpClientImp::onFlush(); } } @@ -191,13 +206,12 @@ protected: * tcp连接结果 */ void onConnect(const toolkit::SockException &ex) override { - auto strong_ref = _weak_delegate.lock();; if (ex) { - //tcp连接失败,直接返回失败 + // tcp连接失败,直接返回失败 onWebSocketException(ex); return; } - //开始websocket握手 + // 开始websocket握手 HttpClientImp::onConnect(ex); } @@ -205,20 +219,17 @@ protected: * tcp连接断开 */ void onErr(const toolkit::SockException &ex) override { - auto strong_ref = _weak_delegate.lock();; - //tcp断开或者shutdown导致的断开 + // tcp断开或者shutdown导致的断开 onWebSocketException(ex); } - //WebSocketSplitter override + // WebSocketSplitter override /** * 收到一个webSocket数据包包头,后续将继续触发onWebSocketDecodePayload回调 * @param header 数据包头 */ - void onWebSocketDecodeHeader(const WebSocketHeader &header) override{ - _payload_section.clear(); - } + void onWebSocketDecodeHeader(const WebSocketHeader &header) override { _payload_section.clear(); } /** * 收到webSocket数据包负载 @@ -227,58 +238,62 @@ protected: * @param len 负载数据长度 * @param recved 已接收数据长度(包含本次数据长度),等于header._payload_len时则接受完毕 */ - void onWebSocketDecodePayload(const WebSocketHeader &header, const uint8_t *ptr, size_t len, size_t recved) override{ - _payload_section.append((char *)ptr,len); + void onWebSocketDecodePayload(const WebSocketHeader &header, const uint8_t *ptr, size_t len, size_t recved) override { + _payload_section.append((char *)ptr, len); } /** * 接收到完整的一个webSocket数据包后回调 * @param header 数据包包头 */ - void onWebSocketDecodeComplete(const WebSocketHeader &header_in) override{ - WebSocketHeader& header = const_cast(header_in); - auto flag = header._mask_flag; - //websocket客户端发送数据需要加密 + void onWebSocketDecodeComplete(const WebSocketHeader &header_in) override { + WebSocketHeader &header = const_cast(header_in); + auto flag = header._mask_flag; + // websocket客户端发送数据需要加密 header._mask_flag = true; - - switch (header._opcode){ - case WebSocketHeader::CLOSE:{ - //服务器主动关闭 - WebSocketSplitter::encode(header,nullptr); - shutdown(toolkit::SockException(toolkit::Err_eof,"websocket server close the connection")); + _recv_ticker.resetTime(); + switch (header._opcode) { + case WebSocketHeader::CLOSE: { + // 服务器主动关闭 + WebSocketSplitter::encode(header, nullptr); + shutdown(toolkit::SockException(toolkit::Err_eof, "websocket server close the connection")); break; } - case WebSocketHeader::PING:{ - //心跳包 + case WebSocketHeader::PING: { + // 心跳包 header._opcode = WebSocketHeader::PONG; - WebSocketSplitter::encode(header,std::make_shared(std::move(_payload_section))); + WebSocketSplitter::encode(header, std::make_shared(std::move(_payload_section))); break; } case WebSocketHeader::CONTINUATION: case WebSocketHeader::TEXT: - case WebSocketHeader::BINARY:{ + case WebSocketHeader::BINARY: { if (!header._fin) { - //还有后续分片数据, 我们先缓存数据,所有分片收集完成才一次性输出 + // 还有后续分片数据, 我们先缓存数据,所有分片收集完成才一次性输出 _payload_cache.append(std::move(_payload_section)); if (_payload_cache.size() < MAX_WS_PACKET) { - //还有内存容量缓存分片数据 + // 还有内存容量缓存分片数据 break; } - //分片缓存太大,需要清空 + // 分片缓存太大,需要清空 } - //最后一个包 + // 最后一个包 if (_payload_cache.empty()) { - //这个包是唯一个分片 - _delegate.onRecv(std::make_shared(header._opcode, header._fin, std::move(_payload_section))); + // 这个包是唯一个分片 + if (auto strong_ref = _weak_delegate.lock()) { + strong_ref->onRecv(std::make_shared(header._opcode, header._fin, std::move(_payload_section))); + } break; } - //这个包由多个分片组成 + // 这个包由多个分片组成 _payload_cache.append(std::move(_payload_section)); - _delegate.onRecv(std::make_shared(header._opcode, header._fin, std::move(_payload_cache))); + if (auto strong_ref = _weak_delegate.lock()) { + strong_ref->onRecv(std::make_shared(header._opcode, header._fin, std::move(_payload_cache))); + } _payload_cache.clear(); break; } @@ -294,61 +309,65 @@ protected: * @param ptr 数据指针 * @param len 数据指针长度 */ - void onWebSocketEncodeData(toolkit::Buffer::Ptr buffer) override{ - HttpClientImp::send(std::move(buffer)); - } + void onWebSocketEncodeData(toolkit::Buffer::Ptr buffer) override { HttpClientImp::send(std::move(buffer)); } private: - void onWebSocketException(const toolkit::SockException &ex){ - if(!ex){ - //websocket握手成功 - //此处截取TcpClient派生类发送的数据并进行websocket协议打包 + void onWebSocketException(const toolkit::SockException &ex) { + if (!ex) { + // websocket握手成功 + // 此处截取TcpClient派生类发送的数据并进行websocket协议打包 std::weak_ptr weakSelf = std::dynamic_pointer_cast(shared_from_this()); - _delegate.setOnBeforeSendCB([weakSelf](const toolkit::Buffer::Ptr &buf){ - auto strongSelf = weakSelf.lock(); - if(strongSelf){ - WebSocketHeader header; - header._fin = true; - header._reserved = 0; - header._opcode = DataType; - //客户端需要加密 - header._mask_flag = true; - strongSelf->WebSocketSplitter::encode(header,buf); - } - return buf->size(); - }); + if (auto strong_ref = _weak_delegate.lock()) { + strong_ref->setOnBeforeSendCB([weakSelf](const toolkit::Buffer::Ptr &buf) { + auto strong_self = weakSelf.lock(); + if (strong_self) { + WebSocketHeader header; + header._fin = true; + header._reserved = 0; + header._opcode = DataType; + // 客户端需要加密 + header._mask_flag = true; + strong_self->WebSocketSplitter::encode(header, buf); + } + return buf->size(); + }); + // 设置sock,否则shutdown等接口都无效 + strong_ref->setSock(HttpClientImp::getSock()); + // 触发连接成功事件 + strong_ref->onConnect(ex); + } - //设置sock,否则shutdown等接口都无效 - _delegate.setSock(HttpClientImp::getSock()); - //触发连接成功事件 - _delegate.onConnect(ex); - //拦截websocket数据接收 - _onRecv = [this](const char *data, size_t len){ - //解析websocket数据包 + // 拦截websocket数据接收 + _onRecv = [this](const char *data, size_t len) { + // 解析websocket数据包 this->WebSocketSplitter::decode((uint8_t *)data, len); }; return; } - //websocket握手失败或者tcp连接失败或者中途断开 - if(_onRecv){ - //握手成功之后的中途断开 + // websocket握手失败或者tcp连接失败或者中途断开 + if (_onRecv) { + // 握手成功之后的中途断开 _onRecv = nullptr; - _delegate.onErr(ex); + if (auto strong_ref = _weak_delegate.lock()) { + strong_ref->onErr(ex); + } return; } - //websocket握手失败或者tcp连接失败 - _delegate.onConnect(ex); + // websocket握手失败或者tcp连接失败 + if (auto strong_ref = _weak_delegate.lock()) { + strong_ref->onConnect(ex); + } } private: std::string _Sec_WebSocket_Key; std::function _onRecv; std::weak_ptr> _weak_delegate; - ClientTypeImp &_delegate; std::string _payload_section; std::string _payload_cache; + toolkit::Ticker _recv_ticker; }; /** @@ -358,17 +377,14 @@ private: * @tparam DataType websocket负载类型,是TEXT还是BINARY类型 * @tparam useWSS 是否使用ws还是wss连接 */ -template -class WebSocketClient : public ClientTypeImp{ +template +class WebSocketClient : public ClientTypeImp { public: using Ptr = std::shared_ptr; - template - WebSocketClient(ArgsType &&...args) : ClientTypeImp(std::forward(args)...){ - } - ~WebSocketClient() override { - _wsClient->closeWsClient(); - } + template + WebSocketClient(ArgsType &&...args) : ClientTypeImp(std::forward(args)...) {} + ~WebSocketClient() override { _wsClient->closeWsClient(); } /** * 重载startConnect方法, @@ -381,30 +397,26 @@ public: void startConnect(const std::string &host, uint16_t port, float timeout_sec = 3, uint16_t local_port = 0) override { std::string ws_url; if (useWSS) { - //加密的ws + // 加密的ws ws_url = StrPrinter << "wss://" + host << ":" << port << "/"; } else { - //明文ws + // 明文ws ws_url = StrPrinter << "ws://" + host << ":" << port << "/"; } startWebSocket(ws_url, timeout_sec); } void startWebSocket(const std::string &ws_url, float fTimeOutSec = 3) { - _wsClient = std::make_shared >(std::static_pointer_cast(this->shared_from_this())); - _wsClient->setOnCreateSocket([this](const toolkit::EventPoller::Ptr &){ - return this->createSocket(); - }); - _wsClient->startWsClient(ws_url,fTimeOutSec); + _wsClient = std::make_shared>(std::static_pointer_cast(this->shared_from_this())); + _wsClient->setOnCreateSocket([this](const toolkit::EventPoller::Ptr &) { return this->createSocket(); }); + _wsClient->startWsClient(ws_url, fTimeOutSec); } - HttpClient &getHttpClient() { - return *_wsClient; - } + HttpClient &getHttpClient() { return *_wsClient; } private: - typename HttpWsClient::Ptr _wsClient; + typename HttpWsClient::Ptr _wsClient; }; -}//namespace mediakit -#endif //ZLMEDIAKIT_WebSocketClient_H +} // namespace mediakit +#endif // ZLMEDIAKIT_WebSocketClient_H diff --git a/src/Http/WebSocketSession.h b/src/Http/WebSocketSession.h index c7eb9b8f..83d7da73 100644 --- a/src/Http/WebSocketSession.h +++ b/src/Http/WebSocketSession.h @@ -36,7 +36,7 @@ public: using Ptr = std::shared_ptr; SessionTypeImp(const mediakit::Parser &header, const mediakit::HttpSession &parent, const toolkit::Socket::Ptr &pSock) : - SessionType(pSock), _identifier(parent.getIdentifier()) {} + SessionType(pSock) {} ~SessionTypeImp() = default; @@ -61,12 +61,7 @@ protected: return SessionType::send(std::move(buf)); } - std::string getIdentifier() const override { - return _identifier; - } - private: - std::string _identifier; onBeforeSendCB _beforeSendCB; }; @@ -98,11 +93,26 @@ public: } //每隔一段时间触发,用来做超时管理 void onManager() override{ - if(_session){ + if (_session) { _session->onManager(); - }else{ + } else { HttpSessionType::onManager(); } + if (!_session) { + // websocket尚未链接 + return; + } + if (_recv_ticker.elapsedTime() > 30 * 1000) { + HttpSessionType::shutdown(toolkit::SockException(toolkit::Err_timeout, "websocket timeout")); + } else if (_recv_ticker.elapsedTime() > 10 * 1000) { + // 没收到回复,每10秒发送次ping 包 + mediakit::WebSocketHeader header; + header._fin = true; + header._reserved = 0; + header._opcode = mediakit::WebSocketHeader::PING; + header._mask_flag = false; + HttpSessionType::encode(header, nullptr); + } } void attachServer(const toolkit::Server &server) override{ @@ -118,13 +128,13 @@ protected: */ bool onWebSocketConnect(const mediakit::Parser &header) override{ //创建websocket session类 - _session = _creator(header, *this,HttpSessionType::getSock()); - if(!_session){ - //此url不允许创建websocket连接 + _session = _creator(header, *this, HttpSessionType::getSock()); + if (!_session) { + // 此url不允许创建websocket连接 return false; } auto strongServer = _weak_server.lock(); - if(strongServer){ + if (strongServer) { _session->attachServer(*strongServer); } @@ -170,7 +180,7 @@ protected: auto header = const_cast(header_in); auto flag = header._mask_flag; header._mask_flag = false; - + _recv_ticker.resetTime(); switch (header._opcode){ case mediakit::WebSocketHeader::CLOSE:{ HttpSessionType::encode(header,nullptr); @@ -230,6 +240,7 @@ private: std::weak_ptr _weak_server; toolkit::Session::Ptr _session; Creator _creator; + toolkit::Ticker _recv_ticker; };