修复WebSocket线程安全问题,同时新增内置客户端、服务端心跳机制。

This commit is contained in:
xia-chu 2023-03-11 11:08:14 +08:00
parent c2a8d46a64
commit 1bab0b8e31
2 changed files with 163 additions and 140 deletions

View File

@ -18,9 +18,9 @@
#include "HttpClientImp.h" #include "HttpClientImp.h"
#include "WebSocketSplitter.h" #include "WebSocketSplitter.h"
namespace mediakit{ namespace mediakit {
template <typename ClientType,WebSocketHeader::Type DataType> template <typename ClientType, WebSocketHeader::Type DataType>
class HttpWsClient; class HttpWsClient;
/** /**
@ -28,23 +28,23 @@ class HttpWsClient;
* @tparam ClientType TcpClient派生类 * @tparam ClientType TcpClient派生类
* @tparam DataType , * @tparam DataType ,
*/ */
template <typename ClientType,WebSocketHeader::Type DataType> template <typename ClientType, WebSocketHeader::Type DataType>
class ClientTypeImp : public ClientType { class ClientTypeImp : public ClientType {
public: public:
friend class HttpWsClient<ClientType, DataType>; friend class HttpWsClient<ClientType, DataType>;
using onBeforeSendCB = std::function<ssize_t(const toolkit::Buffer::Ptr &buf)>; using onBeforeSendCB = std::function<ssize_t(const toolkit::Buffer::Ptr &buf)>;
template<typename ...ArgsType> template <typename... ArgsType>
ClientTypeImp(ArgsType &&...args): ClientType(std::forward<ArgsType>(args)...){} ClientTypeImp(ArgsType &&...args) : ClientType(std::forward<ArgsType>(args)...) {}
~ClientTypeImp() override {}; ~ClientTypeImp() override {};
protected: protected:
/** /**
* websocket协议 * websocket协议
*/ */
ssize_t send(toolkit::Buffer::Ptr buf) override{ ssize_t send(toolkit::Buffer::Ptr buf) override {
if(_beforeSendCB){ if (_beforeSendCB) {
return _beforeSendCB(buf); return _beforeSendCB(buf);
} }
return ClientType::send(std::move(buf)); return ClientType::send(std::move(buf));
@ -54,9 +54,7 @@ protected:
* *
* @param cb * @param cb
*/ */
void setOnBeforeSendCB(const onBeforeSendCB &cb){ void setOnBeforeSendCB(const onBeforeSendCB &cb) { _beforeSendCB = cb; }
_beforeSendCB = cb;
}
private: private:
onBeforeSendCB _beforeSendCB; onBeforeSendCB _beforeSendCB;
@ -67,17 +65,16 @@ private:
* @tparam ClientType TcpClient派生类 * @tparam ClientType TcpClient派生类
* @tparam DataType websocket负载类型TEXT还是BINARY类型 * @tparam DataType websocket负载类型TEXT还是BINARY类型
*/ */
template <typename ClientType,WebSocketHeader::Type DataType = WebSocketHeader::TEXT> template <typename ClientType, WebSocketHeader::Type DataType = WebSocketHeader::TEXT>
class HttpWsClient : public HttpClientImp , public WebSocketSplitter{ class HttpWsClient : public HttpClientImp, public WebSocketSplitter {
public: public:
using Ptr = std::shared_ptr<HttpWsClient>; using Ptr = std::shared_ptr<HttpWsClient>;
HttpWsClient(const std::shared_ptr<ClientTypeImp<ClientType, DataType> > &delegate) : _weak_delegate(delegate), HttpWsClient(const std::shared_ptr<ClientTypeImp<ClientType, DataType>> &delegate) : _weak_delegate(delegate) {
_delegate(*delegate) {
_Sec_WebSocket_Key = encodeBase64(toolkit::makeRandStr(16, false)); _Sec_WebSocket_Key = encodeBase64(toolkit::makeRandStr(16, false));
setPoller(_delegate.getPoller()); setPoller(delegate->getPoller());
} }
~HttpWsClient(){} ~HttpWsClient() = default;
/** /**
* ws握手 * ws握手
@ -98,22 +95,22 @@ public:
sendRequest(http_url); sendRequest(http_url);
} }
void closeWsClient(){ void closeWsClient() {
if(!_onRecv){ if (!_onRecv) {
//未连接 // 未连接
return; return;
} }
WebSocketHeader header; WebSocketHeader header;
header._fin = true; header._fin = true;
header._reserved = 0; header._reserved = 0;
header._opcode = CLOSE; header._opcode = CLOSE;
//客户端需要加密 // 客户端需要加密
header._mask_flag = true; header._mask_flag = true;
WebSocketSplitter::encode(header, nullptr); WebSocketSplitter::encode(header, nullptr);
} }
protected: protected:
//HttpClientImp override // HttpClientImp override
/** /**
* http回复头 * http回复头
@ -121,12 +118,12 @@ protected:
* @param headers http头 * @param headers http头
*/ */
void onResponseHeader(const std::string &status, const HttpHeader &headers) override { void onResponseHeader(const std::string &status, const HttpHeader &headers) override {
if(status == "101"){ if (status == "101") {
auto Sec_WebSocket_Accept = encodeBase64(toolkit::SHA1::encode_bin(_Sec_WebSocket_Key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11")); auto Sec_WebSocket_Accept = encodeBase64(toolkit::SHA1::encode_bin(_Sec_WebSocket_Key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"));
if(Sec_WebSocket_Accept == const_cast<HttpHeader &>(headers)["Sec-WebSocket-Accept"]){ if (Sec_WebSocket_Accept == const_cast<HttpHeader &>(headers)["Sec-WebSocket-Accept"]) {
//success // success
onWebSocketException(toolkit::SockException()); onWebSocketException(toolkit::SockException());
//防止ws服务器返回Content-Length // 防止ws服务器返回Content-Length
const_cast<HttpHeader &>(headers).erase("Content-Length"); const_cast<HttpHeader &>(headers).erase("Content-Length");
return; return;
} }
@ -134,7 +131,7 @@ protected:
return; return;
} }
shutdown(toolkit::SockException(toolkit::Err_shutdown,StrPrinter << "bad http status code:" << status)); shutdown(toolkit::SockException(toolkit::Err_shutdown, StrPrinter << "bad http status code:" << status));
}; };
/** /**
@ -145,17 +142,16 @@ protected:
/** /**
* websocket负载数据 * websocket负载数据
*/ */
void onResponseBody(const char *buf,size_t size) override{ void onResponseBody(const char *buf, size_t size) override {
if(_onRecv){ if (_onRecv) {
//完成websocket握手后拦截websocket数据并解析 // 完成websocket握手后拦截websocket数据并解析
_onRecv(buf, size); _onRecv(buf, size);
} }
}; };
//TcpClient override // TcpClient override
void onRecv(const toolkit::Buffer::Ptr &buf) override { void onRecv(const toolkit::Buffer::Ptr &buf) override {
auto strong_ref = _weak_delegate.lock();;
HttpClientImp::onRecv(buf); HttpClientImp::onRecv(buf);
} }
@ -163,26 +159,45 @@ protected:
* *
*/ */
void onManager() override { void onManager() override {
auto strong_ref = _weak_delegate.lock();;
if (_onRecv) { if (_onRecv) {
//websocket连接成功了 // websocket连接成功了
_delegate.onManager(); if (auto strong_ref = _weak_delegate.lock()) {
strong_ref->onManager();
}
} else { } else {
//websocket连接中... // websocket连接中...
HttpClientImp::onManager(); HttpClientImp::onManager();
} }
if (!_onRecv) {
// websocket尚未链接
return;
}
if (_recv_ticker.elapsedTime() > 30 * 1000) {
shutdown(toolkit::SockException(toolkit::Err_timeout, "websocket timeout"));
} else if (_recv_ticker.elapsedTime() > 10 * 1000) {
// 没收到回复每10秒发送次ping 包
WebSocketHeader header;
header._fin = true;
header._reserved = 0;
header._opcode = PING;
header._mask_flag = true;
WebSocketSplitter::encode(header, nullptr);
}
} }
/** /**
* *
*/ */
void onFlush() override { void onFlush() override {
auto strong_ref = _weak_delegate.lock();;
if (_onRecv) { if (_onRecv) {
//websocket连接成功了 // websocket连接成功了
_delegate.onFlush(); if (auto strong_ref = _weak_delegate.lock()) {
strong_ref->onFlush();
}
} else { } else {
//websocket连接中... // websocket连接中...
HttpClientImp::onFlush(); HttpClientImp::onFlush();
} }
} }
@ -191,13 +206,12 @@ protected:
* tcp连接结果 * tcp连接结果
*/ */
void onConnect(const toolkit::SockException &ex) override { void onConnect(const toolkit::SockException &ex) override {
auto strong_ref = _weak_delegate.lock();;
if (ex) { if (ex) {
//tcp连接失败直接返回失败 // tcp连接失败直接返回失败
onWebSocketException(ex); onWebSocketException(ex);
return; return;
} }
//开始websocket握手 // 开始websocket握手
HttpClientImp::onConnect(ex); HttpClientImp::onConnect(ex);
} }
@ -205,20 +219,17 @@ protected:
* tcp连接断开 * tcp连接断开
*/ */
void onErr(const toolkit::SockException &ex) override { void onErr(const toolkit::SockException &ex) override {
auto strong_ref = _weak_delegate.lock();; // tcp断开或者shutdown导致的断开
//tcp断开或者shutdown导致的断开
onWebSocketException(ex); onWebSocketException(ex);
} }
//WebSocketSplitter override // WebSocketSplitter override
/** /**
* webSocket数据包包头onWebSocketDecodePayload回调 * webSocket数据包包头onWebSocketDecodePayload回调
* @param header * @param header
*/ */
void onWebSocketDecodeHeader(const WebSocketHeader &header) override{ void onWebSocketDecodeHeader(const WebSocketHeader &header) override { _payload_section.clear(); }
_payload_section.clear();
}
/** /**
* webSocket数据包负载 * webSocket数据包负载
@ -227,58 +238,62 @@ protected:
* @param len * @param len
* @param recved ()header._payload_len时则接受完毕 * @param recved ()header._payload_len时则接受完毕
*/ */
void onWebSocketDecodePayload(const WebSocketHeader &header, const uint8_t *ptr, size_t len, size_t recved) override{ void onWebSocketDecodePayload(const WebSocketHeader &header, const uint8_t *ptr, size_t len, size_t recved) override {
_payload_section.append((char *)ptr,len); _payload_section.append((char *)ptr, len);
} }
/** /**
* webSocket数据包后回调 * webSocket数据包后回调
* @param header * @param header
*/ */
void onWebSocketDecodeComplete(const WebSocketHeader &header_in) override{ void onWebSocketDecodeComplete(const WebSocketHeader &header_in) override {
WebSocketHeader& header = const_cast<WebSocketHeader&>(header_in); WebSocketHeader &header = const_cast<WebSocketHeader &>(header_in);
auto flag = header._mask_flag; auto flag = header._mask_flag;
//websocket客户端发送数据需要加密 // websocket客户端发送数据需要加密
header._mask_flag = true; header._mask_flag = true;
_recv_ticker.resetTime();
switch (header._opcode){ switch (header._opcode) {
case WebSocketHeader::CLOSE:{ case WebSocketHeader::CLOSE: {
//服务器主动关闭 // 服务器主动关闭
WebSocketSplitter::encode(header,nullptr); WebSocketSplitter::encode(header, nullptr);
shutdown(toolkit::SockException(toolkit::Err_eof,"websocket server close the connection")); shutdown(toolkit::SockException(toolkit::Err_eof, "websocket server close the connection"));
break; break;
} }
case WebSocketHeader::PING:{ case WebSocketHeader::PING: {
//心跳包 // 心跳包
header._opcode = WebSocketHeader::PONG; header._opcode = WebSocketHeader::PONG;
WebSocketSplitter::encode(header,std::make_shared<toolkit::BufferString>(std::move(_payload_section))); WebSocketSplitter::encode(header, std::make_shared<toolkit::BufferString>(std::move(_payload_section)));
break; break;
} }
case WebSocketHeader::CONTINUATION: case WebSocketHeader::CONTINUATION:
case WebSocketHeader::TEXT: case WebSocketHeader::TEXT:
case WebSocketHeader::BINARY:{ case WebSocketHeader::BINARY: {
if (!header._fin) { if (!header._fin) {
//还有后续分片数据, 我们先缓存数据,所有分片收集完成才一次性输出 // 还有后续分片数据, 我们先缓存数据,所有分片收集完成才一次性输出
_payload_cache.append(std::move(_payload_section)); _payload_cache.append(std::move(_payload_section));
if (_payload_cache.size() < MAX_WS_PACKET) { if (_payload_cache.size() < MAX_WS_PACKET) {
//还有内存容量缓存分片数据 // 还有内存容量缓存分片数据
break; break;
} }
//分片缓存太大,需要清空 // 分片缓存太大,需要清空
} }
//最后一个包 // 最后一个包
if (_payload_cache.empty()) { if (_payload_cache.empty()) {
//这个包是唯一个分片 // 这个包是唯一个分片
_delegate.onRecv(std::make_shared<WebSocketBuffer>(header._opcode, header._fin, std::move(_payload_section))); if (auto strong_ref = _weak_delegate.lock()) {
strong_ref->onRecv(std::make_shared<WebSocketBuffer>(header._opcode, header._fin, std::move(_payload_section)));
}
break; break;
} }
//这个包由多个分片组成 // 这个包由多个分片组成
_payload_cache.append(std::move(_payload_section)); _payload_cache.append(std::move(_payload_section));
_delegate.onRecv(std::make_shared<WebSocketBuffer>(header._opcode, header._fin, std::move(_payload_cache))); if (auto strong_ref = _weak_delegate.lock()) {
strong_ref->onRecv(std::make_shared<WebSocketBuffer>(header._opcode, header._fin, std::move(_payload_cache)));
}
_payload_cache.clear(); _payload_cache.clear();
break; break;
} }
@ -294,61 +309,65 @@ protected:
* @param ptr * @param ptr
* @param len * @param len
*/ */
void onWebSocketEncodeData(toolkit::Buffer::Ptr buffer) override{ void onWebSocketEncodeData(toolkit::Buffer::Ptr buffer) override { HttpClientImp::send(std::move(buffer)); }
HttpClientImp::send(std::move(buffer));
}
private: private:
void onWebSocketException(const toolkit::SockException &ex){ void onWebSocketException(const toolkit::SockException &ex) {
if(!ex){ if (!ex) {
//websocket握手成功 // websocket握手成功
//此处截取TcpClient派生类发送的数据并进行websocket协议打包 // 此处截取TcpClient派生类发送的数据并进行websocket协议打包
std::weak_ptr<HttpWsClient> weakSelf = std::dynamic_pointer_cast<HttpWsClient>(shared_from_this()); std::weak_ptr<HttpWsClient> weakSelf = std::dynamic_pointer_cast<HttpWsClient>(shared_from_this());
_delegate.setOnBeforeSendCB([weakSelf](const toolkit::Buffer::Ptr &buf){ if (auto strong_ref = _weak_delegate.lock()) {
auto strongSelf = weakSelf.lock(); strong_ref->setOnBeforeSendCB([weakSelf](const toolkit::Buffer::Ptr &buf) {
if(strongSelf){ auto strong_self = weakSelf.lock();
WebSocketHeader header; if (strong_self) {
header._fin = true; WebSocketHeader header;
header._reserved = 0; header._fin = true;
header._opcode = DataType; header._reserved = 0;
//客户端需要加密 header._opcode = DataType;
header._mask_flag = true; // 客户端需要加密
strongSelf->WebSocketSplitter::encode(header,buf); header._mask_flag = true;
} strong_self->WebSocketSplitter::encode(header, buf);
return buf->size(); }
}); return buf->size();
});
// 设置sock否则shutdown等接口都无效
strong_ref->setSock(HttpClientImp::getSock());
// 触发连接成功事件
strong_ref->onConnect(ex);
}
//设置sock否则shutdown等接口都无效 // 拦截websocket数据接收
_delegate.setSock(HttpClientImp::getSock()); _onRecv = [this](const char *data, size_t len) {
//触发连接成功事件 // 解析websocket数据包
_delegate.onConnect(ex);
//拦截websocket数据接收
_onRecv = [this](const char *data, size_t len){
//解析websocket数据包
this->WebSocketSplitter::decode((uint8_t *)data, len); this->WebSocketSplitter::decode((uint8_t *)data, len);
}; };
return; return;
} }
//websocket握手失败或者tcp连接失败或者中途断开 // websocket握手失败或者tcp连接失败或者中途断开
if(_onRecv){ if (_onRecv) {
//握手成功之后的中途断开 // 握手成功之后的中途断开
_onRecv = nullptr; _onRecv = nullptr;
_delegate.onErr(ex); if (auto strong_ref = _weak_delegate.lock()) {
strong_ref->onErr(ex);
}
return; return;
} }
//websocket握手失败或者tcp连接失败 // websocket握手失败或者tcp连接失败
_delegate.onConnect(ex); if (auto strong_ref = _weak_delegate.lock()) {
strong_ref->onConnect(ex);
}
} }
private: private:
std::string _Sec_WebSocket_Key; std::string _Sec_WebSocket_Key;
std::function<void(const char *data, size_t len)> _onRecv; std::function<void(const char *data, size_t len)> _onRecv;
std::weak_ptr<ClientTypeImp<ClientType, DataType>> _weak_delegate; std::weak_ptr<ClientTypeImp<ClientType, DataType>> _weak_delegate;
ClientTypeImp<ClientType, DataType> &_delegate;
std::string _payload_section; std::string _payload_section;
std::string _payload_cache; std::string _payload_cache;
toolkit::Ticker _recv_ticker;
}; };
/** /**
@ -358,17 +377,14 @@ private:
* @tparam DataType websocket负载类型TEXT还是BINARY类型 * @tparam DataType websocket负载类型TEXT还是BINARY类型
* @tparam useWSS 使ws还是wss连接 * @tparam useWSS 使ws还是wss连接
*/ */
template <typename ClientType,WebSocketHeader::Type DataType = WebSocketHeader::TEXT,bool useWSS = false > template <typename ClientType, WebSocketHeader::Type DataType = WebSocketHeader::TEXT, bool useWSS = false>
class WebSocketClient : public ClientTypeImp<ClientType,DataType>{ class WebSocketClient : public ClientTypeImp<ClientType, DataType> {
public: public:
using Ptr = std::shared_ptr<WebSocketClient>; using Ptr = std::shared_ptr<WebSocketClient>;
template<typename ...ArgsType> template <typename... ArgsType>
WebSocketClient(ArgsType &&...args) : ClientTypeImp<ClientType,DataType>(std::forward<ArgsType>(args)...){ WebSocketClient(ArgsType &&...args) : ClientTypeImp<ClientType, DataType>(std::forward<ArgsType>(args)...) {}
} ~WebSocketClient() override { _wsClient->closeWsClient(); }
~WebSocketClient() override {
_wsClient->closeWsClient();
}
/** /**
* startConnect方法 * startConnect方法
@ -381,30 +397,26 @@ public:
void startConnect(const std::string &host, uint16_t port, float timeout_sec = 3, uint16_t local_port = 0) override { void startConnect(const std::string &host, uint16_t port, float timeout_sec = 3, uint16_t local_port = 0) override {
std::string ws_url; std::string ws_url;
if (useWSS) { if (useWSS) {
//加密的ws // 加密的ws
ws_url = StrPrinter << "wss://" + host << ":" << port << "/"; ws_url = StrPrinter << "wss://" + host << ":" << port << "/";
} else { } else {
//明文ws // 明文ws
ws_url = StrPrinter << "ws://" + host << ":" << port << "/"; ws_url = StrPrinter << "ws://" + host << ":" << port << "/";
} }
startWebSocket(ws_url, timeout_sec); startWebSocket(ws_url, timeout_sec);
} }
void startWebSocket(const std::string &ws_url, float fTimeOutSec = 3) { void startWebSocket(const std::string &ws_url, float fTimeOutSec = 3) {
_wsClient = std::make_shared<HttpWsClient<ClientType, DataType> >(std::static_pointer_cast<WebSocketClient>(this->shared_from_this())); _wsClient = std::make_shared<HttpWsClient<ClientType, DataType>>(std::static_pointer_cast<WebSocketClient>(this->shared_from_this()));
_wsClient->setOnCreateSocket([this](const toolkit::EventPoller::Ptr &){ _wsClient->setOnCreateSocket([this](const toolkit::EventPoller::Ptr &) { return this->createSocket(); });
return this->createSocket(); _wsClient->startWsClient(ws_url, fTimeOutSec);
});
_wsClient->startWsClient(ws_url,fTimeOutSec);
} }
HttpClient &getHttpClient() { HttpClient &getHttpClient() { return *_wsClient; }
return *_wsClient;
}
private: private:
typename HttpWsClient<ClientType,DataType>::Ptr _wsClient; typename HttpWsClient<ClientType, DataType>::Ptr _wsClient;
}; };
}//namespace mediakit } // namespace mediakit
#endif //ZLMEDIAKIT_WebSocketClient_H #endif // ZLMEDIAKIT_WebSocketClient_H

View File

@ -36,7 +36,7 @@ public:
using Ptr = std::shared_ptr<SessionTypeImp>; using Ptr = std::shared_ptr<SessionTypeImp>;
SessionTypeImp(const mediakit::Parser &header, const mediakit::HttpSession &parent, const toolkit::Socket::Ptr &pSock) : SessionTypeImp(const mediakit::Parser &header, const mediakit::HttpSession &parent, const toolkit::Socket::Ptr &pSock) :
SessionType(pSock), _identifier(parent.getIdentifier()) {} SessionType(pSock) {}
~SessionTypeImp() = default; ~SessionTypeImp() = default;
@ -61,12 +61,7 @@ protected:
return SessionType::send(std::move(buf)); return SessionType::send(std::move(buf));
} }
std::string getIdentifier() const override {
return _identifier;
}
private: private:
std::string _identifier;
onBeforeSendCB _beforeSendCB; onBeforeSendCB _beforeSendCB;
}; };
@ -98,11 +93,26 @@ public:
} }
//每隔一段时间触发,用来做超时管理 //每隔一段时间触发,用来做超时管理
void onManager() override{ void onManager() override{
if(_session){ if (_session) {
_session->onManager(); _session->onManager();
}else{ } else {
HttpSessionType::onManager(); HttpSessionType::onManager();
} }
if (!_session) {
// websocket尚未链接
return;
}
if (_recv_ticker.elapsedTime() > 30 * 1000) {
HttpSessionType::shutdown(toolkit::SockException(toolkit::Err_timeout, "websocket timeout"));
} else if (_recv_ticker.elapsedTime() > 10 * 1000) {
// 没收到回复每10秒发送次ping 包
mediakit::WebSocketHeader header;
header._fin = true;
header._reserved = 0;
header._opcode = mediakit::WebSocketHeader::PING;
header._mask_flag = false;
HttpSessionType::encode(header, nullptr);
}
} }
void attachServer(const toolkit::Server &server) override{ void attachServer(const toolkit::Server &server) override{
@ -118,13 +128,13 @@ protected:
*/ */
bool onWebSocketConnect(const mediakit::Parser &header) override{ bool onWebSocketConnect(const mediakit::Parser &header) override{
//创建websocket session类 //创建websocket session类
_session = _creator(header, *this,HttpSessionType::getSock()); _session = _creator(header, *this, HttpSessionType::getSock());
if(!_session){ if (!_session) {
//此url不允许创建websocket连接 // 此url不允许创建websocket连接
return false; return false;
} }
auto strongServer = _weak_server.lock(); auto strongServer = _weak_server.lock();
if(strongServer){ if (strongServer) {
_session->attachServer(*strongServer); _session->attachServer(*strongServer);
} }
@ -170,7 +180,7 @@ protected:
auto header = const_cast<mediakit::WebSocketHeader&>(header_in); auto header = const_cast<mediakit::WebSocketHeader&>(header_in);
auto flag = header._mask_flag; auto flag = header._mask_flag;
header._mask_flag = false; header._mask_flag = false;
_recv_ticker.resetTime();
switch (header._opcode){ switch (header._opcode){
case mediakit::WebSocketHeader::CLOSE:{ case mediakit::WebSocketHeader::CLOSE:{
HttpSessionType::encode(header,nullptr); HttpSessionType::encode(header,nullptr);
@ -230,6 +240,7 @@ private:
std::weak_ptr<toolkit::Server> _weak_server; std::weak_ptr<toolkit::Server> _weak_server;
toolkit::Session::Ptr _session; toolkit::Session::Ptr _session;
Creator _creator; Creator _creator;
toolkit::Ticker _recv_ticker;
}; };