diff --git a/CMakeLists.txt b/CMakeLists.txt index 4ffaf764..dd155379 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -106,6 +106,7 @@ option(ENABLE_FFMPEG "Enable FFmpeg" true) option(ENABLE_MSVC_MT "Enable MSVC Mt/Mtd lib" true) option(ENABLE_API_STATIC_LIB "Enable mk_api static lib" false) option(USE_SOLUTION_FOLDERS "Enable solution dir supported" ON) +option(ENABLE_SRT "Enable SRT" true) # ---------------------------------------------------------------------------- # Solution folders: # ---------------------------------------------------------------------------- @@ -486,6 +487,15 @@ if (ENABLE_WEBRTC) endif () endif () +if (ENABLE_SRT) + add_definitions(-DENABLE_SRT) + include_directories(./srt) + file(GLOB SRC_SRT_LIST ./srt/*.cpp ./srt/*.h ./srt/*.hpp) + add_library(srt ${SRC_SRT_LIST}) + list(APPEND LINK_LIB_LIST srt) + message(STATUS "srt 功能已开启") +endif() + #添加c库 if (ENABLE_API) add_subdirectory(api) diff --git a/conf/config.ini b/conf/config.ini index 2c36c22b..7a199215 100644 --- a/conf/config.ini +++ b/conf/config.ini @@ -284,6 +284,14 @@ preferredCodecA=PCMU,PCMA,opus,mpeg4-generic #以下范例为所有支持的视频codec preferredCodecV=H264,H265,AV1X,VP9,VP8 +[srt] +#srt播放推流、播放超时时间,单位秒 +timeoutSec=5 +#srt udp服务器监听端口号,所有srt客户端将通过该端口传输srt数据, +#该端口是多线程的,同时支持客户端网络切换导致的连接迁移 +port=9000 + + [rtsp] #rtsp专有鉴权方式是采用base64还是md5方式 authBasic=0 diff --git a/server/main.cpp b/server/main.cpp index f2dd1f61..f98e9640 100644 --- a/server/main.cpp +++ b/server/main.cpp @@ -32,6 +32,11 @@ #include "../webrtc/WebRtcSession.h" #endif +#if defined(ENABLE_SRT) +#include "../srt/SrtSession.hpp" +#include "../srt/SrtTransport.hpp" +#endif + #if defined(ENABLE_VERSION) #include "version.h" #endif @@ -284,6 +289,24 @@ int start_main(int argc,char *argv[]) { uint16_t rtcPort = mINI::Instance()[RTC::kPort]; #endif//defined(ENABLE_WEBRTC) + +#if defined(ENABLE_SRT) + auto srtSrv = std::make_shared(); + srtSrv->setOnCreateSocket([](const EventPoller::Ptr &poller, const Buffer::Ptr &buf, struct sockaddr *, int) { + if (!buf) { + return Socket::createSocket(poller, false); + } + auto new_poller = SRT::SrtSession::queryPoller(buf); + if (!new_poller) { + //握手第一阶段 + return Socket::createSocket(poller, false); + } + return Socket::createSocket(new_poller, false); + }); + + uint16_t srtPort = mINI::Instance()[SRT::kPort]; +#endif //defined(ENABLE_SRT) + try { //rtsp服务器,端口默认554 if (rtspPort) { rtspSrv->start(rtspPort); } @@ -313,6 +336,14 @@ int start_main(int argc,char *argv[]) { if (rtcPort) { rtcSrv->start(rtcPort); } #endif//defined(ENABLE_WEBRTC) + +#if defined(ENABLE_SRT) + // srt udp服务器 + if(srtPort){ + srtSrv->start(srtPort); + } +#endif//defined(ENABLE_SRT) + } catch (std::exception &ex) { WarnL << "端口占用或无权限:" << ex.what() << endl; ErrorL << "程序启动失败,请修改配置文件中端口号后重试!" << endl; diff --git a/srt/Ack.cpp b/srt/Ack.cpp new file mode 100644 index 00000000..7a61971e --- /dev/null +++ b/srt/Ack.cpp @@ -0,0 +1,74 @@ +#include "Ack.hpp" +#include "Common.hpp" + +namespace SRT { +bool ACKPacket::loadFromData(uint8_t *buf, size_t len) { + if(len < ACK_CIF_SIZE + ControlPacket::HEADER_SIZE){ + return false; + } + + _data = BufferRaw::create(); + _data->assign((char *)(buf), len); + ControlPacket::loadHeader(); + ack_number = loadUint32(type_specific_info); + uint8_t* ptr = (uint8_t*)_data->data()+ControlPacket::HEADER_SIZE; + + last_ack_pkt_seq_number = loadUint32(ptr); + ptr += 4; + + rtt = loadUint32(ptr); + ptr += 4; + + rtt_variance = loadUint32(ptr); + ptr += 4; + + available_buf_size = loadUint32(ptr); + ptr += 4; + + pkt_recv_rate = loadUint32(ptr); + ptr += 4; + + estimated_link_capacity = loadUint32(ptr); + ptr += 4; + + recv_rate = loadUint32(ptr); + ptr += 4; + + return true; +} +bool ACKPacket::storeToData() { + _data = BufferRaw::create(); + _data->setCapacity(HEADER_SIZE + ACK_CIF_SIZE); + _data->setSize(HEADER_SIZE + ACK_CIF_SIZE); + control_type = ControlPacket::ACK; + sub_type = 0; + + storeUint32(type_specific_info,ack_number); + storeToHeader(); + + uint8_t* ptr = (uint8_t*)_data->data()+ControlPacket::HEADER_SIZE; + + storeUint32(ptr,last_ack_pkt_seq_number); + ptr += 4; + + storeUint32(ptr,rtt); + ptr += 4; + + storeUint32(ptr,rtt_variance); + ptr += 4; + + storeUint32(ptr,pkt_recv_rate); + ptr += 4; + + storeUint32(ptr,available_buf_size); + ptr += 4; + + storeUint32(ptr,estimated_link_capacity); + ptr += 4; + + storeUint32(ptr,recv_rate); + ptr += 4; + + return true; +} +} // namespace \ No newline at end of file diff --git a/srt/Ack.hpp b/srt/Ack.hpp new file mode 100644 index 00000000..15a026db --- /dev/null +++ b/srt/Ack.hpp @@ -0,0 +1,96 @@ +#ifndef ZLMEDIAKIT_SRT_ACK_H +#define ZLMEDIAKIT_SRT_ACK_H +#include "Packet.hpp" + + +namespace SRT{ +/* +0 1 2 3 +0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 ++-+-+-+-+-+-+-+-+-+-+-+-+- SRT Header +-+-+-+-+-+-+-+-+-+-+-+-+-+ +|1| Control Type | Reserved | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Acknowledgement Number | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Timestamp | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Destination Socket ID | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+- CIF -+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Last Acknowledged Packet Sequence Number | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| RTT | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| RTT Variance | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Available Buffer Size | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Packets Receiving Rate | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Estimated Link Capacity | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Receiving Rate | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + Figure 13: ACK control packet + https://haivision.github.io/srt-rfc/draft-sharabayko-srt.html#name-ack-acknowledgment +*/ +class ACKPacket : public ControlPacket +{ +public: + using Ptr = std::shared_ptr; + ACKPacket() = default; + ~ACKPacket() = default; + + enum{ + ACK_CIF_SIZE = 7*4 + }; + + ///////ControlPacket override/////// + bool loadFromData(uint8_t *buf, size_t len) override; + bool storeToData() override; + + uint32_t ack_number; + + uint32_t last_ack_pkt_seq_number; + uint32_t rtt; + uint32_t rtt_variance; + uint32_t available_buf_size; + uint32_t pkt_recv_rate; + uint32_t estimated_link_capacity; + uint32_t recv_rate; +}; + + +class ACKACKPacket : public ControlPacket{ +public: + using Ptr = std::shared_ptr; + ACKACKPacket() = default; + ~ACKACKPacket() = default; + ///////ControlPacket override/////// + bool loadFromData(uint8_t *buf, size_t len) override{ + if(len < ControlPacket::HEADER_SIZE){ + return false; + } + _data = BufferRaw::create(); + _data->assign((char *)(buf), len); + ControlPacket::loadHeader(); + ack_number = loadUint32(type_specific_info); + return true; + } + bool storeToData() override{ + _data = BufferRaw::create(); + _data->setCapacity(HEADER_SIZE); + _data->setSize(HEADER_SIZE ); + control_type = ControlPacket::ACKACK; + sub_type = 0; + + storeUint32(type_specific_info,ack_number); + storeToHeader(); + return true; + } + + uint32_t ack_number; + +}; + +} //namespace SRT +#endif // ZLMEDIAKIT_SRT_ACK_H \ No newline at end of file diff --git a/srt/Common.hpp b/srt/Common.hpp new file mode 100644 index 00000000..cfb80240 --- /dev/null +++ b/srt/Common.hpp @@ -0,0 +1,54 @@ +#ifndef ZLMEDIAKIT_SRT_COMMON_H +#define ZLMEDIAKIT_SRT_COMMON_H +#include + +namespace SRT +{ +using SteadyClock = std::chrono::steady_clock; +using TimePoint = std::chrono::time_point; + +using Microseconds = std::chrono::microseconds; +using Milliseconds = std::chrono::milliseconds; + +inline int64_t DurationCountMicroseconds( SteadyClock::duration dur){ + return std::chrono::duration_cast(dur).count(); +} + +inline uint32_t loadUint32(uint8_t *ptr) { + return ptr[0] << 24 | ptr[1] << 16 | ptr[2] << 8 | ptr[3]; +} +inline uint16_t loadUint16(uint8_t *ptr) { + return ptr[0] << 8 | ptr[1]; +} + +inline void storeUint32(uint8_t *buf, uint32_t val) { + buf[0] = val >> 24; + buf[1] = (val >> 16) & 0xff; + buf[2] = (val >> 8) & 0xff; + buf[3] = val & 0xff; +} + +inline void storeUint16(uint8_t *buf, uint16_t val) { + buf[0] = (val >> 8) & 0xff; + buf[1] = val & 0xff; +} + +inline void storeUint32LE(uint8_t *buf, uint32_t val) { + buf[0] = val & 0xff; + buf[1] = (val >> 8) & 0xff; + buf[2] = (val >> 16) & 0xff; + buf[3] = (val >>24) & 0xff; +} + +inline void storeUint16LE(uint8_t *buf, uint16_t val) { + buf[0] = val & 0xff; + buf[1] = (val>>8) & 0xff; +} + +inline uint32_t srtVersion(int major, int minor, int patch) +{ + return patch + minor*0x100 + major*0x10000; +} +} // namespace SRT + +#endif //ZLMEDIAKIT_SRT_COMMON_H \ No newline at end of file diff --git a/srt/HSExt.cpp b/srt/HSExt.cpp new file mode 100644 index 00000000..5398e48e --- /dev/null +++ b/srt/HSExt.cpp @@ -0,0 +1,127 @@ +#include "HSExt.hpp" + +namespace SRT { +bool HSExtMessage::loadFromData(uint8_t *buf, size_t len) { + if(buf == NULL || len != HSEXT_MSG_SIZE){ + return false; + } + + _data = BufferRaw::create(); + _data->assign((char*)buf,len); + extension_length = 3; + HSExt::loadHeader(); + + assert(extension_type == SRT_CMD_HSREQ || extension_type == SRT_CMD_HSRSP); + + uint8_t* ptr = (uint8_t*)_data->data()+4; + srt_version = loadUint32(ptr); + ptr += 4; + + srt_flag = loadUint32(ptr); + ptr += 4; + + recv_tsbpd_delay = loadUint16(ptr); + ptr += 2; + + send_tsbpd_delay = loadUint16(ptr); + ptr += 2; + + return true; + + } + std::string HSExtMessage::dump(){ + _StrPrinter printer; + printer << "srt version : "<; + uint16_t extension_type; + uint16_t extension_length; + virtual bool loadFromData(uint8_t *buf, size_t len) = 0; + virtual bool storeToData() = 0; + virtual std::string dump() = 0; + ///////Buffer override/////// + char *data() const override { + if (_data) { + return _data->data(); + } + return nullptr; + }; + size_t size() const override { + if (_data) { + return _data->size(); + } + return 0; + }; + +protected: + void loadHeader() { + uint8_t *ptr = (uint8_t *)_data->data(); + extension_type = loadUint16(ptr); + ptr += 2; + extension_length = loadUint16(ptr); + ptr += 2; + } + void storeHeader() { + uint8_t *ptr = (uint8_t *)_data->data(); + SRT::storeUint16(ptr, extension_type); + ptr += 2; + storeUint16(ptr, extension_length); + } + +protected: + BufferRaw::Ptr _data; +}; + +/* + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| SRT Version | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| SRT Flags | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Receiver TSBPD Delay | Sender TSBPD Delay | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + Figure 6: Handshake Extension Message structure + https://haivision.github.io/srt-rfc/draft-sharabayko-srt.html#name-handshake-extension-message + +*/ +class HSExtMessage : public HSExt { +public: + using Ptr = std::shared_ptr; + enum { + HS_EXT_MSG_TSBPDSND = 0x00000001, + HS_EXT_MSG_TSBPDRCV = 0x00000002, + HS_EXT_MSG_CRYPT = 0x00000004, + HS_EXT_MSG_TLPKTDROP = 0x00000008, + HS_EXT_MSG_PERIODICNAK = 0x00000010, + HS_EXT_MSG_REXMITFLG = 0x00000020, + HS_EXT_MSG_STREAM = 0x00000040, + HS_EXT_MSG_PACKET_FILTER = 0x00000080 + }; + enum { HSEXT_MSG_SIZE = 16 }; + HSExtMessage() = default; + ~HSExtMessage() = default; + bool loadFromData(uint8_t *buf, size_t len) override; + bool storeToData() override; + std::string dump() override; + uint32_t srt_version; + uint32_t srt_flag; + uint16_t recv_tsbpd_delay; + uint16_t send_tsbpd_delay; +}; + +/* + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| | +| Stream ID | + ... +| | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + Figure 7: Stream ID Extension Message + https://haivision.github.io/srt-rfc/draft-sharabayko-srt.html#name-stream-id-extension-message +*/ +class HSExtStreamID : public HSExt { +public: + using Ptr = std::shared_ptr; + HSExtStreamID() = default; + ~HSExtStreamID() = default; + bool loadFromData(uint8_t *buf, size_t len) override; + bool storeToData() override; + std::string dump() override; + std::string streamid; +}; +} // namespace SRT +#endif // ZLMEDIAKIT_SRT_HS_EXT_H \ No newline at end of file diff --git a/srt/Packet.cpp b/srt/Packet.cpp new file mode 100644 index 00000000..af4b3772 --- /dev/null +++ b/srt/Packet.cpp @@ -0,0 +1,574 @@ + +#include "sys/socket.h" +#include "netdb.h" + +#include +#include "Util/logger.h" +#include "Util/MD5.h" + +#include "Packet.hpp" + + + +namespace SRT { + + +const size_t DataPacket::HEADER_SIZE; +const size_t ControlPacket::HEADER_SIZE; +const size_t HandshakePacket::HS_CONTENT_MIN_SIZE; + +bool DataPacket::isDataPacket(uint8_t *buf, size_t len) { + if (len < HEADER_SIZE) { + WarnL << "data size" << len << " less " << HEADER_SIZE; + return false; + } + if (!(buf[0] & 0x80)) { + return true; + } + return false; +} + +uint32_t DataPacket::getSocketID(uint8_t *buf, size_t len){ + uint8_t *ptr = buf; + ptr += 12; + return loadUint32(ptr); +} + +bool DataPacket::loadFromData(uint8_t *buf, size_t len) { + if (len < HEADER_SIZE) { + WarnL << "data size" << len << " less " << HEADER_SIZE; + return false; + } + uint8_t *ptr = buf; + f = ptr[0] >> 7; + packet_seq_number = loadUint32(ptr)&0x7fffffff; + ptr += 4; + + PP = ptr[0] >> 6; + O = (ptr[0] & 0x20) >> 5; + KK = (ptr[0] & 0x18) >> 3; + R = (ptr[0] & 0x04) >> 2; + msg_number = (ptr[0] & 0x03) << 24 | ptr[1] << 12 | ptr[2] << 8 | ptr[3]; + ptr += 4; + + timestamp = loadUint32(ptr); + ptr += 4; + + dst_socket_id = loadUint32(ptr); + ptr += 4; + + _data = BufferRaw::create(); + _data->assign((char *)(buf), len); + return true; +} + +bool DataPacket::storeToData(uint8_t *buf, size_t len) { + _data = BufferRaw::create(); + _data->setCapacity(len + HEADER_SIZE); + _data->setSize(len + HEADER_SIZE); + + uint8_t *ptr = (uint8_t *)_data->data(); + + ptr[0] = packet_seq_number >> 24; + ptr[1] = (packet_seq_number >> 16) & 0xff; + ptr[2] = (packet_seq_number >> 8) & 0xff; + ptr[3] = packet_seq_number & 0xff; + ptr += 4; + + ptr[0] = PP << 6; + ptr[0] |= O << 5; + ptr[0] |= KK << 3; + ptr[0] |= R << 2; + ptr[0] |= (msg_number & 0xff000000) >> 24; + ptr[1] = (msg_number & 0xff0000) >> 16; + ptr[2] = (msg_number & 0xff00) >> 8; + ptr[3] = msg_number & 0xff; + ptr += 4; + + storeUint32(ptr, timestamp); + ptr += 4; + + storeUint32(ptr, dst_socket_id); + ptr += 4; + + memcpy(ptr, buf, len); + return true; +} + +char *DataPacket::data() const { + if (!_data) + return nullptr; + return _data->data(); +} +size_t DataPacket::size() const { + if (!_data) { + return 0; + } + return _data->size(); +} + +char *DataPacket::payloadData() { + if (!_data) + return nullptr; + return _data->data() + HEADER_SIZE; +} +size_t DataPacket::payloadSize() { + if (!_data) { + return 0; + } + return _data->size() - HEADER_SIZE; +} + + + +bool ControlPacket::isControlPacket(uint8_t *buf, size_t len) { + if (len < HEADER_SIZE) { + WarnL << "data size" << len << " less " << HEADER_SIZE; + return false; + } + if (buf[0] & 0x80) { + return true; + } + return false; +} + +uint16_t ControlPacket::getControlType(uint8_t *buf, size_t len) { + uint8_t *ptr = buf; + uint16_t control_type = (ptr[0] & 0x7f) << 8 | ptr[1]; + return control_type; +} + +bool ControlPacket::loadHeader() { + uint8_t *ptr = (uint8_t *)_data->data(); + f = ptr[0] >> 7; + control_type = (ptr[0] & 0x7f) << 8 | ptr[1]; + ptr += 2; + + sub_type = loadUint16(ptr); + ptr += 2; + + type_specific_info[0] = ptr[0]; + type_specific_info[1] = ptr[1]; + type_specific_info[2] = ptr[2]; + type_specific_info[3] = ptr[3]; + ptr += 4; + + timestamp = loadUint32(ptr); + ptr += 4; + + dst_socket_id = loadUint32(ptr); + ptr += 4; + return true; +} +bool ControlPacket::storeToHeader() { + uint8_t *ptr = (uint8_t *)_data->data(); + ptr[0] = 0x80; + ptr[0] |= control_type >> 8; + ptr[1] = control_type & 0xff; + ptr += 2; + + storeUint16(ptr, sub_type); + ptr += 2; + + ptr[0] = type_specific_info[0]; + ptr[1] = type_specific_info[1]; + ptr[2] = type_specific_info[2]; + ptr[3] = type_specific_info[3]; + ptr += 4; + + storeUint32(ptr, timestamp); + ptr += 4; + + storeUint32(ptr, dst_socket_id); + ptr += 4; + return true; +} + +char *ControlPacket::data() const { + if (!_data) + return nullptr; + return _data->data(); +} +size_t ControlPacket::size() const { + if (!_data) { + return 0; + } + return _data->size(); +} +uint32_t ControlPacket::getSocketID(uint8_t *buf, size_t len){ + return loadUint32(buf+12); +} +bool HandshakePacket::loadFromData(uint8_t *buf, size_t len) { + if(HEADER_SIZE+HS_CONTENT_MIN_SIZE > len){ + ErrorL << "size too smalle " << encryption_field; + return false; + } + _data = BufferRaw::create(); + _data->assign((char *)(buf), len); + ControlPacket::loadHeader(); + + uint8_t *ptr = (uint8_t *)_data->data() + HEADER_SIZE; + // parse CIF + version = loadUint32(ptr); + ptr += 4; + + encryption_field = loadUint16(ptr); + ptr += 2; + + extension_field = loadUint16(ptr); + ptr += 2; + + initial_packet_sequence_number = loadUint32(ptr); + ptr += 4; + + mtu = loadUint32(ptr); + ptr += 4; + + max_flow_window_size = loadUint32(ptr); + ptr += 4; + + handshake_type = loadUint32(ptr); + ptr += 4; + + srt_socket_id = loadUint32(ptr); + ptr += 4; + + syn_cookie = loadUint32(ptr); + ptr += 4; + + memcpy(peer_ip_addr, ptr, sizeof(peer_ip_addr) * sizeof(peer_ip_addr[0])); + ptr += sizeof(peer_ip_addr) * sizeof(peer_ip_addr[0]); + + if (encryption_field != NO_ENCRYPTION) { + ErrorL << "not support encryption " << encryption_field; + } + + if(extension_field == 0){ + return true; + } + + if(len == HEADER_SIZE+HS_CONTENT_MIN_SIZE){ + //ErrorL << "extension filed not exist " << extension_field; + return true; + } + + return loadExtMessage(ptr,len-HS_CONTENT_MIN_SIZE-HEADER_SIZE); +} +bool HandshakePacket::loadExtMessage(uint8_t *buf,size_t len){ + uint8_t* ptr = buf; + ext_list.clear(); + uint16_t type; + uint16_t length; + HSExt::Ptr ext; + while(ptr(); + break; + case HSExt::SRT_CMD_SID: + ext = std::make_shared(); + break; + default: + WarnL<<"not support ext "<loadFromData(ptr,length*4+4)){ + ext_list.push_back(std::move(ext)); + }else{ + WarnL<<"parse HS EXT failed type="<ss_family == AF_INET){ + struct sockaddr_in * ipv4 = (struct sockaddr_in *)addr; + //抓包 奇怪好像是小头端??? + storeUint32LE(peer_ip_addr,ipv4->sin_addr.s_addr); + }else{ + const sockaddr_in6* ipv6 = (struct sockaddr_in6 *)addr; + memcpy(peer_ip_addr,ipv6->sin6_addr.s6_addr,sizeof(peer_ip_addr)*sizeof(peer_ip_addr[0])); + } +} +uint32_t HandshakePacket::generateSynCookie(struct sockaddr_storage* addr,TimePoint ts,uint32_t current_cookie, int correction ){ + + static std::atomic distractor{0}; + uint32_t rollover = distractor.load() + 10; + + for (;;) + { + // SYN cookie + char clienthost[NI_MAXHOST]; + char clientport[NI_MAXSERV]; + getnameinfo((struct sockaddr*)addr, + sizeof(struct sockaddr_storage), + clienthost, + sizeof(clienthost), + clientport, + sizeof(clientport), + NI_NUMERICHOST | NI_NUMERICSERV); + int64_t timestamp = (DurationCountMicroseconds(SteadyClock::now() - ts) / 60000000) + distractor.load() + + correction; // secret changes every one minute + std::stringstream cookiestr; + cookiestr << clienthost << ":" << clientport << ":" << timestamp; + union { + unsigned char cookie[16]; + uint32_t cookie_val; + }; + MD5 md5(cookiestr.str()); + memcpy(cookie,md5.rawdigest().c_str(),16); + + if (cookie_val != current_cookie) + return cookie_val; + + ++distractor; + + // This is just to make the loop formally breakable, + // but this is virtually impossible to happen. + if (distractor == rollover) + return cookie_val; + } +} + +bool KeepLivePacket::loadFromData(uint8_t *buf, size_t len){ + if (len < HEADER_SIZE) { + WarnL << "data size" << len << " less " << HEADER_SIZE; + return false; + } + _data = BufferRaw::create(); + _data->assign((char*)buf,len); + + return loadHeader(); +} +bool KeepLivePacket::storeToData(){ + control_type = ControlPacket::KEEPALIVE; + sub_type = 0; + + _data = BufferRaw::create(); + _data->setCapacity(HEADER_SIZE); + _data->setSize(HEADER_SIZE); + return storeToHeader(); +} + +bool NAKPacket::loadFromData(uint8_t *buf, size_t len) { + if (len < HEADER_SIZE) { + WarnL << "data size" << len << " less " << HEADER_SIZE; + return false; + } + _data = BufferRaw::create(); + _data->assign((char*)buf,len); + loadHeader(); + + uint8_t* ptr = (uint8_t*)_data->data()+HEADER_SIZE; + uint8_t* end = (uint8_t*)_data->data()+_data->size(); + LostPair lost; + while (ptrsetCapacity(HEADER_SIZE+cif_size); + _data->setSize(HEADER_SIZE+cif_size); + + storeToHeader(); + + uint8_t* ptr = (uint8_t*)_data->data()+HEADER_SIZE; + + for(auto it : lost_list){ + if(it.first+1 ==it.second){ + storeUint32(ptr,it.first); + ptr[0] = ptr[0]&0x7f; + ptr += 4; + }else{ + storeUint32(ptr,it.first); + ptr[0] |= 0x80; + + storeUint32(ptr+4,it.second-1); + //ptr[4] = ptr[4]&0x7f; + + ptr += 8; + } + } + + return true; +} + +size_t NAKPacket::getCIFSize(){ + size_t size = 0; + for(auto it : lost_list){ + if(it.first+1 ==it.second){ + size += 4; + }else{ + size += 8; + } + } + return size; +} + +std::string NAKPacket::dump(){ + _StrPrinter printer; + for (auto it : lost_list) { + printer<<"[ "<assign((char*)buf,len); + loadHeader(); + + uint8_t* ptr = (uint8_t*)_data->data()+HEADER_SIZE; + + first_pkt_seq_num = loadUint32(ptr); + ptr += 4; + + last_pkt_seq_num = loadUint32(ptr); + ptr += 4; + return true; +} +bool MsgDropReqPacket::storeToData() { + control_type = DROPREQ; + sub_type = 0; + _data = BufferRaw::create(); + _data->setCapacity(HEADER_SIZE+8); + _data->setSize(HEADER_SIZE+8); + + storeToHeader(); + + uint8_t* ptr = (uint8_t*)_data->data()+HEADER_SIZE; + + storeUint32(ptr,first_pkt_seq_num); + ptr += 4; + + storeUint32(ptr,last_pkt_seq_num); + ptr += 4; + return true; +} +} // namespace SRT \ No newline at end of file diff --git a/srt/Packet.hpp b/srt/Packet.hpp new file mode 100644 index 00000000..f2fbc98a --- /dev/null +++ b/srt/Packet.hpp @@ -0,0 +1,317 @@ +#ifndef ZLMEDIAKIT_SRT_PACKET_H +#define ZLMEDIAKIT_SRT_PACKET_H + +#include +#include + +#include "Network/Buffer.h" + +#include "Common.hpp" +#include "HSExt.hpp" + +namespace SRT { + +using namespace toolkit; +/* + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 ++-+-+-+-+-+-+-+-+-+-+-+-+- SRT Header +-+-+-+-+-+-+-+-+-+-+-+-+-+ +|0| Packet Sequence Number | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +|P P|O|K K|R| Message Number | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Timestamp | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Destination Socket ID | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| | ++ Data + +| | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + Figure 3: Data packet structure + reference https://haivision.github.io/srt-rfc/draft-sharabayko-srt.html#name-packet-structure +*/ +class DataPacket : public Buffer { +public: + using Ptr = std::shared_ptr; + DataPacket() = default; + ~DataPacket() = default; + + static const size_t HEADER_SIZE = 16; + static bool isDataPacket(uint8_t *buf, size_t len); + static uint32_t getSocketID(uint8_t *buf, size_t len); + bool loadFromData(uint8_t *buf, size_t len); + bool storeToData(uint8_t *buf, size_t len); + + ///////Buffer override/////// + char *data() const override; + size_t size() const override; + + char *payloadData(); + size_t payloadSize(); + + uint8_t f; + uint32_t packet_seq_number; + uint8_t PP; + uint8_t O; + uint8_t KK; + uint8_t R; + uint32_t msg_number; + uint32_t timestamp; + uint32_t dst_socket_id; + +private: + BufferRaw::Ptr _data; +}; +/* + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 ++-+-+-+-+-+-+-+-+-+-+-+-+- SRT Header +-+-+-+-+-+-+-+-+-+-+-+-+-+ +|1| Control Type | Subtype | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Type-specific Information | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Timestamp | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Destination Socket ID | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+- CIF -+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| | ++ Control Information Field + +| | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + Figure 4: Control packet structure + reference https://haivision.github.io/srt-rfc/draft-sharabayko-srt.html#name-control-packets +*/ +class ControlPacket : public Buffer { +public: + using Ptr = std::shared_ptr; + static const size_t HEADER_SIZE = 16; + static bool isControlPacket(uint8_t *buf, size_t len); + static uint16_t getControlType(uint8_t *buf, size_t len); + static uint32_t getSocketID(uint8_t *buf, size_t len); + + ControlPacket() = default; + virtual ~ControlPacket() = default; + virtual bool loadFromData(uint8_t *buf, size_t len) = 0; + virtual bool storeToData() = 0; + + bool loadHeader(); + bool storeToHeader(); + + ///////Buffer override/////// + char *data() const override; + size_t size() const override; + + enum { + HANDSHAKE = 0x0000, + KEEPALIVE = 0x0001, + ACK = 0x0002, + NAK = 0x0003, + CONGESTIONWARNING = 0x0004, + SHUTDOWN = 0x0005, + ACKACK = 0x0006, + DROPREQ = 0x0007, + PEERERROR = 0x0008, + USERDEFINEDTYPE = 0x7FFF + }; + + uint32_t sub_type : 16; + uint32_t control_type : 15; + uint32_t f : 1; + uint8_t type_specific_info[4]; + uint32_t timestamp; + uint32_t dst_socket_id; + +protected: + BufferRaw::Ptr _data; +}; + +/** + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Version | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Encryption Field | Extension Field | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Initial Packet Sequence Number | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Maximum Transmission Unit Size | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Maximum Flow Window Size | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Handshake Type | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| SRT Socket ID | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| SYN Cookie | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| | ++ + +| | ++ Peer IP Address + +| | ++ + +| | ++=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ +| Extension Type | Extension Length | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| | ++ Extension Contents + +| | ++=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ + Figure 5: Handshake packet structure + https://haivision.github.io/srt-rfc/draft-sharabayko-srt.html#name-handshake + */ +class HandshakePacket : public ControlPacket { +public: + using Ptr = std::shared_ptr; + enum { NO_ENCRYPTION = 0, AES_128 = 1, AES_196 = 2, AES_256 = 3 }; + static const size_t HS_CONTENT_MIN_SIZE = 48; + enum { + HS_TYPE_DONE = 0xFFFFFFFD, + HS_TYPE_AGREEMENT = 0xFFFFFFFE, + HS_TYPE_CONCLUSION = 0xFFFFFFFF, + HS_TYPE_WAVEHAND = 0x00000000, + HS_TYPE_INDUCTION = 0x00000001 + }; + + enum { HS_EXT_FILED_HSREQ = 0x00000001, HS_EXT_FILED_KMREQ = 0x00000002, HS_EXT_FILED_CONFIG = 0x00000004 }; + + + + HandshakePacket() = default; + ~HandshakePacket() = default; + + static bool isHandshakePacket(uint8_t *buf, size_t len); + static uint32_t getHandshakeType(uint8_t *buf, size_t len); + static uint32_t getSynCookie(uint8_t *buf, size_t len); + static uint32_t generateSynCookie(struct sockaddr_storage* addr,TimePoint ts,uint32_t current_cookie = 0, int correction = 0); + + void assignPeerIP(struct sockaddr_storage* addr); + ///////ControlPacket override/////// + bool loadFromData(uint8_t *buf, size_t len) override; + bool storeToData() override; + + uint32_t version; + uint16_t encryption_field; + uint16_t extension_field; + uint32_t initial_packet_sequence_number; + uint32_t mtu; + uint32_t max_flow_window_size; + uint32_t handshake_type; + uint32_t srt_socket_id; + uint32_t syn_cookie; + uint8_t peer_ip_addr[16]; + + std::vector ext_list; +private: + bool loadExtMessage(uint8_t *buf,size_t len); + bool storeExtMessage(); + size_t getExtSize(); +}; +/* + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 ++-+-+-+-+-+-+-+-+-+-+-+-+- SRT Header +-+-+-+-+-+-+-+-+-+-+-+-+-+ +|1| Control Type | Reserved | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Type-specific Information | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Timestamp | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Destination Socket ID | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + Figure 12: Keep-Alive control packet + https://haivision.github.io/srt-rfc/draft-sharabayko-srt.html#name-keep-alive +*/ +class KeepLivePacket : public ControlPacket +{ +public: + using Ptr = std::shared_ptr; + KeepLivePacket() = default; + ~KeepLivePacket() = default; + ///////ControlPacket override/////// + bool loadFromData(uint8_t *buf, size_t len) override; + bool storeToData() override; +}; + +/* +An SRT NAK packet is formatted as follows: + + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 ++-+-+-+-+-+-+-+-+-+-+-+-+- SRT Header +-+-+-+-+-+-+-+-+-+-+-+-+-+ +|1| Control Type | Reserved | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Type-specific Information | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Timestamp | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Destination Socket ID | ++-+-+-+-+-+-+-+-+-+-+-+- CIF (Loss List) -+-+-+-+-+-+-+-+-+-+-+-+ +|0| Lost packet sequence number | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +|1| Range of lost packets from sequence number | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +|0| Up to sequence number | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +|0| Lost packet sequence number | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + Figure 14: NAK control packet + https://haivision.github.io/srt-rfc/draft-sharabayko-srt.html#name-nak-control-packet +*/ +class NAKPacket : public ControlPacket +{ +public: + using Ptr = std::shared_ptr; + using LostPair = std::pair; + NAKPacket() = default; + ~NAKPacket() = default; + std::string dump(); + ///////ControlPacket override/////// + bool loadFromData(uint8_t *buf, size_t len) override; + bool storeToData() override; + + std::list lost_list; +private: + size_t getCIFSize(); +}; + + +/* + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 ++-+-+-+-+-+-+-+-+-+-+-+-+- SRT Header +-+-+-+-+-+-+-+-+-+-+-+-+-+ +|1| Control Type = 7 | Reserved = 0 | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Message Number | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Timestamp | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Destination Socket ID | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| First Packet Sequence Number | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Last Packet Sequence Number | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + Figure 18: Drop Request control packet + https://haivision.github.io/srt-rfc/draft-sharabayko-srt.html#name-message-drop-request +*/ +class MsgDropReqPacket : public ControlPacket +{ + public: + using Ptr = std::shared_ptr; + MsgDropReqPacket() = default; + ~MsgDropReqPacket() = default; + ///////ControlPacket override/////// + bool loadFromData(uint8_t *buf, size_t len) override; + bool storeToData() override; + + uint32_t first_pkt_seq_num; + uint32_t last_pkt_seq_num; +}; + +} // namespace SRT + +#endif //ZLMEDIAKIT_SRT_PACKET_H \ No newline at end of file diff --git a/srt/PacketQueue.cpp b/srt/PacketQueue.cpp new file mode 100644 index 00000000..ea342277 --- /dev/null +++ b/srt/PacketQueue.cpp @@ -0,0 +1,126 @@ +#include "PacketQueue.hpp" + +namespace SRT { +PacketQueue::PacketQueue(uint32_t max_size, uint32_t init_seq, uint32_t lantency) + : _pkt_expected_seq(init_seq) + , _pkt_cap(max_size) + , _pkt_lantency(lantency) { + } + +bool PacketQueue::inputPacket(DataPacket::Ptr pkt) { + if (pkt->packet_seq_number < _pkt_expected_seq) { + // TOO later drop this packet + return false; + } + + _pkt_map[pkt->packet_seq_number] = pkt; + + return true; +} + +std::list PacketQueue::tryGetPacket() { + std::list re; + while (_pkt_map.find(_pkt_expected_seq) != _pkt_map.end()) { + re.push_back(_pkt_map[_pkt_expected_seq]); + _pkt_map.erase(_pkt_expected_seq); + _pkt_expected_seq++; + } + + while (_pkt_map.size() > _pkt_cap) { + // force pop some packet + auto it = _pkt_map.begin(); + re.push_back(it->second); + _pkt_expected_seq = it->second->packet_seq_number + 1; + _pkt_map.erase(it); + } + + while (timeLantency() > _pkt_lantency) { + auto it = _pkt_map.begin(); + re.push_back(it->second); + _pkt_expected_seq = it->second->packet_seq_number + 1; + _pkt_map.erase(it); + } + + return std::move(re); +} + + +bool PacketQueue::dropForRecv(uint32_t first,uint32_t last){ + if(first >= last){ + return false; + } + + if(_pkt_expected_seq <= last){ + _pkt_expected_seq = last+1; + return true; + } + + return false; +} +uint32_t PacketQueue::timeLantency() { + if (_pkt_map.empty()) { + return 0; + } + + auto first = _pkt_map.begin()->second; + auto last = _pkt_map.rbegin()->second; + + return last->timestamp - first->timestamp; +} + +std::list PacketQueue::getLostSeq() { + std::list re; + if(_pkt_map.empty()){ + return re; + } + + if(getExpectedSize() == getSize()){ + return re; + } + + PacketQueue::LostPair lost; + lost.first = 0; + lost.second = 0; + + uint32_t i = _pkt_expected_seq; + bool finish = true; + for(i = _pkt_expected_seq;i<=_pkt_map.rbegin()->first;++i){ + if(_pkt_map.find(i) == _pkt_map.end()){ + if(finish){ + finish = false; + lost.first = i; + lost.second = i+1; + }else{ + lost.second = i+1; + } + }else{ + + if(!finish){ + finish = true; + re.push_back(lost); + } + } + } + + return re; +} + +size_t PacketQueue::getSize(){ + return _pkt_map.size(); +} + +size_t PacketQueue::getExpectedSize() { + if(_pkt_map.empty()){ + return 0; + } + return _pkt_map.rbegin()->first - _pkt_expected_seq+1; +} + +size_t PacketQueue::getAvailableBufferSize(){ + return _pkt_cap - getExpectedSize(); +} + +uint32_t PacketQueue::getExpectedSeq(){ + return _pkt_expected_seq; +} +} // namespace SRT \ No newline at end of file diff --git a/srt/PacketQueue.hpp b/srt/PacketQueue.hpp new file mode 100644 index 00000000..05996647 --- /dev/null +++ b/srt/PacketQueue.hpp @@ -0,0 +1,44 @@ +#ifndef ZLMEDIAKIT_SRT_PACKET_QUEUE_H +#define ZLMEDIAKIT_SRT_PACKET_QUEUE_H +#include +#include +#include +#include +#include + +#include "Packet.hpp" + +namespace SRT{ + +class PacketQueue +{ +public: + using Ptr = std::shared_ptr; + using LostPair = std::pair; + + PacketQueue(uint32_t max_size,uint32_t init_seq,uint32_t lantency); + ~PacketQueue() = default; + bool inputPacket(DataPacket::Ptr pkt); + std::list tryGetPacket(); + uint32_t timeLantency(); + std::list getLostSeq(); + + size_t getSize(); + size_t getExpectedSize(); + size_t getAvailableBufferSize(); + uint32_t getExpectedSeq(); + + bool dropForRecv(uint32_t first,uint32_t last); + + +private: + std::map _pkt_map; + + uint32_t _pkt_expected_seq = 0; + uint32_t _pkt_cap; + uint32_t _pkt_lantency; +}; + +} + +#endif //ZLMEDIAKIT_SRT_PACKET_QUEUE_H \ No newline at end of file diff --git a/srt/SrtSession.cpp b/srt/SrtSession.cpp new file mode 100644 index 00000000..a20246d4 --- /dev/null +++ b/srt/SrtSession.cpp @@ -0,0 +1,137 @@ +#include "SrtSession.hpp" +#include "Packet.hpp" +#include "SrtTransport.hpp" + +#include "Common/config.h" + +namespace SRT { +using namespace mediakit; + +SrtSession::SrtSession(const Socket::Ptr &sock) + : UdpSession(sock) { + socklen_t addr_len = sizeof(_peer_addr); + getpeername(sock->rawFD(), (struct sockaddr *)&_peer_addr, &addr_len); +} + +SrtSession::~SrtSession() { + InfoP(this); +} + +EventPoller::Ptr SrtSession::queryPoller(const Buffer::Ptr &buffer) { + uint8_t* data = (uint8_t*)buffer->data(); + size_t size = buffer->size(); + + if(DataPacket::isDataPacket(data,size)){ + uint32_t socket_id = DataPacket::getSocketID(data,size); + auto trans = SrtTransportManager::Instance().getItem(std::to_string(socket_id)); + return trans ? trans->getPoller() : nullptr; + } + + if(HandshakePacket::isHandshakePacket(data,size)){ + auto type = HandshakePacket::getHandshakeType(data,size); + if(type == HandshakePacket::HS_TYPE_INDUCTION){ + // 握手第一阶段 + return nullptr; + }else if(type == HandshakePacket::HS_TYPE_CONCLUSION){ + // 握手第二阶段 + uint32_t sync_cookie = HandshakePacket::getSynCookie(data,size); + auto trans = SrtTransportManager::Instance().getHandshakeItem(std::to_string(sync_cookie)); + return trans ? trans->getPoller() : nullptr; + }else{ + WarnL<<" not reach there"; + } + }else{ + uint32_t socket_id = ControlPacket::getSocketID(data,size); + auto trans = SrtTransportManager::Instance().getItem(std::to_string(socket_id)); + return trans ? trans->getPoller() : nullptr; + } + return nullptr; +} + +void SrtSession::onRecv(const Buffer::Ptr &buffer) { + uint8_t* data = (uint8_t*)buffer->data(); + size_t size = buffer->size(); + + if (_find_transport) { + //只允许寻找一次transport + _find_transport = false; + + if (DataPacket::isDataPacket(data, size)) { + uint32_t socket_id = DataPacket::getSocketID(data, size); + auto trans = SrtTransportManager::Instance().getItem(std::to_string(socket_id)); + if(trans){ + _transport = std::move(trans); + }else{ + WarnL<<" data packet not find transport "; + } + } + + if (HandshakePacket::isHandshakePacket(data, size)) { + auto type = HandshakePacket::getHandshakeType(data, size); + if (type == HandshakePacket::HS_TYPE_INDUCTION) { + // 握手第一阶段 + _transport = std::make_shared(getPoller()); + + } else if (type == HandshakePacket::HS_TYPE_CONCLUSION) { + // 握手第二阶段 + uint32_t sync_cookie = HandshakePacket::getSynCookie(data, size); + auto trans = SrtTransportManager::Instance().getHandshakeItem(std::to_string(sync_cookie)); + if (trans) { + _transport = std::move(trans); + } else { + WarnL << " hanshake packet not find transport "; + } + } else { + WarnL << " not reach there"; + } + } else { + uint32_t socket_id = ControlPacket::getSocketID(data, size); + auto trans = SrtTransportManager::Instance().getItem(std::to_string(socket_id)); + if(trans){ + _transport = std::move(trans); + }else{ + WarnL << " not find transport"; + } + } + + if(_transport){ + _transport->setSession(shared_from_this()); + } + InfoP(this); + } + _ticker.resetTime(); + + if(_transport){ + _transport->inputSockData(data,size,&_peer_addr); + }else{ + WarnL<< "ingore data"; + } +} + +void SrtSession::onError(const SockException &err) { + // udp链接超时,但是srt链接不一定超时,因为可能存在udp链接迁移的情况 + //在udp链接迁移时,新的SrtSession对象将接管SrtSession对象的生命周期 + //本SrtSession对象将在超时后自动销毁 + WarnP(this) << err.what(); + + if (!_transport) { + return; + } + + // 防止互相引用导致不释放 + auto transport = std::move(_transport); + getPoller()->async([transport,err] { + //延时减引用,防止使用transport对象时,销毁对象 + transport->onShutdown(err); + }, false); +} + +void SrtSession::onManager() { + GET_CONFIG(float, timeoutSec, kTimeOutSec); + if (_ticker.elapsedTime() > timeoutSec*1000) { + shutdown(SockException(Err_timeout, "srt connection timeout")); + return; + } +} + +} // namespace SRT \ No newline at end of file diff --git a/srt/SrtSession.hpp b/srt/SrtSession.hpp new file mode 100644 index 00000000..9478f1b9 --- /dev/null +++ b/srt/SrtSession.hpp @@ -0,0 +1,31 @@ +#ifndef ZLMEDIAKIT_SRT_SESSION_H +#define ZLMEDIAKIT_SRT_SESSION_H + +#include "Network/Session.h" +#include "SrtTransport.hpp" + +namespace SRT { + +using namespace toolkit; + +class SrtSession : public UdpSession { +public: + SrtSession(const Socket::Ptr &sock); + ~SrtSession() override; + + void onRecv(const Buffer::Ptr &) override; + void onError(const SockException &err) override; + void onManager() override; + + static EventPoller::Ptr queryPoller(const Buffer::Ptr &buffer); + +private: + bool _find_transport = true; + Ticker _ticker; + struct sockaddr_storage _peer_addr; + SrtTransport::Ptr _transport; + +}; + +} // namespace SRT +#endif //ZLMEDIAKIT_SRT_SESSION_H \ No newline at end of file diff --git a/srt/SrtTransport.cpp b/srt/SrtTransport.cpp new file mode 100644 index 00000000..4086ffe2 --- /dev/null +++ b/srt/SrtTransport.cpp @@ -0,0 +1,465 @@ +#include "Util/onceToken.h" + +#include "SrtTransport.hpp" +#include "Packet.hpp" +#include "Ack.hpp" +namespace SRT { +#define SRT_FIELD "srt." +//srt 超时时间 +const std::string kTimeOutSec = SRT_FIELD"timeoutSec"; +//srt 单端口udp服务器 +const std::string kPort = SRT_FIELD"port"; + +static std::atomic s_srt_socket_id_generate{125}; +//////////// SrtTransport ////////////////////////// +SrtTransport::SrtTransport(const EventPoller::Ptr &poller) + : _poller(poller) { + _start_timestamp = SteadyClock::now(); + _socket_id = s_srt_socket_id_generate.fetch_add(1); + } + +SrtTransport::~SrtTransport(){ + TraceL<<" "; +} +const EventPoller::Ptr &SrtTransport::getPoller() const { + return _poller; +} + +void SrtTransport::setSession(Session::Ptr session) { + _history_sessions.emplace(session.get(), session); + if (_selected_session) { + InfoL << "srt network changed: " << _selected_session->get_peer_ip() << ":" + << _selected_session->get_peer_port() << " -> " << session->get_peer_ip() << ":" + << session->get_peer_port() << ", id:" << _selected_session->getIdentifier(); + } + _selected_session = session; +} +const Session::Ptr &SrtTransport::getSession() const { + return _selected_session; +} + +void SrtTransport::switchToOtherTransport(uint8_t *buf, int len,uint32_t socketid, struct sockaddr_storage *addr){ + BufferRaw::Ptr tmp = BufferRaw::create(); + struct sockaddr_storage tmp_addr = *addr; + tmp->assign((char*)buf,len); + auto trans = SrtTransportManager::Instance().getItem(std::to_string(socketid)); + if(trans){ + trans->getPoller()->async([tmp,tmp_addr,trans]{ + trans->inputSockData((uint8_t*)tmp->data(),tmp->size(),(struct sockaddr_storage*)&tmp_addr); + }); + } +} + +void SrtTransport::inputSockData(uint8_t *buf, int len, struct sockaddr_storage *addr) { + using srt_control_handler = void (SrtTransport::*)(uint8_t* buf,int len,struct sockaddr_storage *addr); + static std::unordered_map s_control_functions; + static onceToken token([]() { + s_control_functions.emplace(ControlPacket::HANDSHAKE, &SrtTransport::handleHandshake); + s_control_functions.emplace(ControlPacket::KEEPALIVE, &SrtTransport::handleKeeplive); + s_control_functions.emplace(ControlPacket::ACK, &SrtTransport::handleACK); + s_control_functions.emplace(ControlPacket::NAK, &SrtTransport::handleNAK); + s_control_functions.emplace(ControlPacket::CONGESTIONWARNING, &SrtTransport::handleCongestionWarning); + s_control_functions.emplace(ControlPacket::SHUTDOWN, &SrtTransport::handleShutDown); + s_control_functions.emplace(ControlPacket::ACKACK, &SrtTransport::handleACKACK); + s_control_functions.emplace(ControlPacket::DROPREQ, &SrtTransport::handleDropReq); + s_control_functions.emplace(ControlPacket::PEERERROR, &SrtTransport::handlePeerError); + s_control_functions.emplace(ControlPacket::USERDEFINEDTYPE, &SrtTransport::handleUserDefinedType); + }); + auto now = SteadyClock::now(); + // 处理srt数据 + if (DataPacket::isDataPacket(buf, len)) { + uint32_t socketId = DataPacket::getSocketID(buf,len); + if(socketId == _socket_id){ + _pkt_recv_rate_context.inputPacket(now); + _estimated_link_capacity_context.inputPacket(now); + _recv_rate_context.inputPacket(now, len); + + handleDataPacket(buf, len, addr); + }else{ + switchToOtherTransport(buf,len,socketId,addr); + } + } else { + if (ControlPacket::isControlPacket(buf, len)) { + uint32_t socketId = ControlPacket::getSocketID(buf,len); + uint16_t type = ControlPacket::getControlType(buf,len); + if(type != ControlPacket::HANDSHAKE && socketId != _socket_id && _socket_id != 0){ + // socket id not same + switchToOtherTransport(buf,len,socketId,addr); + return; + } + _pkt_recv_rate_context.inputPacket(now); + _estimated_link_capacity_context.inputPacket(now); + _recv_rate_context.inputPacket(now, len); + + auto it = s_control_functions.find(type); + if (it == s_control_functions.end()) { + WarnL<<" not support type ignore" << ControlPacket::getControlType(buf,len); + return; + }else{ + (this->*(it->second))(buf,len,addr); + } + } else { + // not reach + WarnL << "not reach this"; + } + } +} + +void SrtTransport::handleHandshakeInduction(HandshakePacket &pkt, struct sockaddr_storage *addr) { + // Induction Phase + TraceL << getIdentifier() << " Induction Phase "; + if (_handleshake_res) { + TraceL << getIdentifier() << " Induction handle repeate "; + sendControlPacket(_handleshake_res, true); + return; + } + + _init_seq_number = pkt.initial_packet_sequence_number; + _max_window_size = pkt.max_flow_window_size; + _mtu = pkt.mtu; + + _peer_socket_id = pkt.srt_socket_id; + HandshakePacket::Ptr res = std::make_shared(); + res->dst_socket_id = _peer_socket_id; + res->timestamp = DurationCountMicroseconds(_start_timestamp.time_since_epoch()); + res->mtu = _mtu; + res->max_flow_window_size = _max_window_size; + res->initial_packet_sequence_number = _init_seq_number; + res->version = 5; + res->encryption_field = HandshakePacket::NO_ENCRYPTION; + res->extension_field = 0x4A17; + res->handshake_type = HandshakePacket::HS_TYPE_INDUCTION; + res->srt_socket_id = _peer_socket_id; + res->syn_cookie = HandshakePacket::generateSynCookie(addr, _start_timestamp); + _sync_cookie = res->syn_cookie; + memcpy(res->peer_ip_addr, pkt.peer_ip_addr, sizeof(pkt.peer_ip_addr) * sizeof(pkt.peer_ip_addr[0])); + _handleshake_res = res; + res->storeToData(); + + registerSelfHandshake(); + sendControlPacket(res, true); +} +void SrtTransport::handleHandshakeConclusion(HandshakePacket &pkt, struct sockaddr_storage *addr) { + if(!_handleshake_res){ + ErrorL<<"must Induction Phase for handleshake "; + return; + } + + if (_handleshake_res->handshake_type == HandshakePacket::HS_TYPE_INDUCTION) { + // first + HSExtMessage::Ptr req; + HSExtStreamID::Ptr sid; + + for (auto ext : pkt.ext_list) { + //TraceL << getIdentifier() << " ext " << ext->dump(); + if (!req) { + req = std::dynamic_pointer_cast(ext); + } + if(!sid){ + sid = std::dynamic_pointer_cast(ext); + } + } + if(sid){ + _stream_id = sid->streamid; + } + TraceL << getIdentifier() << " CONCLUSION Phase "; + HandshakePacket::Ptr res = std::make_shared(); + res->dst_socket_id = _peer_socket_id; + res->timestamp = DurationCountMicroseconds(SteadyClock::now() - _start_timestamp); + res->mtu = _mtu; + res->max_flow_window_size = _max_window_size; + res->initial_packet_sequence_number = _init_seq_number; + res->version = 5; + res->encryption_field = HandshakePacket::NO_ENCRYPTION; + res->extension_field = HandshakePacket::HS_EXT_FILED_HSREQ; + res->handshake_type = HandshakePacket::HS_TYPE_CONCLUSION; + res->srt_socket_id = _socket_id; + res->syn_cookie = 0; + res->assignPeerIP(addr); + HSExtMessage::Ptr ext = std::make_shared(); + ext->extension_type = HSExt::SRT_CMD_HSRSP; + ext->srt_version = srtVersion(1, 5, 0); + ext->srt_flag = req->srt_flag; + ext->recv_tsbpd_delay = ext->send_tsbpd_delay = req->recv_tsbpd_delay; + res->ext_list.push_back(std::move(ext)); + res->storeToData(); + _handleshake_res = res; + unregisterSelfHandshake(); + registerSelf(); + sendControlPacket(res, true); + TraceL<<" buf size = "<max_flow_window_size<<" init seq ="<<_init_seq_number<<" lantency="<recv_tsbpd_delay; + _recv_buf = std::make_shared(res->max_flow_window_size,_init_seq_number, req->recv_tsbpd_delay*1e6); + onHandShakeFinished(_stream_id); + } else { + TraceL << getIdentifier() << " CONCLUSION handle repeate "; + sendControlPacket(_handleshake_res, true); + } +} +void SrtTransport::handleHandshake(uint8_t *buf, int len, struct sockaddr_storage *addr){ + HandshakePacket pkt; + assert(pkt.loadFromData(buf,len)); + + if(pkt.handshake_type == HandshakePacket::HS_TYPE_INDUCTION){ + handleHandshakeInduction(pkt,addr); + }else if(pkt.handshake_type == HandshakePacket::HS_TYPE_CONCLUSION){ + handleHandshakeConclusion(pkt,addr); + }else{ + WarnL<<" not support handshake type = "<< pkt.handshake_type; + } + _ack_ticker.resetTime(); + _nak_ticker.resetTime(); +} +void SrtTransport::handleKeeplive(uint8_t *buf, int len, struct sockaddr_storage *addr){ + TraceL; +} +void SrtTransport::handleACK(uint8_t *buf, int len, struct sockaddr_storage *addr){ + TraceL; + auto now = SteadyClock::now(); + ACKPacket ack; + ack.loadFromData(buf,len); + + ACKACKPacket::Ptr pkt = std::make_shared(); + pkt->dst_socket_id = _peer_socket_id; + pkt->timestamp = DurationCountMicroseconds(now -_start_timestamp); + pkt->ack_number = ack.ack_number; + pkt->storeToData(); + sendControlPacket(pkt,true); +} +void SrtTransport::handleNAK(uint8_t *buf, int len, struct sockaddr_storage *addr){ + TraceL; +} +void SrtTransport::handleCongestionWarning(uint8_t *buf, int len, struct sockaddr_storage *addr){ + TraceL; +} +void SrtTransport::handleShutDown(uint8_t *buf, int len, struct sockaddr_storage *addr){ + TraceL; + onShutdown(SockException(Err_shutdown, "peer close connection")); +} +void SrtTransport::handleDropReq(uint8_t *buf, int len, struct sockaddr_storage *addr){ + MsgDropReqPacket pkt; + pkt.loadFromData(buf,len); + TraceL<<"drop "<dropForRecv(pkt.first_pkt_seq_num,pkt.last_pkt_seq_num); +} +void SrtTransport::handleUserDefinedType(uint8_t *buf, int len, struct sockaddr_storage *addr){ + TraceL; +} + +void SrtTransport::handleACKACK(uint8_t *buf, int len, struct sockaddr_storage *addr){ + //TraceL; + auto now = SteadyClock::now(); + ACKACKPacket::Ptr pkt = std::make_shared(); + pkt->loadFromData(buf,len); + + uint32_t rtt = DurationCountMicroseconds(now - _ack_send_timestamp[pkt->ack_number]); + _rtt_variance = 3*_rtt_variance/4+abs(_rtt - rtt); + _rtt = 7*rtt/8+_rtt/8; + + _ack_send_timestamp.erase(pkt->ack_number); +} + +void SrtTransport::handlePeerError(uint8_t *buf, int len, struct sockaddr_storage *addr){ + TraceL; +} + +void SrtTransport::sendACKPacket() { + ACKPacket::Ptr pkt=std::make_shared(); + auto now = SteadyClock::now(); + pkt->dst_socket_id = _peer_socket_id; + pkt->timestamp = DurationCountMicroseconds(now - _start_timestamp); + pkt->ack_number = ++_ack_number_count; + pkt->last_ack_pkt_seq_number = _recv_buf->getExpectedSeq(); + pkt->rtt = _rtt; + pkt->rtt_variance = _rtt_variance; + pkt->available_buf_size = _recv_buf->getAvailableBufferSize(); + pkt->pkt_recv_rate = _pkt_recv_rate_context.getPacketRecvRate(); + pkt->estimated_link_capacity = _estimated_link_capacity_context.getEstimatedLinkCapacity(); + pkt->recv_rate = _recv_rate_context.getRecvRate(); + pkt->storeToData(); + _ack_send_timestamp[pkt->ack_number] = now; + sendControlPacket(pkt,true); +} +void SrtTransport::sendLightACKPacket() { + ACKPacket::Ptr pkt=std::make_shared(); + auto now = SteadyClock::now(); + pkt->dst_socket_id = _peer_socket_id; + pkt->timestamp = DurationCountMicroseconds(now - _start_timestamp); + pkt->ack_number = 0; + pkt->last_ack_pkt_seq_number = _recv_buf->getExpectedSeq(); + pkt->rtt = 0; + pkt->rtt_variance = 0; + pkt->available_buf_size = 0; + pkt->pkt_recv_rate = 0; + pkt->estimated_link_capacity = 0; + pkt->recv_rate = 0; + pkt->storeToData(); + sendControlPacket(pkt,true); + +} + +void SrtTransport::sendNAKPacket(std::list& lost_list){ + NAKPacket::Ptr pkt = std::make_shared(); + auto now = SteadyClock::now(); + + pkt->dst_socket_id = _peer_socket_id; + pkt->timestamp = DurationCountMicroseconds(now - _start_timestamp); + pkt->lost_list = lost_list; + + pkt->storeToData(); + + //TraceL<<"send NAK "<dump(); + sendControlPacket(pkt,true); +} +void SrtTransport::handleDataPacket(uint8_t *buf, int len, struct sockaddr_storage *addr){ + DataPacket::Ptr pkt = std::make_shared(); + pkt->loadFromData(buf,len); + if(_ack_ticker.elapsedTime()>=10){ + _light_ack_pkt_count = 0; + _ack_ticker.resetTime(); + // send a ack per 10 ms for receiver + sendACKPacket(); + }else{ + if(_light_ack_pkt_count >= 64){ + // for high bitrate stream send light ack + // TODO + sendLightACKPacket(); + } + _light_ack_pkt_count = 0; + } + + _light_ack_pkt_count++; + + //TraceL<<" seq="<< pkt->packet_seq_number<<" ts="<timestamp<<" size="<payloadSize()<<\ + " PP="<<(int)pkt->PP<<" O="<<(int)pkt->O<<" kK="<<(int)pkt->KK<<" R="<<(int)pkt->R; +#if 1 + _recv_buf->inputPacket(pkt); +#else + if(pkt->packet_seq_number%100 == 0){ + // drop + TraceL<<"drop packet"; + TraceL<<"expected size "<<_recv_buf->getExpectedSize()<<" real size="<<_recv_buf->getSize(); + }else{ + _recv_buf->inputPacket(pkt); + } +#endif + //TraceL<<" data number size "<20 && _nak_ticker.elapsedTime()>nak_interval){ + auto lost = _recv_buf->getLostSeq(); + if(!lost.empty()){ + sendNAKPacket(lost); + //TraceL<<"send NAK"; + } + _nak_ticker.resetTime(); + } + auto list = _recv_buf->tryGetPacket(); + + for(auto data : list){ + onSRTData(std::move(data)); + } +} + +void SrtTransport::sendDataPacket(DataPacket::Ptr pkt,char* buf,int len, bool flush) { + pkt->storeToData((uint8_t*)buf,len); + sendPacket(pkt,flush); +} +void SrtTransport::sendControlPacket(ControlPacket::Ptr pkt, bool flush) { + sendPacket(pkt,flush); +} +void SrtTransport::sendPacket(Buffer::Ptr pkt,bool flush){ + if(_selected_session){ + auto tmp = _packet_pool.obtain2(); + tmp->assign(pkt->data(),pkt->size()); + _selected_session->setSendFlushFlag(flush); + _selected_session->send(std::move(tmp)); + }else{ + WarnL<<"not reach this"; + } +} +std::string SrtTransport::getIdentifier(){ + return _selected_session ? _selected_session->getIdentifier() : ""; +} + +void SrtTransport::registerSelfHandshake() { + SrtTransportManager::Instance().addHandshakeItem(std::to_string(_sync_cookie),shared_from_this()); +} +void SrtTransport::unregisterSelfHandshake() { + if(_sync_cookie == 0){ + return; + } + SrtTransportManager::Instance().removeHandshakeItem(std::to_string(_sync_cookie)); +} + +void SrtTransport::registerSelf() { + if(_socket_id == 0){ + return; + } + SrtTransportManager::Instance().addItem(std::to_string(_socket_id),shared_from_this()); + +} +void SrtTransport::unregisterSelf() { + SrtTransportManager::Instance().removeItem(std::to_string(_socket_id)); +} + +void SrtTransport::onShutdown(const SockException &ex){ + WarnL << ex.what(); + unregisterSelfHandshake(); + unregisterSelf(); + for (auto &pr : _history_sessions) { + auto session = pr.second.lock(); + if (session) { + session->shutdown(ex); + } + } +} +//////////// SrtTransportManager ////////////////////////// +SrtTransportManager &SrtTransportManager::Instance() { + static SrtTransportManager s_instance; + return s_instance; +} + +void SrtTransportManager::addItem(const std::string &key, const SrtTransport::Ptr &ptr) { + std::lock_guard lck(_mtx); + _map[key] = ptr; +} + +SrtTransport::Ptr SrtTransportManager::getItem(const std::string &key) { + if (key.empty()) { + return nullptr; + } + std::lock_guard lck(_mtx); + auto it = _map.find(key); + if (it == _map.end()) { + return nullptr; + } + return it->second.lock(); +} + +void SrtTransportManager::removeItem(const std::string &key) { + std::lock_guard lck(_mtx); + _map.erase(key); +} + +void SrtTransportManager::addHandshakeItem(const std::string &key, const SrtTransport::Ptr &ptr) { + std::lock_guard lck(_handshake_mtx); + _handshake_map[key] = ptr; +} +void SrtTransportManager::removeHandshakeItem(const std::string &key) { + std::lock_guard lck(_handshake_mtx); + _handshake_map.erase(key); +} +SrtTransport::Ptr SrtTransportManager::getHandshakeItem(const std::string &key) { + if (key.empty()) { + return nullptr; + } + std::lock_guard lck(_handshake_mtx); + auto it = _handshake_map.find(key); + if (it == _handshake_map.end()) { + return nullptr; + } + return it->second.lock(); +} + + +} // namespace SRT \ No newline at end of file diff --git a/srt/SrtTransport.hpp b/srt/SrtTransport.hpp new file mode 100644 index 00000000..5ecd24ca --- /dev/null +++ b/srt/SrtTransport.hpp @@ -0,0 +1,143 @@ +#ifndef ZLMEDIAKIT_SRT_TRANSPORT_H +#define ZLMEDIAKIT_SRT_TRANSPORT_H + +#include +#include +#include +#include + +#include "Network/Session.h" +#include "Poller/EventPoller.h" +#include "Util/TimeTicker.h" + +#include "Common.hpp" +#include "Packet.hpp" +#include "PacketQueue.hpp" +#include "Statistic.hpp" + +namespace SRT { +using namespace toolkit; + +extern const std::string kPort; +extern const std::string kTimeOutSec; + +class SrtTransport : public std::enable_shared_from_this { +public: + friend class SrtSession; + using Ptr = std::shared_ptr; + + SrtTransport(const EventPoller::Ptr &poller); + virtual ~SrtTransport(); + const EventPoller::Ptr &getPoller() const; + void setSession(Session::Ptr session); + const Session::Ptr &getSession() const; + /** + * socket收到udp数据 + * @param buf 数据指针 + * @param len 数据长度 + * @param addr 数据来源地址 + */ + void inputSockData(uint8_t *buf, int len, struct sockaddr_storage *addr); + + std::string getIdentifier(); + + void unregisterSelfHandshake(); + void unregisterSelf(); +protected: + virtual void onHandShakeFinished(std::string& streamid){}; + virtual void onSRTData(DataPacket::Ptr pkt){}; + virtual void onShutdown(const SockException &ex); + +private: + void registerSelfHandshake(); + void registerSelf(); + + void switchToOtherTransport(uint8_t *buf, int len,uint32_t socketid, struct sockaddr_storage *addr); + + void handleHandshake(uint8_t *buf, int len, struct sockaddr_storage *addr); + void handleHandshakeInduction(HandshakePacket& pkt,struct sockaddr_storage *addr); + void handleHandshakeConclusion(HandshakePacket& pkt,struct sockaddr_storage *addr); + + void handleKeeplive(uint8_t *buf, int len, struct sockaddr_storage *addr); + void handleACK(uint8_t *buf, int len, struct sockaddr_storage *addr); + void handleACKACK(uint8_t *buf, int len, struct sockaddr_storage *addr); + void handleNAK(uint8_t *buf, int len, struct sockaddr_storage *addr); + void handleCongestionWarning(uint8_t *buf, int len, struct sockaddr_storage *addr); + void handleShutDown(uint8_t *buf, int len, struct sockaddr_storage *addr); + void handleDropReq(uint8_t *buf, int len, struct sockaddr_storage *addr); + void handleUserDefinedType(uint8_t *buf, int len, struct sockaddr_storage *addr); + void handlePeerError(uint8_t *buf, int len, struct sockaddr_storage *addr); + void handleDataPacket(uint8_t *buf, int len, struct sockaddr_storage *addr); + + void sendNAKPacket(std::list& lost_list); + void sendACKPacket(); + void sendLightACKPacket(); +protected: + void sendDataPacket(DataPacket::Ptr pkt,char* buf,int len,bool flush = false); + void sendControlPacket(ControlPacket::Ptr pkt,bool flush = true); + void sendPacket(Buffer::Ptr pkt,bool flush = true); +private: + //当前选中的udp链接 + Session::Ptr _selected_session; + //链接迁移前后使用过的udp链接 + std::unordered_map > _history_sessions; + + EventPoller::Ptr _poller; + + uint32_t _peer_socket_id; + uint32_t _socket_id = 0; + + TimePoint _start_timestamp; + + uint32_t _mtu = 1500; + uint32_t _max_window_size = 8192; + uint32_t _init_seq_number = 0; + + std::string _stream_id; + uint32_t _sync_cookie = 0; + + PacketQueue::Ptr _recv_buf; + uint32_t _rtt = 100*1000; + uint32_t _rtt_variance =50*1000; + uint32_t _light_ack_pkt_count = 0; + uint32_t _ack_number_count = 0; + Ticker _ack_ticker; + std::map _ack_send_timestamp; + + PacketRecvRateContext _pkt_recv_rate_context; + EstimatedLinkCapacityContext _estimated_link_capacity_context; + RecvRateContext _recv_rate_context; + + Ticker _nak_ticker; + + //保持发送的握手消息,防止丢失重发 + HandshakePacket::Ptr _handleshake_res; + + ResourcePool _packet_pool; + +}; + +class SrtTransportManager { +public: + static SrtTransportManager &Instance(); + SrtTransport::Ptr getItem(const std::string &key); + void addItem(const std::string &key, const SrtTransport::Ptr &ptr); + void removeItem(const std::string &key); + + void addHandshakeItem(const std::string &key, const SrtTransport::Ptr &ptr); + void removeHandshakeItem(const std::string &key); + SrtTransport::Ptr getHandshakeItem(const std::string &key); +private: + SrtTransportManager() = default; + +private: + std::mutex _mtx; + std::unordered_map> _map; + + std::mutex _handshake_mtx; + std::unordered_map> _handshake_map; +}; + +} // namespace SRT + +#endif // ZLMEDIAKIT_SRT_TRANSPORT_H \ No newline at end of file diff --git a/srt/Statistic.cpp b/srt/Statistic.cpp new file mode 100644 index 00000000..8728ccdc --- /dev/null +++ b/srt/Statistic.cpp @@ -0,0 +1,76 @@ +#include + +#include "Statistic.hpp" +namespace SRT { +void PacketRecvRateContext::inputPacket(TimePoint ts) { + if(_pkt_map.size()>100){ + _pkt_map.erase(_pkt_map.begin()); + } + _pkt_map.emplace(ts,ts); +} +uint32_t PacketRecvRateContext::getPacketRecvRate() { + if(_pkt_map.size()<2){ + return 0; + } + + auto first = _pkt_map.begin(); + auto last = _pkt_map.rbegin(); + double dur = DurationCountMicroseconds(last->first - first->first)/1000000.0; + double rate = _pkt_map.size()/dur; + return (uint32_t)rate; +} + +void EstimatedLinkCapacityContext::inputPacket(TimePoint ts) { + if(_pkt_map.size()>16){ + _pkt_map.erase(_pkt_map.begin()); + } + _pkt_map.emplace(ts,ts); +} +uint32_t EstimatedLinkCapacityContext::getEstimatedLinkCapacity() { + decltype(_pkt_map.begin()) next; + std::vector tmp; + + for(auto it = _pkt_map.begin();it != _pkt_map.end();++it){ + next = it; + ++next; + if(next != _pkt_map.end()){ + tmp.push_back(next->first -it->first); + }else{ + break; + } + } + std::sort(tmp.begin(),tmp.end()); + if(tmp.empty()){ + return 0; + } + + double dur =DurationCountMicroseconds(tmp[tmp.size()/2])/1e6; + + return (uint32_t)(1.0/dur); + +} + +void RecvRateContext::inputPacket(TimePoint ts, size_t size ) { + if (_pkt_map.size() > 100) { + _pkt_map.erase(_pkt_map.begin()); + } + _pkt_map.emplace(ts, size); +} +uint32_t RecvRateContext::getRecvRate() { + if(_pkt_map.size()<2){ + return 0; + } + + auto first = _pkt_map.begin(); + auto last = _pkt_map.rbegin(); + double dur = DurationCountMicroseconds(last->first - first->first)/1000000.0; + + size_t bytes = 0; + for(auto it : _pkt_map){ + bytes += it.second; + } + double rate = (double)bytes/dur; + return (uint32_t)rate; +} + +} // namespace SRT \ No newline at end of file diff --git a/srt/Statistic.hpp b/srt/Statistic.hpp new file mode 100644 index 00000000..283d4546 --- /dev/null +++ b/srt/Statistic.hpp @@ -0,0 +1,43 @@ +#ifndef ZLMEDIAKIT_SRT_STATISTIC_H +#define ZLMEDIAKIT_SRT_STATISTIC_H +#include + +#include "Common.hpp" +#include "Packet.hpp" + +namespace SRT { +class PacketRecvRateContext { +public: + PacketRecvRateContext() = default; + ~PacketRecvRateContext() = default; + void inputPacket(TimePoint ts); + uint32_t getPacketRecvRate(); +private: + std::map _pkt_map; + +}; + +class EstimatedLinkCapacityContext { +public: + EstimatedLinkCapacityContext() = default; + ~EstimatedLinkCapacityContext() = default; + void inputPacket(TimePoint ts); + uint32_t getEstimatedLinkCapacity(); +private: + std::map _pkt_map; +}; + +class RecvRateContext { +public: + RecvRateContext() = default; + ~RecvRateContext() = default; + void inputPacket(TimePoint ts,size_t size); + uint32_t getRecvRate(); +private: + std::map _pkt_map; +}; + + + +} // namespace SRT +#endif // ZLMEDIAKIT_SRT_STATISTIC_H \ No newline at end of file