1、ws-flv直播支持客户端主动关闭请求:#430

2、兼容CONTINUATION类型的websocket包
3、修复websocket客户端在处理Content-Length时的相关bug
This commit is contained in:
xiongziliang 2020-08-08 12:17:06 +08:00
parent e7e8969b4f
commit 2fd567b8b0
5 changed files with 149 additions and 61 deletions

View File

@ -132,36 +132,37 @@ void HttpSession::onManager() {
bool HttpSession::checkWebSocket(){ bool HttpSession::checkWebSocket(){
auto Sec_WebSocket_Key = _parser["Sec-WebSocket-Key"]; auto Sec_WebSocket_Key = _parser["Sec-WebSocket-Key"];
if(Sec_WebSocket_Key.empty()){ if (Sec_WebSocket_Key.empty()) {
return false; return false;
} }
auto Sec_WebSocket_Accept = encodeBase64(SHA1::encode_bin(Sec_WebSocket_Key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11")); auto Sec_WebSocket_Accept = encodeBase64(
SHA1::encode_bin(Sec_WebSocket_Key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"));
KeyValue headerOut; KeyValue headerOut;
headerOut["Upgrade"] = "websocket"; headerOut["Upgrade"] = "websocket";
headerOut["Connection"] = "Upgrade"; headerOut["Connection"] = "Upgrade";
headerOut["Sec-WebSocket-Accept"] = Sec_WebSocket_Accept; headerOut["Sec-WebSocket-Accept"] = Sec_WebSocket_Accept;
if(!_parser["Sec-WebSocket-Protocol"].empty()){ if (!_parser["Sec-WebSocket-Protocol"].empty()) {
headerOut["Sec-WebSocket-Protocol"] = _parser["Sec-WebSocket-Protocol"]; headerOut["Sec-WebSocket-Protocol"] = _parser["Sec-WebSocket-Protocol"];
} }
auto res_cb = [this,headerOut](){ auto res_cb = [this, headerOut]() {
_flv_over_websocket = true; _flv_over_websocket = true;
sendResponse("101 Switching Protocols",false,nullptr,headerOut,nullptr, true); sendResponse("101 Switching Protocols", false, nullptr, headerOut, nullptr, true);
}; };
//判断是否为websocket-flv //判断是否为websocket-flv
if(checkLiveFlvStream(res_cb)){ if (checkLiveFlvStream(res_cb)) {
//这里是websocket-flv直播请求 //这里是websocket-flv直播请求
return true; return true;
} }
//如果checkLiveFlvStream返回false,则代表不是websocket-flv而是普通的websocket连接 //如果checkLiveFlvStream返回false,则代表不是websocket-flv而是普通的websocket连接
if(!onWebSocketConnect(_parser)){ if (!onWebSocketConnect(_parser)) {
sendResponse("501 Not Implemented",true, nullptr, headerOut); sendResponse("501 Not Implemented", true, nullptr, headerOut);
return true; return true;
} }
sendResponse("101 Switching Protocols",false, nullptr,headerOut); sendResponse("101 Switching Protocols", false, nullptr, headerOut, nullptr, true);
return true; return true;
} }
@ -389,7 +390,7 @@ void HttpSession::sendResponse(const char *pcStatus,
const char *pcContentType, const char *pcContentType,
const HttpSession::KeyValue &header, const HttpSession::KeyValue &header,
const HttpBody::Ptr &body, const HttpBody::Ptr &body,
bool is_http_flv ){ bool no_content_length ){
GET_CONFIG(string,charSet,Http::kCharSet); GET_CONFIG(string,charSet,Http::kCharSet);
GET_CONFIG(uint32_t,keepAliveSec,Http::kKeepAliveSecond); GET_CONFIG(uint32_t,keepAliveSec,Http::kKeepAliveSecond);
@ -400,7 +401,7 @@ void HttpSession::sendResponse(const char *pcStatus,
size = body->remainSize(); size = body->remainSize();
} }
if(is_http_flv){ if(no_content_length){
//http-flv直播是Keep-Alive类型 //http-flv直播是Keep-Alive类型
bClose = false; bClose = false;
}else if(size >= INT64_MAX){ }else if(size >= INT64_MAX){
@ -425,7 +426,7 @@ void HttpSession::sendResponse(const char *pcStatus,
headerOut.emplace(kAccessControlAllowCredentials, "true"); headerOut.emplace(kAccessControlAllowCredentials, "true");
} }
if(!is_http_flv && size >= 0 && size < INT64_MAX){ if(!no_content_length && size >= 0 && size < INT64_MAX){
//文件长度为固定值,且不是http-flv强制设置Content-Length //文件长度为固定值,且不是http-flv强制设置Content-Length
headerOut[kContentLength] = to_string(size); headerOut[kContentLength] = to_string(size);
} }
@ -645,6 +646,21 @@ void HttpSession::onWebSocketEncodeData(const Buffer::Ptr &buffer){
send(buffer); send(buffer);
} }
void HttpSession::onWebSocketDecodeComplete(const WebSocketHeader &header_in){
WebSocketHeader& header = const_cast<WebSocketHeader&>(header_in);
header._mask_flag = false;
switch (header._opcode) {
case WebSocketHeader::CLOSE: {
encode(header, nullptr);
shutdown(SockException(Err_shutdown, "recv close request from client"));
break;
}
default : break;
}
}
void HttpSession::onDetach() { void HttpSession::onDetach() {
shutdown(SockException(Err_shutdown,"rtmp ring buffer detached")); shutdown(SockException(Err_shutdown,"rtmp ring buffer detached"));
} }

View File

@ -47,6 +47,7 @@ public:
void onError(const SockException &err) override; void onError(const SockException &err) override;
void onManager() override; void onManager() override;
static string urlDecode(const string &str); static string urlDecode(const string &str);
protected: protected:
//FlvMuxer override //FlvMuxer override
void onWrite(const Buffer::Ptr &data, bool flush) override ; void onWrite(const Buffer::Ptr &data, bool flush) override ;
@ -90,6 +91,13 @@ protected:
* @param buffer websocket协议数据 * @param buffer websocket协议数据
*/ */
void onWebSocketEncodeData(const Buffer::Ptr &buffer) override; void onWebSocketEncodeData(const Buffer::Ptr &buffer) override;
/**
* webSocket数据包后回调
* @param header
*/
void onWebSocketDecodeComplete(const WebSocketHeader &header_in) override;
private: private:
void Handle_Req_GET(int64_t &content_len); void Handle_Req_GET(int64_t &content_len);
void Handle_Req_GET_l(int64_t &content_len, bool sendBody); void Handle_Req_GET_l(int64_t &content_len, bool sendBody);
@ -103,10 +111,11 @@ private:
void sendNotFound(bool bClose); void sendNotFound(bool bClose);
void sendResponse(const char *pcStatus, bool bClose, const char *pcContentType = nullptr, void sendResponse(const char *pcStatus, bool bClose, const char *pcContentType = nullptr,
const HttpSession::KeyValue &header = HttpSession::KeyValue(), const HttpSession::KeyValue &header = HttpSession::KeyValue(),
const HttpBody::Ptr &body = nullptr,bool is_http_flv = false); const HttpBody::Ptr &body = nullptr, bool no_content_length = false);
//设置socket标志 //设置socket标志
void setSocketFlags(); void setSocketFlags();
private: private:
string _origin; string _origin;
Parser _parser; Parser _parser;

View File

@ -38,11 +38,10 @@ public:
template<typename ...ArgsType> template<typename ...ArgsType>
ClientTypeImp(ArgsType &&...args): ClientType(std::forward<ArgsType>(args)...){} ClientTypeImp(ArgsType &&...args): ClientType(std::forward<ArgsType>(args)...){}
~ClientTypeImp() override {}; ~ClientTypeImp() override {};
protected: protected:
/** /**
* websocket协议 * websocket协议
* @param buf
* @return
*/ */
int send(const Buffer::Ptr &buf) override{ int send(const Buffer::Ptr &buf) override{
if(_beforeSendCB){ if(_beforeSendCB){
@ -50,6 +49,7 @@ protected:
} }
return ClientType::send(buf); return ClientType::send(buf);
} }
/** /**
* *
* @param cb * @param cb
@ -57,6 +57,7 @@ protected:
void setOnBeforeSendCB(const onBeforeSendCB &cb){ void setOnBeforeSendCB(const onBeforeSendCB &cb){
_beforeSendCB = cb; _beforeSendCB = cb;
} }
private: private:
onBeforeSendCB _beforeSendCB; onBeforeSendCB _beforeSendCB;
}; };
@ -108,6 +109,7 @@ public:
header._mask_flag = true; header._mask_flag = true;
WebSocketSplitter::encode(header, nullptr); WebSocketSplitter::encode(header, nullptr);
} }
protected: protected:
//HttpClientImp override //HttpClientImp override
@ -124,6 +126,8 @@ protected:
if(Sec_WebSocket_Accept == const_cast<HttpHeader &>(headers)["Sec-WebSocket-Accept"]){ if(Sec_WebSocket_Accept == const_cast<HttpHeader &>(headers)["Sec-WebSocket-Accept"]){
//success //success
onWebSocketException(SockException()); onWebSocketException(SockException());
//防止ws服务器返回Content-Length
const_cast<HttpHeader &>(headers).erase("Content-Length");
//后续全是websocket负载数据 //后续全是websocket负载数据
return -1; return -1;
} }
@ -180,7 +184,6 @@ protected:
/** /**
* tcp连接结果 * tcp连接结果
* @param ex
*/ */
void onConnect(const SockException &ex) override{ void onConnect(const SockException &ex) override{
if(ex){ if(ex){
@ -194,7 +197,6 @@ protected:
/** /**
* tcp连接断开 * tcp连接断开
* @param ex
*/ */
void onErr(const SockException &ex) override{ void onErr(const SockException &ex) override{
//tcp断开或者shutdown导致的断开 //tcp断开或者shutdown导致的断开
@ -208,7 +210,7 @@ protected:
* @param header * @param header
*/ */
void onWebSocketDecodeHeader(const WebSocketHeader &header) override{ void onWebSocketDecodeHeader(const WebSocketHeader &header) override{
_payload.clear(); _payload_section.clear();
} }
/** /**
@ -219,10 +221,9 @@ protected:
* @param recved ()header._payload_len时则接受完毕 * @param recved ()header._payload_len时则接受完毕
*/ */
void onWebSocketDecodePayload(const WebSocketHeader &header, const uint8_t *ptr, uint64_t len, uint64_t recved) override{ void onWebSocketDecodePayload(const WebSocketHeader &header, const uint8_t *ptr, uint64_t len, uint64_t recved) override{
_payload.append((char *)ptr,len); _payload_section.append((char *)ptr,len);
} }
/** /**
* webSocket数据包后回调 * webSocket数据包后回调
* @param header * @param header
@ -238,28 +239,46 @@ protected:
//服务器主动关闭 //服务器主动关闭
WebSocketSplitter::encode(header,nullptr); WebSocketSplitter::encode(header,nullptr);
shutdown(SockException(Err_eof,"websocket server close the connection")); shutdown(SockException(Err_eof,"websocket server close the connection"));
}
break; break;
}
case WebSocketHeader::PING:{ case WebSocketHeader::PING:{
//心跳包 //心跳包
header._opcode = WebSocketHeader::PONG; header._opcode = WebSocketHeader::PONG;
WebSocketSplitter::encode(header,std::make_shared<BufferString>(std::move(_payload))); WebSocketSplitter::encode(header,std::make_shared<BufferString>(std::move(_payload_section)));
}
break; break;
case WebSocketHeader::CONTINUATION:{ }
} case WebSocketHeader::CONTINUATION:
break;
case WebSocketHeader::TEXT: case WebSocketHeader::TEXT:
case WebSocketHeader::BINARY:{ case WebSocketHeader::BINARY:{
//接收完毕websocket数据包触发onRecv事件 if (!header._fin) {
_delegate.onRecv(std::make_shared<BufferString>(std::move(_payload))); //还有后续分片数据, 我们先缓存数据,所有分片收集完成才一次性输出
_payload_cache.append(std::move(_payload_section));
if (_payload_cache.size() < MAX_WS_PACKET) {
//还有内存容量缓存分片数据
break;
}
//分片缓存太大,需要清空
}
//最后一个包
if (_payload_cache.empty()) {
//这个包是唯一个分片
_delegate.onRecv(std::make_shared<WebSocketBuffer>(header._opcode, header._fin, std::move(_payload_section)));
break;
}
//这个包由多个分片组成
_payload_cache.append(std::move(_payload_section));
_delegate.onRecv(std::make_shared<WebSocketBuffer>(header._opcode, header._fin, std::move(_payload_cache)));
_payload_cache.clear();
break;
} }
break;
default: default: break;
break;
} }
_payload.clear(); _payload_section.clear();
header._mask_flag = flag; header._mask_flag = flag;
} }
@ -271,6 +290,7 @@ protected:
void onWebSocketEncodeData(const Buffer::Ptr &buffer) override{ void onWebSocketEncodeData(const Buffer::Ptr &buffer) override{
HttpClientImp::send(buffer); HttpClientImp::send(buffer);
} }
private: private:
void onWebSocketException(const SockException &ex){ void onWebSocketException(const SockException &ex){
if(!ex){ if(!ex){
@ -319,10 +339,10 @@ private:
string _Sec_WebSocket_Key; string _Sec_WebSocket_Key;
function<void(const char *data, int len)> _onRecv; function<void(const char *data, int len)> _onRecv;
ClientTypeImp<ClientType,DataType> &_delegate; ClientTypeImp<ClientType,DataType> &_delegate;
string _payload; string _payload_section;
string _payload_cache;
}; };
/** /**
* Tcp客户端转WebSocket客户端模板 * Tcp客户端转WebSocket客户端模板
* TcpClient派生类任何代码的情况下快速实现WebSocket协议的包装 * TcpClient派生类任何代码的情况下快速实现WebSocket协议的包装
@ -365,6 +385,7 @@ public:
void startWebSocket(const string &ws_url,float fTimeOutSec = 3){ void startWebSocket(const string &ws_url,float fTimeOutSec = 3){
_wsClient->startWsClient(ws_url,fTimeOutSec); _wsClient->startWsClient(ws_url,fTimeOutSec);
} }
private: private:
typename HttpWsClient<ClientType,DataType>::Ptr _wsClient; typename HttpWsClient<ClientType,DataType>::Ptr _wsClient;
}; };

View File

@ -78,7 +78,6 @@ public:
} }
}; };
/** /**
* WebSocket协议 * WebSocket协议
* WebSock协议下的具体业务协议WebSocket协议的Rtmp协议等 * WebSock协议下的具体业务协议WebSocket协议的Rtmp协议等
@ -107,8 +106,9 @@ public:
void attachServer(const TcpServer &server) override{ void attachServer(const TcpServer &server) override{
HttpSessionType::attachServer(server); HttpSessionType::attachServer(server);
_weakServer = const_cast<TcpServer &>(server).shared_from_this(); _weak_server = const_cast<TcpServer &>(server).shared_from_this();
} }
protected: protected:
/** /**
* websocket客户端连接上事件 * websocket客户端连接上事件
@ -122,7 +122,7 @@ protected:
//此url不允许创建websocket连接 //此url不允许创建websocket连接
return false; return false;
} }
auto strongServer = _weakServer.lock(); auto strongServer = _weak_server.lock();
if(strongServer){ if(strongServer){
_session->attachServer(*strongServer); _session->attachServer(*strongServer);
} }
@ -145,24 +145,20 @@ protected:
//允许websocket客户端 //允许websocket客户端
return true; return true;
} }
/** /**
* webSocket数据包 * webSocket数据包
* @param packet
*/ */
void onWebSocketDecodeHeader(const WebSocketHeader &packet) override{ void onWebSocketDecodeHeader(const WebSocketHeader &packet) override{
//新包,原来的包残余数据清空掉 //新包,原来的包残余数据清空掉
_remian_data.clear(); _payload_section.clear();
} }
/** /**
* websocket数据包负载 * websocket数据包负载
* @param packet
* @param ptr
* @param len
* @param recved
*/ */
void onWebSocketDecodePayload(const WebSocketHeader &packet,const uint8_t *ptr,uint64_t len,uint64_t recved) override { void onWebSocketDecodePayload(const WebSocketHeader &packet,const uint8_t *ptr,uint64_t len,uint64_t recved) override {
_remian_data.append((char *)ptr,len); _payload_section.append((char *)ptr,len);
} }
/** /**
@ -178,39 +174,59 @@ protected:
case WebSocketHeader::CLOSE:{ case WebSocketHeader::CLOSE:{
HttpSessionType::encode(header,nullptr); HttpSessionType::encode(header,nullptr);
HttpSessionType::shutdown(SockException(Err_shutdown, "recv close request from client")); HttpSessionType::shutdown(SockException(Err_shutdown, "recv close request from client"));
}
break; break;
}
case WebSocketHeader::PING:{ case WebSocketHeader::PING:{
header._opcode = WebSocketHeader::PONG; header._opcode = WebSocketHeader::PONG;
HttpSessionType::encode(header,std::make_shared<BufferString>(_remian_data)); HttpSessionType::encode(header,std::make_shared<BufferString>(_payload_section));
}
break; break;
case WebSocketHeader::CONTINUATION:{
} }
break;
case WebSocketHeader::CONTINUATION:
case WebSocketHeader::TEXT: case WebSocketHeader::TEXT:
case WebSocketHeader::BINARY:{ case WebSocketHeader::BINARY:{
_session->onRecv(std::make_shared<BufferString>(_remian_data)); if (!header._fin) {
//还有后续分片数据, 我们先缓存数据,所有分片收集完成才一次性输出
_payload_cache.append(std::move(_payload_section));
if (_payload_cache.size() < MAX_WS_PACKET) {
//还有内存容量缓存分片数据
break;
}
//分片缓存太大,需要清空
}
//最后一个包
if (_payload_cache.empty()) {
//这个包是唯一个分片
_session->onRecv(std::make_shared<WebSocketBuffer>(header._opcode, header._fin, std::move(_payload_section)));
break;
}
//这个包由多个分片组成
_payload_cache.append(std::move(_payload_section));
_session->onRecv(std::make_shared<WebSocketBuffer>(header._opcode, header._fin, std::move(_payload_cache)));
_payload_cache.clear();
break;
} }
break;
default: default: break;
break;
} }
_remian_data.clear(); _payload_section.clear();
header._mask_flag = flag; header._mask_flag = flag;
} }
/** /**
* websocket协议打包后回调 * websocket协议打包后回调
* @param buffer
*/ */
void onWebSocketEncodeData(const Buffer::Ptr &buffer) override{ void onWebSocketEncodeData(const Buffer::Ptr &buffer) override{
HttpSessionType::send(buffer); HttpSessionType::send(buffer);
} }
private: private:
string _remian_data; string _payload_cache;
weak_ptr<TcpServer> _weakServer; string _payload_section;
weak_ptr<TcpServer> _weak_server;
TcpSession::Ptr _session; TcpSession::Ptr _session;
Creator _creator; Creator _creator;
}; };

View File

@ -16,10 +16,12 @@
#include <vector> #include <vector>
#include <memory> #include <memory>
#include "Network/Buffer.h" #include "Network/Buffer.h"
using namespace std; using namespace std;
using namespace toolkit; using namespace toolkit;
//websocket组合包最大不得超过4MB(防止内存爆炸)
#define MAX_WS_PACKET (4 * 1024 * 1024)
namespace mediakit { namespace mediakit {
class WebSocketHeader { class WebSocketHeader {
@ -44,6 +46,7 @@ public:
CONTROL_RSVF = 0xF CONTROL_RSVF = 0xF
} Type; } Type;
public: public:
WebSocketHeader() : _mask(4){ WebSocketHeader() : _mask(4){
//获取_mask内部buffer的内存地址该内存是malloc开辟的地址为随机 //获取_mask内部buffer的内存地址该内存是malloc开辟的地址为随机
uint64_t ptr = (uint64_t)(&_mask[0]); uint64_t ptr = (uint64_t)(&_mask[0]);
@ -51,6 +54,7 @@ public:
_mask.assign((uint8_t*)(&ptr), (uint8_t*)(&ptr) + 4); _mask.assign((uint8_t*)(&ptr), (uint8_t*)(&ptr) + 4);
} }
virtual ~WebSocketHeader(){} virtual ~WebSocketHeader(){}
public: public:
bool _fin; bool _fin;
uint8_t _reserved; uint8_t _reserved;
@ -60,6 +64,26 @@ public:
vector<uint8_t > _mask; vector<uint8_t > _mask;
}; };
//websocket协议收到的字符串类型缓存用户协议层获取该数据传输的方式
class WebSocketBuffer : public BufferString {
public:
typedef std::shared_ptr<WebSocketBuffer> Ptr;
template<typename ...ARGS>
WebSocketBuffer(WebSocketHeader::Type headType, bool fin, ARGS &&...args)
: _head_type(headType), _fin(fin), BufferString(std::forward<ARGS>(args)...) {}
~WebSocketBuffer() override {}
WebSocketHeader::Type headType() const { return _head_type; }
bool isFinished() const { return _fin; };
private:
WebSocketHeader::Type _head_type;
bool _fin;
};
class WebSocketSplitter : public WebSocketHeader{ class WebSocketSplitter : public WebSocketHeader{
public: public:
WebSocketSplitter(){} WebSocketSplitter(){}
@ -80,6 +104,7 @@ public:
* @param buffer * @param buffer
*/ */
void encode(const WebSocketHeader &header,const Buffer::Ptr &buffer); void encode(const WebSocketHeader &header,const Buffer::Ptr &buffer);
protected: protected:
/** /**
* webSocket数据包包头onWebSocketDecodePayload回调 * webSocket数据包包头onWebSocketDecodePayload回调
@ -96,7 +121,6 @@ protected:
*/ */
virtual void onWebSocketDecodePayload(const WebSocketHeader &header, const uint8_t *ptr, uint64_t len, uint64_t recved) {}; virtual void onWebSocketDecodePayload(const WebSocketHeader &header, const uint8_t *ptr, uint64_t len, uint64_t recved) {};
/** /**
* webSocket数据包后回调 * webSocket数据包后回调
* @param header * @param header
@ -109,8 +133,10 @@ protected:
* @param len * @param len
*/ */
virtual void onWebSocketEncodeData(const Buffer::Ptr &buffer){}; virtual void onWebSocketEncodeData(const Buffer::Ptr &buffer){};
private: private:
void onPayloadData(uint8_t *data, uint64_t len); void onPayloadData(uint8_t *data, uint64_t len);
private: private:
string _remain_data; string _remain_data;
int _mask_offset = 0; int _mask_offset = 0;