diff --git a/CMakeLists.txt b/CMakeLists.txt index 4ffaf764..dd155379 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -106,6 +106,7 @@ option(ENABLE_FFMPEG "Enable FFmpeg" true) option(ENABLE_MSVC_MT "Enable MSVC Mt/Mtd lib" true) option(ENABLE_API_STATIC_LIB "Enable mk_api static lib" false) option(USE_SOLUTION_FOLDERS "Enable solution dir supported" ON) +option(ENABLE_SRT "Enable SRT" true) # ---------------------------------------------------------------------------- # Solution folders: # ---------------------------------------------------------------------------- @@ -486,6 +487,15 @@ if (ENABLE_WEBRTC) endif () endif () +if (ENABLE_SRT) + add_definitions(-DENABLE_SRT) + include_directories(./srt) + file(GLOB SRC_SRT_LIST ./srt/*.cpp ./srt/*.h ./srt/*.hpp) + add_library(srt ${SRC_SRT_LIST}) + list(APPEND LINK_LIB_LIST srt) + message(STATUS "srt 功能已开启") +endif() + #添加c库 if (ENABLE_API) add_subdirectory(api) diff --git a/README.md b/README.md index da0cd2e4..f8ece25a 100644 --- a/README.md +++ b/README.md @@ -107,7 +107,7 @@ - 支持rtp扩展解析 - 支持GOP缓冲,webrtc播放秒开 - 支持datachannel - +- [SRT支持](./srt/srt.md) - 其他 - 支持丰富的restful api以及web hook事件 - 支持简单的telnet调试 diff --git a/README_en.md b/README_en.md index f0f09470..a1d958d1 100644 --- a/README_en.md +++ b/README_en.md @@ -61,7 +61,7 @@ - Support simulcast - Support rtx/nack - Support transport-cc rtcp/rtp ext - +- [SRT support](./srt/srt_en.md) - Others - Support stream proxy by ffmpeg. - RESTful http api and http hook event api. diff --git a/conf/config.ini b/conf/config.ini index 2c36c22b..5a468bd7 100644 --- a/conf/config.ini +++ b/conf/config.ini @@ -284,6 +284,17 @@ preferredCodecA=PCMU,PCMA,opus,mpeg4-generic #以下范例为所有支持的视频codec preferredCodecV=H264,H265,AV1X,VP9,VP8 +[srt] +#srt播放推流、播放超时时间,单位秒 +timeoutSec=5 +#srt udp服务器监听端口号,所有srt客户端将通过该端口传输srt数据, +#该端口是多线程的,同时支持客户端网络切换导致的连接迁移 +port=9000 + +#srt 协议中延迟缓存的估算参数,在握手阶段估算rtt ,然后latencyMul*rtt 为最大缓存时长,此参数越大,表示等待重传的时长就越大 +latencyMul=4 + + [rtsp] #rtsp专有鉴权方式是采用base64还是md5方式 authBasic=0 diff --git a/server/main.cpp b/server/main.cpp index f2dd1f61..f98e9640 100644 --- a/server/main.cpp +++ b/server/main.cpp @@ -32,6 +32,11 @@ #include "../webrtc/WebRtcSession.h" #endif +#if defined(ENABLE_SRT) +#include "../srt/SrtSession.hpp" +#include "../srt/SrtTransport.hpp" +#endif + #if defined(ENABLE_VERSION) #include "version.h" #endif @@ -284,6 +289,24 @@ int start_main(int argc,char *argv[]) { uint16_t rtcPort = mINI::Instance()[RTC::kPort]; #endif//defined(ENABLE_WEBRTC) + +#if defined(ENABLE_SRT) + auto srtSrv = std::make_shared(); + srtSrv->setOnCreateSocket([](const EventPoller::Ptr &poller, const Buffer::Ptr &buf, struct sockaddr *, int) { + if (!buf) { + return Socket::createSocket(poller, false); + } + auto new_poller = SRT::SrtSession::queryPoller(buf); + if (!new_poller) { + //握手第一阶段 + return Socket::createSocket(poller, false); + } + return Socket::createSocket(new_poller, false); + }); + + uint16_t srtPort = mINI::Instance()[SRT::kPort]; +#endif //defined(ENABLE_SRT) + try { //rtsp服务器,端口默认554 if (rtspPort) { rtspSrv->start(rtspPort); } @@ -313,6 +336,14 @@ int start_main(int argc,char *argv[]) { if (rtcPort) { rtcSrv->start(rtcPort); } #endif//defined(ENABLE_WEBRTC) + +#if defined(ENABLE_SRT) + // srt udp服务器 + if(srtPort){ + srtSrv->start(srtPort); + } +#endif//defined(ENABLE_SRT) + } catch (std::exception &ex) { WarnL << "端口占用或无权限:" << ex.what() << endl; ErrorL << "程序启动失败,请修改配置文件中端口号后重试!" << endl; diff --git a/src/Common/MediaSource.cpp b/src/Common/MediaSource.cpp index 47987424..4a3b4373 100644 --- a/src/Common/MediaSource.cpp +++ b/src/Common/MediaSource.cpp @@ -38,6 +38,7 @@ string getOriginTypeString(MediaOriginType type){ SWITCH_CASE(mp4_vod); SWITCH_CASE(device_chn); SWITCH_CASE(rtc_push); + SWITCH_CASE(srt_push); default : return "unknown"; } } diff --git a/src/Common/MediaSource.h b/src/Common/MediaSource.h index ea9da097..4ecbc22a 100644 --- a/src/Common/MediaSource.h +++ b/src/Common/MediaSource.h @@ -45,6 +45,7 @@ enum class MediaOriginType : uint8_t { mp4_vod, device_chn, rtc_push, + srt_push }; std::string getOriginTypeString(MediaOriginType type); diff --git a/srt/Ack.cpp b/srt/Ack.cpp new file mode 100644 index 00000000..f0afc017 --- /dev/null +++ b/srt/Ack.cpp @@ -0,0 +1,85 @@ +#include "Ack.hpp" +#include "Common.hpp" + +namespace SRT { + +bool ACKPacket::loadFromData(uint8_t *buf, size_t len) { + if (len < ACK_CIF_SIZE + ControlPacket::HEADER_SIZE) { + return false; + } + + _data = BufferRaw::create(); + _data->assign((char *)(buf), len); + ControlPacket::loadHeader(); + ack_number = loadUint32(type_specific_info); + uint8_t *ptr = (uint8_t *)_data->data() + ControlPacket::HEADER_SIZE; + + last_ack_pkt_seq_number = loadUint32(ptr); + ptr += 4; + + rtt = loadUint32(ptr); + ptr += 4; + + rtt_variance = loadUint32(ptr); + ptr += 4; + + available_buf_size = loadUint32(ptr); + ptr += 4; + + pkt_recv_rate = loadUint32(ptr); + ptr += 4; + + estimated_link_capacity = loadUint32(ptr); + ptr += 4; + + recv_rate = loadUint32(ptr); + ptr += 4; + + return true; +} + +bool ACKPacket::storeToData() { + _data = BufferRaw::create(); + _data->setCapacity(HEADER_SIZE + ACK_CIF_SIZE); + _data->setSize(HEADER_SIZE + ACK_CIF_SIZE); + control_type = ControlPacket::ACK; + sub_type = 0; + + storeUint32(type_specific_info, ack_number); + storeToHeader(); + + uint8_t *ptr = (uint8_t *)_data->data() + ControlPacket::HEADER_SIZE; + + storeUint32(ptr, last_ack_pkt_seq_number); + ptr += 4; + + storeUint32(ptr, rtt); + ptr += 4; + + storeUint32(ptr, rtt_variance); + ptr += 4; + + storeUint32(ptr, pkt_recv_rate); + ptr += 4; + + storeUint32(ptr, available_buf_size); + ptr += 4; + + storeUint32(ptr, estimated_link_capacity); + ptr += 4; + + storeUint32(ptr, recv_rate); + ptr += 4; + + return true; +} + +std::string ACKPacket::dump() { + _StrPrinter printer; + printer << "last_ack_pkt_seq_number=" << last_ack_pkt_seq_number << " rtt=" << rtt + << " rtt_variance=" << rtt_variance << " pkt_recv_rate=" << pkt_recv_rate + << " available_buf_size=" << available_buf_size << " estimated_link_capacity=" << estimated_link_capacity + << " recv_rate=" << recv_rate; + return std::move(printer); +} +} // namespace SRT \ No newline at end of file diff --git a/srt/Ack.hpp b/srt/Ack.hpp new file mode 100644 index 00000000..50e21002 --- /dev/null +++ b/srt/Ack.hpp @@ -0,0 +1,90 @@ +#ifndef ZLMEDIAKIT_SRT_ACK_H +#define ZLMEDIAKIT_SRT_ACK_H +#include "Packet.hpp" + +namespace SRT { +/* +0 1 2 3 +0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 ++-+-+-+-+-+-+-+-+-+-+-+-+- SRT Header +-+-+-+-+-+-+-+-+-+-+-+-+-+ +|1| Control Type | Reserved | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Acknowledgement Number | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Timestamp | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Destination Socket ID | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+- CIF -+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Last Acknowledged Packet Sequence Number | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| RTT | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| RTT Variance | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Available Buffer Size | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Packets Receiving Rate | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Estimated Link Capacity | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Receiving Rate | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + Figure 13: ACK control packet + https://haivision.github.io/srt-rfc/draft-sharabayko-srt.html#name-ack-acknowledgment +*/ +class ACKPacket : public ControlPacket { +public: + using Ptr = std::shared_ptr; + ACKPacket() = default; + ~ACKPacket() = default; + + enum { ACK_CIF_SIZE = 7 * 4 }; + std::string dump(); + ///////ControlPacket override/////// + bool loadFromData(uint8_t *buf, size_t len) override; + bool storeToData() override; + + uint32_t ack_number; + + uint32_t last_ack_pkt_seq_number; + uint32_t rtt; + uint32_t rtt_variance; + uint32_t available_buf_size; + uint32_t pkt_recv_rate; + uint32_t estimated_link_capacity; + uint32_t recv_rate; +}; + +class ACKACKPacket : public ControlPacket { +public: + using Ptr = std::shared_ptr; + ACKACKPacket() = default; + ~ACKACKPacket() = default; + ///////ControlPacket override/////// + bool loadFromData(uint8_t *buf, size_t len) override { + if (len < ControlPacket::HEADER_SIZE) { + return false; + } + _data = BufferRaw::create(); + _data->assign((char *)(buf), len); + ControlPacket::loadHeader(); + ack_number = loadUint32(type_specific_info); + return true; + } + bool storeToData() override { + _data = BufferRaw::create(); + _data->setCapacity(HEADER_SIZE); + _data->setSize(HEADER_SIZE); + control_type = ControlPacket::ACKACK; + sub_type = 0; + + storeUint32(type_specific_info, ack_number); + storeToHeader(); + return true; + } + + uint32_t ack_number; +}; + +} // namespace SRT +#endif // ZLMEDIAKIT_SRT_ACK_H \ No newline at end of file diff --git a/srt/Common.hpp b/srt/Common.hpp new file mode 100644 index 00000000..de5d3aaa --- /dev/null +++ b/srt/Common.hpp @@ -0,0 +1,96 @@ +#ifndef ZLMEDIAKIT_SRT_COMMON_H +#define ZLMEDIAKIT_SRT_COMMON_H +#if defined(_WIN32) +#include +#include + +#include +#pragma comment(lib, "Ws2_32.lib") +#pragma comment(lib, "Iphlpapi.lib") +#else +#include +#include +#endif // defined(_WIN32) + +#include +#define MAX_SEQ 0x7fffffff +#define MAX_TS 0xffffffff + +namespace SRT { +using SteadyClock = std::chrono::steady_clock; +using TimePoint = std::chrono::time_point; + +using Microseconds = std::chrono::microseconds; +using Milliseconds = std::chrono::milliseconds; + +static inline int64_t DurationCountMicroseconds(SteadyClock::duration dur) { + return std::chrono::duration_cast(dur).count(); +} + +static inline uint32_t loadUint32(uint8_t *ptr) { + return ptr[0] << 24 | ptr[1] << 16 | ptr[2] << 8 | ptr[3]; +} + +static inline uint16_t loadUint16(uint8_t *ptr) { + return ptr[0] << 8 | ptr[1]; +} + +static inline void storeUint32(uint8_t *buf, uint32_t val) { + buf[0] = val >> 24; + buf[1] = (val >> 16) & 0xff; + buf[2] = (val >> 8) & 0xff; + buf[3] = val & 0xff; +} + +static inline void storeUint16(uint8_t *buf, uint16_t val) { + buf[0] = (val >> 8) & 0xff; + buf[1] = val & 0xff; +} + +static inline void storeUint32LE(uint8_t *buf, uint32_t val) { + buf[0] = val & 0xff; + buf[1] = (val >> 8) & 0xff; + buf[2] = (val >> 16) & 0xff; + buf[3] = (val >> 24) & 0xff; +} + +static inline void storeUint16LE(uint8_t *buf, uint16_t val) { + buf[0] = val & 0xff; + buf[1] = (val >> 8) & 0xff; +} + +static inline uint32_t srtVersion(int major, int minor, int patch) { + return patch + minor * 0x100 + major * 0x10000; +} +static inline uint32_t genExpectedSeq(uint32_t seq) { + return MAX_SEQ & seq; +} + +class UTicker { +public: + UTicker() { _created = _begin = SteadyClock::now(); } + ~UTicker() = default; + + /** + * 获取创建时间,单位微妙 + */ + int64_t elapsedTime(TimePoint now) const { return DurationCountMicroseconds(now - _begin); } + + /** + * 获取上次resetTime后至今的时间,单位毫秒 + */ + int64_t createdTime(TimePoint now) const { return DurationCountMicroseconds(now - _created); } + + /** + * 重置计时器 + */ + void resetTime(TimePoint now) { _begin = now; } + +private: + TimePoint _begin; + TimePoint _created; +}; + +} // namespace SRT + +#endif // ZLMEDIAKIT_SRT_COMMON_H \ No newline at end of file diff --git a/srt/HSExt.cpp b/srt/HSExt.cpp new file mode 100644 index 00000000..d12b2b3c --- /dev/null +++ b/srt/HSExt.cpp @@ -0,0 +1,134 @@ +#include "HSExt.hpp" + +namespace SRT { + +bool HSExtMessage::loadFromData(uint8_t *buf, size_t len) { + if (buf == NULL || len != HSEXT_MSG_SIZE) { + return false; + } + + _data = BufferRaw::create(); + _data->assign((char *)buf, len); + extension_length = 3; + HSExt::loadHeader(); + + assert(extension_type == SRT_CMD_HSREQ || extension_type == SRT_CMD_HSRSP); + + uint8_t *ptr = (uint8_t *)_data->data() + 4; + srt_version = loadUint32(ptr); + ptr += 4; + + srt_flag = loadUint32(ptr); + ptr += 4; + + recv_tsbpd_delay = loadUint16(ptr); + ptr += 2; + + send_tsbpd_delay = loadUint16(ptr); + ptr += 2; + + return true; +} + +std::string HSExtMessage::dump() { + _StrPrinter printer; + printer << "srt version : " << std::hex << srt_version << " srt flag : " << std::hex << srt_flag + << " recv_tsbpd_delay=" << recv_tsbpd_delay << " send_tsbpd_delay = " << send_tsbpd_delay; + return std::move(printer); +} + +bool HSExtMessage::storeToData() { + _data = BufferRaw::create(); + _data->setCapacity(HSEXT_MSG_SIZE); + _data->setSize(HSEXT_MSG_SIZE); + extension_length = 3; + HSExt::storeHeader(); + uint8_t *ptr = (uint8_t *)_data->data() + 4; + + storeUint32(ptr, srt_version); + ptr += 4; + + storeUint32(ptr, srt_flag); + ptr += 4; + + storeUint16(ptr, recv_tsbpd_delay); + ptr += 2; + + storeUint16(ptr, send_tsbpd_delay); + ptr += 2; + return true; +} + +bool HSExtStreamID::loadFromData(uint8_t *buf, size_t len) { + if (buf == NULL || len < 4) { + return false; + } + _data = BufferRaw::create(); + _data->assign((char *)buf, len); + + HSExt::loadHeader(); + + size_t content_size = extension_length * 4; + if (len < content_size + 4) { + return false; + } + streamid.clear(); + char *ptr = _data->data() + 4; + + for (size_t i = 0; i < extension_length; ++i) { + streamid.push_back(*(ptr + 3)); + streamid.push_back(*(ptr + 2)); + streamid.push_back(*(ptr + 1)); + streamid.push_back(*(ptr)); + ptr += 4; + } + char zero = 0x00; + if (streamid.back() == zero) { + streamid.erase(streamid.find_first_of(zero), streamid.size()); + } + return true; +} + +bool HSExtStreamID::storeToData() { + size_t content_size = ((streamid.length() + 4) + 3) / 4 * 4; + + _data = BufferRaw::create(); + _data->setCapacity(content_size); + _data->setSize(content_size); + extension_length = (content_size - 4) / 4; + extension_type = SRT_CMD_SID; + HSExt::storeHeader(); + auto ptr = _data->data() + 4; + memset(ptr, 0, extension_length * 4); + const char *src = streamid.c_str(); + for (size_t i = 0; i < streamid.length() / 4; ++i) { + *ptr = *(src + 3 + i * 4); + ptr++; + + *ptr = *(src + 2 + i * 4); + ptr++; + + *ptr = *(src + 1 + i * 4); + ptr++; + + *ptr = *(src + 0 + i * 4); + ptr++; + } + + ptr += 3; + size_t offset = streamid.length() / 4 * 4; + for (size_t i = 0; i < streamid.length() % 4; ++i) { + *ptr = *(src + offset + i); + ptr -= 1; + } + + return true; +} + +std::string HSExtStreamID::dump() { + _StrPrinter printer; + printer << " streamid : " << streamid; + return std::move(printer); +} + +} // namespace SRT \ No newline at end of file diff --git a/srt/HSExt.hpp b/srt/HSExt.hpp new file mode 100644 index 00000000..55eba100 --- /dev/null +++ b/srt/HSExt.hpp @@ -0,0 +1,129 @@ +#ifndef ZLMEDIAKIT_SRT_HS_EXT_H +#define ZLMEDIAKIT_SRT_HS_EXT_H + +#include "Network/Buffer.h" + +#include "Common.hpp" + +namespace SRT { +using namespace toolkit; +class HSExt : public Buffer { +public: + HSExt() = default; + virtual ~HSExt() = default; + + enum { + SRT_CMD_REJECT = 0, + SRT_CMD_HSREQ = 1, + SRT_CMD_HSRSP = 2, + SRT_CMD_KMREQ = 3, + SRT_CMD_KMRSP = 4, + SRT_CMD_SID = 5, + SRT_CMD_CONGESTION = 6, + SRT_CMD_FILTER = 7, + SRT_CMD_GROUP = 8, + SRT_CMD_NONE = -1 + }; + + using Ptr = std::shared_ptr; + uint16_t extension_type; + uint16_t extension_length; + virtual bool loadFromData(uint8_t *buf, size_t len) = 0; + virtual bool storeToData() = 0; + virtual std::string dump() = 0; + ///////Buffer override/////// + char *data() const override { + if (_data) { + return _data->data(); + } + return nullptr; + } + size_t size() const override { + if (_data) { + return _data->size(); + } + return 0; + } + +protected: + void loadHeader() { + uint8_t *ptr = (uint8_t *)_data->data(); + extension_type = loadUint16(ptr); + ptr += 2; + extension_length = loadUint16(ptr); + ptr += 2; + } + void storeHeader() { + uint8_t *ptr = (uint8_t *)_data->data(); + SRT::storeUint16(ptr, extension_type); + ptr += 2; + storeUint16(ptr, extension_length); + } + +protected: + BufferRaw::Ptr _data; +}; + +/* + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| SRT Version | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| SRT Flags | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Receiver TSBPD Delay | Sender TSBPD Delay | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + Figure 6: Handshake Extension Message structure + https://haivision.github.io/srt-rfc/draft-sharabayko-srt.html#name-handshake-extension-message + +*/ +class HSExtMessage : public HSExt { +public: + using Ptr = std::shared_ptr; + enum { + HS_EXT_MSG_TSBPDSND = 0x00000001, + HS_EXT_MSG_TSBPDRCV = 0x00000002, + HS_EXT_MSG_CRYPT = 0x00000004, + HS_EXT_MSG_TLPKTDROP = 0x00000008, + HS_EXT_MSG_PERIODICNAK = 0x00000010, + HS_EXT_MSG_REXMITFLG = 0x00000020, + HS_EXT_MSG_STREAM = 0x00000040, + HS_EXT_MSG_PACKET_FILTER = 0x00000080 + }; + enum { HSEXT_MSG_SIZE = 16 }; + HSExtMessage() = default; + ~HSExtMessage() = default; + bool loadFromData(uint8_t *buf, size_t len) override; + bool storeToData() override; + std::string dump() override; + uint32_t srt_version; + uint32_t srt_flag; + uint16_t recv_tsbpd_delay; + uint16_t send_tsbpd_delay; +}; + +/* + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| | +| Stream ID | + ... +| | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + Figure 7: Stream ID Extension Message + https://haivision.github.io/srt-rfc/draft-sharabayko-srt.html#name-stream-id-extension-message +*/ +class HSExtStreamID : public HSExt { +public: + using Ptr = std::shared_ptr; + HSExtStreamID() = default; + ~HSExtStreamID() = default; + bool loadFromData(uint8_t *buf, size_t len) override; + bool storeToData() override; + std::string dump() override; + std::string streamid; +}; +} // namespace SRT +#endif // ZLMEDIAKIT_SRT_HS_EXT_H \ No newline at end of file diff --git a/srt/NackContext.cpp b/srt/NackContext.cpp new file mode 100644 index 00000000..9711531c --- /dev/null +++ b/srt/NackContext.cpp @@ -0,0 +1,108 @@ +#include "NackContext.hpp" + +namespace SRT { +void NackContext::update(TimePoint now, std::list &lostlist) { + for (auto item : lostlist) { + mergeItem(now, item); + } +} +void NackContext::getLostList( + TimePoint now, uint32_t rtt, uint32_t rtt_variance, std::list &lostlist) { + lostlist.clear(); + std::list tmp_list; + + for (auto it = _nack_map.begin(); it != _nack_map.end(); ++it) { + if (!it->second._is_nack) { + tmp_list.push_back(it->first); + it->second._ts = now; + it->second._is_nack = true; + } else { + if (DurationCountMicroseconds(now - it->second._ts) > rtt) { + tmp_list.push_back(it->first); + it->second._ts = now; + } + } + } + tmp_list.sort(); + + if (tmp_list.empty()) { + return; + } + + uint32_t min = *tmp_list.begin(); + uint32_t max = *tmp_list.rbegin(); + + if ((max - min) >= (MAX_SEQ >> 1)) { + while ((max - tmp_list.front()) > (MAX_SEQ >> 1)) { + tmp_list.push_back(tmp_list.front()); + tmp_list.pop_front(); + } + } + + PacketQueue::LostPair lost; + bool finish = true; + for (auto cur = tmp_list.begin(); cur != tmp_list.end(); ++cur) { + if (finish) { + lost.first = *cur; + lost.second = genExpectedSeq(*cur + 1); + finish = false; + } else { + if (lost.second == *cur) { + lost.second = genExpectedSeq(*cur + 1); + } else { + finish = true; + lostlist.push_back(lost); + } + } + } +} +void NackContext::drop(uint32_t seq) { + if (_nack_map.empty()) + return; + uint32_t min = _nack_map.begin()->first; + uint32_t max = _nack_map.rbegin()->first; + bool is_cycle = false; + if ((max - min) >= (MAX_SEQ >> 1)) { + is_cycle = true; + } + + for (auto it = _nack_map.begin(); it != _nack_map.end();) { + if (!is_cycle) { + // 不回环 + if (it->first <= seq) { + it = _nack_map.erase(it); + } else { + it++; + } + } else { + if (it->first <= seq) { + if ((seq - it->first) >= (MAX_SEQ >> 1)) { + WarnL << "cycle seq " << seq << " " << it->first; + it++; + } else { + it = _nack_map.erase(it); + } + } else { + if ((it->first - seq) >= (MAX_SEQ >> 1)) { + it = _nack_map.erase(it); + WarnL << "cycle seq " << seq << " " << it->first; + } else { + it++; + } + } + } + } +} + +void NackContext::mergeItem(TimePoint now, PacketQueue::LostPair &item) { + for (uint32_t i = item.first; i < item.second; ++i) { + auto it = _nack_map.find(i); + if (it != _nack_map.end()) { + } else { + NackItem tmp; + tmp._is_nack = false; + _nack_map.emplace(i, tmp); + } + } +} +} // namespace SRT \ No newline at end of file diff --git a/srt/NackContext.hpp b/srt/NackContext.hpp new file mode 100644 index 00000000..fe6a0ba6 --- /dev/null +++ b/srt/NackContext.hpp @@ -0,0 +1,30 @@ +#ifndef ZLMEDIAKIT_SRT_NACK_CONTEXT_H +#define ZLMEDIAKIT_SRT_NACK_CONTEXT_H +#include "Common.hpp" +#include "PacketQueue.hpp" +#include + +namespace SRT { +class NackContext { +public: + NackContext() = default; + ~NackContext() = default; + void update(TimePoint now, std::list &lostlist); + void getLostList(TimePoint now, uint32_t rtt, uint32_t rtt_variance, std::list &lostlist); + void drop(uint32_t seq); + +private: + void mergeItem(TimePoint now, PacketQueue::LostPair &item); + +private: + class NackItem { + public: + bool _is_nack = false; + TimePoint _ts; // send nak time + }; + + std::map _nack_map; +}; + +} // namespace SRT +#endif // ZLMEDIAKIT_SRT_NACK_CONTEXT_H \ No newline at end of file diff --git a/srt/Packet.cpp b/srt/Packet.cpp new file mode 100644 index 00000000..85de1cbf --- /dev/null +++ b/srt/Packet.cpp @@ -0,0 +1,603 @@ +#include "Util/MD5.h" +#include "Util/logger.h" +#include + +#include "Packet.hpp" + +namespace SRT { + +const size_t DataPacket::HEADER_SIZE; +const size_t ControlPacket::HEADER_SIZE; +const size_t HandshakePacket::HS_CONTENT_MIN_SIZE; + +bool DataPacket::isDataPacket(uint8_t *buf, size_t len) { + if (len < HEADER_SIZE) { + WarnL << "data size" << len << " less " << HEADER_SIZE; + return false; + } + if (!(buf[0] & 0x80)) { + return true; + } + return false; +} + +uint32_t DataPacket::getSocketID(uint8_t *buf, size_t len) { + uint8_t *ptr = buf; + ptr += 12; + return loadUint32(ptr); +} + +bool DataPacket::loadFromData(uint8_t *buf, size_t len) { + if (len < HEADER_SIZE) { + WarnL << "data size" << len << " less " << HEADER_SIZE; + return false; + } + uint8_t *ptr = buf; + f = ptr[0] >> 7; + packet_seq_number = loadUint32(ptr) & 0x7fffffff; + ptr += 4; + + PP = ptr[0] >> 6; + O = (ptr[0] & 0x20) >> 5; + KK = (ptr[0] & 0x18) >> 3; + R = (ptr[0] & 0x04) >> 2; + msg_number = (ptr[0] & 0x03) << 24 | ptr[1] << 12 | ptr[2] << 8 | ptr[3]; + ptr += 4; + + timestamp = loadUint32(ptr); + ptr += 4; + + dst_socket_id = loadUint32(ptr); + ptr += 4; + + _data = BufferRaw::create(); + _data->assign((char *)(buf), len); + return true; +} + +bool DataPacket::storeToHeader() { + if (!_data || _data->size() < HEADER_SIZE) { + WarnL << "data size less " << HEADER_SIZE; + return false; + } + uint8_t *ptr = (uint8_t *)_data->data(); + + ptr[0] = packet_seq_number >> 24; + ptr[1] = (packet_seq_number >> 16) & 0xff; + ptr[2] = (packet_seq_number >> 8) & 0xff; + ptr[3] = packet_seq_number & 0xff; + ptr += 4; + + ptr[0] = PP << 6; + ptr[0] |= O << 5; + ptr[0] |= KK << 3; + ptr[0] |= R << 2; + ptr[0] |= (msg_number & 0xff000000) >> 24; + ptr[1] = (msg_number & 0xff0000) >> 16; + ptr[2] = (msg_number & 0xff00) >> 8; + ptr[3] = msg_number & 0xff; + ptr += 4; + + storeUint32(ptr, timestamp); + ptr += 4; + + storeUint32(ptr, dst_socket_id); + ptr += 4; + return true; +} + +bool DataPacket::storeToData(uint8_t *buf, size_t len) { + _data = BufferRaw::create(); + _data->setCapacity(len + HEADER_SIZE); + _data->setSize(len + HEADER_SIZE); + + uint8_t *ptr = (uint8_t *)_data->data(); + + ptr[0] = packet_seq_number >> 24; + ptr[1] = (packet_seq_number >> 16) & 0xff; + ptr[2] = (packet_seq_number >> 8) & 0xff; + ptr[3] = packet_seq_number & 0xff; + ptr += 4; + + ptr[0] = PP << 6; + ptr[0] |= O << 5; + ptr[0] |= KK << 3; + ptr[0] |= R << 2; + ptr[0] |= (msg_number & 0xff000000) >> 24; + ptr[1] = (msg_number & 0xff0000) >> 16; + ptr[2] = (msg_number & 0xff00) >> 8; + ptr[3] = msg_number & 0xff; + ptr += 4; + + storeUint32(ptr, timestamp); + ptr += 4; + + storeUint32(ptr, dst_socket_id); + ptr += 4; + + memcpy(ptr, buf, len); + return true; +} + +char *DataPacket::data() const { + if (!_data) + return nullptr; + return _data->data(); +} + +size_t DataPacket::size() const { + if (!_data) { + return 0; + } + return _data->size(); +} + +char *DataPacket::payloadData() { + if (!_data) + return nullptr; + return _data->data() + HEADER_SIZE; +} + +size_t DataPacket::payloadSize() { + if (!_data) { + return 0; + } + return _data->size() - HEADER_SIZE; +} + +bool ControlPacket::isControlPacket(uint8_t *buf, size_t len) { + if (len < HEADER_SIZE) { + WarnL << "data size" << len << " less " << HEADER_SIZE; + return false; + } + if (buf[0] & 0x80) { + return true; + } + return false; +} + +uint16_t ControlPacket::getControlType(uint8_t *buf, size_t len) { + uint8_t *ptr = buf; + uint16_t control_type = (ptr[0] & 0x7f) << 8 | ptr[1]; + return control_type; +} + +bool ControlPacket::loadHeader() { + uint8_t *ptr = (uint8_t *)_data->data(); + f = ptr[0] >> 7; + control_type = (ptr[0] & 0x7f) << 8 | ptr[1]; + ptr += 2; + + sub_type = loadUint16(ptr); + ptr += 2; + + type_specific_info[0] = ptr[0]; + type_specific_info[1] = ptr[1]; + type_specific_info[2] = ptr[2]; + type_specific_info[3] = ptr[3]; + ptr += 4; + + timestamp = loadUint32(ptr); + ptr += 4; + + dst_socket_id = loadUint32(ptr); + ptr += 4; + return true; +} + +bool ControlPacket::storeToHeader() { + uint8_t *ptr = (uint8_t *)_data->data(); + ptr[0] = 0x80; + ptr[0] |= control_type >> 8; + ptr[1] = control_type & 0xff; + ptr += 2; + + storeUint16(ptr, sub_type); + ptr += 2; + + ptr[0] = type_specific_info[0]; + ptr[1] = type_specific_info[1]; + ptr[2] = type_specific_info[2]; + ptr[3] = type_specific_info[3]; + ptr += 4; + + storeUint32(ptr, timestamp); + ptr += 4; + + storeUint32(ptr, dst_socket_id); + ptr += 4; + return true; +} + +char *ControlPacket::data() const { + if (!_data) + return nullptr; + return _data->data(); +} + +size_t ControlPacket::size() const { + if (!_data) { + return 0; + } + return _data->size(); +} + +uint32_t ControlPacket::getSocketID(uint8_t *buf, size_t len) { + return loadUint32(buf + 12); +} + +bool HandshakePacket::loadFromData(uint8_t *buf, size_t len) { + if (HEADER_SIZE + HS_CONTENT_MIN_SIZE > len) { + ErrorL << "size too smalle " << encryption_field; + return false; + } + _data = BufferRaw::create(); + _data->assign((char *)(buf), len); + ControlPacket::loadHeader(); + + uint8_t *ptr = (uint8_t *)_data->data() + HEADER_SIZE; + // parse CIF + version = loadUint32(ptr); + ptr += 4; + + encryption_field = loadUint16(ptr); + ptr += 2; + + extension_field = loadUint16(ptr); + ptr += 2; + + initial_packet_sequence_number = loadUint32(ptr); + ptr += 4; + + mtu = loadUint32(ptr); + ptr += 4; + + max_flow_window_size = loadUint32(ptr); + ptr += 4; + + handshake_type = loadUint32(ptr); + ptr += 4; + + srt_socket_id = loadUint32(ptr); + ptr += 4; + + syn_cookie = loadUint32(ptr); + ptr += 4; + + memcpy(peer_ip_addr, ptr, sizeof(peer_ip_addr) * sizeof(peer_ip_addr[0])); + ptr += sizeof(peer_ip_addr) * sizeof(peer_ip_addr[0]); + + if (encryption_field != NO_ENCRYPTION) { + ErrorL << "not support encryption " << encryption_field; + } + + if (extension_field == 0) { + return true; + } + + if (len == HEADER_SIZE + HS_CONTENT_MIN_SIZE) { + // ErrorL << "extension filed not exist " << extension_field; + return true; + } + + return loadExtMessage(ptr, len - HS_CONTENT_MIN_SIZE - HEADER_SIZE); +} + +bool HandshakePacket::loadExtMessage(uint8_t *buf, size_t len) { + uint8_t *ptr = buf; + ext_list.clear(); + uint16_t type; + uint16_t length; + HSExt::Ptr ext; + while (ptr < buf + len) { + type = loadUint16(ptr); + length = loadUint16(ptr + 2); + switch (type) { + case HSExt::SRT_CMD_HSREQ: + case HSExt::SRT_CMD_HSRSP: + ext = std::make_shared(); + break; + case HSExt::SRT_CMD_SID: + ext = std::make_shared(); + break; + default: + WarnL << "not support ext " << type; + break; + } + if (ext) { + if (ext->loadFromData(ptr, length * 4 + 4)) { + ext_list.push_back(std::move(ext)); + } else { + WarnL << "parse HS EXT failed type=" << type << " len=" << length; + } + ext = nullptr; + } + + ptr += length * 4 + 4; + } + return true; +} + +bool HandshakePacket::storeExtMessage() { + uint8_t *buf = (uint8_t *)_data->data() + HEADER_SIZE + 48; + size_t len = _data->size() - HEADER_SIZE - 48; + for (auto ex : ext_list) { + memcpy(buf, ex->data(), ex->size()); + buf += ex->size(); + } + return true; +} + +size_t HandshakePacket::getExtSize() { + size_t size = 0; + for (auto it : ext_list) { + size += it->size(); + } + return size; +} +bool HandshakePacket::storeToData() { + _data = BufferRaw::create(); + for (auto ex : ext_list) { + ex->storeToData(); + } + auto ext_size = getExtSize(); + _data->setCapacity(HEADER_SIZE + 48 + ext_size); + _data->setSize(HEADER_SIZE + 48 + ext_size); + + control_type = ControlPacket::HANDSHAKE; + sub_type = 0; + + ControlPacket::storeToHeader(); + + uint8_t *ptr = (uint8_t *)_data->data() + HEADER_SIZE; + + storeUint32(ptr, version); + ptr += 4; + + storeUint16(ptr, encryption_field); + ptr += 2; + + storeUint16(ptr, extension_field); + ptr += 2; + + storeUint32(ptr, initial_packet_sequence_number); + ptr += 4; + + storeUint32(ptr, mtu); + ptr += 4; + + storeUint32(ptr, max_flow_window_size); + ptr += 4; + + storeUint32(ptr, handshake_type); + ptr += 4; + + storeUint32(ptr, srt_socket_id); + ptr += 4; + + storeUint32(ptr, syn_cookie); + ptr += 4; + + memcpy(ptr, peer_ip_addr, sizeof(peer_ip_addr) * sizeof(peer_ip_addr[0])); + ptr += sizeof(peer_ip_addr) * sizeof(peer_ip_addr[0]); + + if (encryption_field != NO_ENCRYPTION) { + ErrorL << "not support encryption " << encryption_field; + } + + assert(encryption_field == NO_ENCRYPTION); + + return storeExtMessage(); +} + +bool HandshakePacket::isHandshakePacket(uint8_t *buf, size_t len) { + if (!ControlPacket::isControlPacket(buf, len)) { + return false; + } + if (len < HEADER_SIZE + 48) { + return false; + } + return ControlPacket::getControlType(buf, len) == HANDSHAKE; +} + +uint32_t HandshakePacket::getHandshakeType(uint8_t *buf, size_t len) { + uint8_t *ptr = buf + HEADER_SIZE + 5 * 4; + return loadUint32(ptr); +} + +uint32_t HandshakePacket::getSynCookie(uint8_t *buf, size_t len) { + uint8_t *ptr = buf + HEADER_SIZE + 7 * 4; + return loadUint32(ptr); +} + +void HandshakePacket::assignPeerIP(struct sockaddr_storage *addr) { + memset(peer_ip_addr, 0, sizeof(peer_ip_addr) * sizeof(peer_ip_addr[0])); + if (addr->ss_family == AF_INET) { + struct sockaddr_in *ipv4 = (struct sockaddr_in *)addr; + // 抓包 奇怪好像是小头端??? + storeUint32LE(peer_ip_addr, ipv4->sin_addr.s_addr); + } else if (addr->ss_family == AF_INET6) { + if (IN6_IS_ADDR_V4MAPPED(&((struct sockaddr_in6 *)addr)->sin6_addr)) { + struct in_addr addr4; + memcpy(&addr4, 12 + (char *)&(((struct sockaddr_in6 *)addr)->sin6_addr), 4); + storeUint32LE(peer_ip_addr, addr4.s_addr); + } else { + const sockaddr_in6 *ipv6 = (struct sockaddr_in6 *)addr; + memcpy(peer_ip_addr, ipv6->sin6_addr.s6_addr, sizeof(peer_ip_addr) * sizeof(peer_ip_addr[0])); + } + } +} + +uint32_t HandshakePacket::generateSynCookie( + struct sockaddr_storage *addr, TimePoint ts, uint32_t current_cookie, int correction) { + static std::atomic distractor { 0 }; + uint32_t rollover = distractor.load() + 10; + + while (true) { + // SYN cookie + char clienthost[NI_MAXHOST]; + char clientport[NI_MAXSERV]; + getnameinfo( + (struct sockaddr *)addr, sizeof(struct sockaddr_storage), clienthost, sizeof(clienthost), clientport, + sizeof(clientport), NI_NUMERICHOST | NI_NUMERICSERV); + int64_t timestamp = (DurationCountMicroseconds(SteadyClock::now() - ts) / 60000000) + distractor.load() + + correction; // secret changes every one minute + std::stringstream cookiestr; + cookiestr << clienthost << ":" << clientport << ":" << timestamp; + union { + unsigned char cookie[16]; + uint32_t cookie_val; + }; + MD5 md5(cookiestr.str()); + memcpy(cookie, md5.rawdigest().c_str(), 16); + + if (cookie_val != current_cookie) { + return cookie_val; + } + + ++distractor; + + // This is just to make the loop formally breakable, + // but this is virtually impossible to happen. + if (distractor == rollover) { + return cookie_val; + } + } +} + +bool KeepLivePacket::loadFromData(uint8_t *buf, size_t len) { + if (len < HEADER_SIZE) { + WarnL << "data size" << len << " less " << HEADER_SIZE; + return false; + } + _data = BufferRaw::create(); + _data->assign((char *)buf, len); + + return loadHeader(); +} +bool KeepLivePacket::storeToData() { + control_type = ControlPacket::KEEPALIVE; + sub_type = 0; + + _data = BufferRaw::create(); + _data->setCapacity(HEADER_SIZE); + _data->setSize(HEADER_SIZE); + return storeToHeader(); +} + +bool NAKPacket::loadFromData(uint8_t *buf, size_t len) { + if (len < HEADER_SIZE) { + WarnL << "data size" << len << " less " << HEADER_SIZE; + return false; + } + _data = BufferRaw::create(); + _data->assign((char *)buf, len); + loadHeader(); + + uint8_t *ptr = (uint8_t *)_data->data() + HEADER_SIZE; + uint8_t *end = (uint8_t *)_data->data() + _data->size(); + LostPair lost; + while (ptr < end) { + if ((*ptr) & 0x80) { + lost.first = loadUint32(ptr) & 0x7fffffff; + lost.second = loadUint32(ptr + 4) & 0x7fffffff; + lost.second += 1; + ptr += 8; + } else { + lost.first = loadUint32(ptr); + lost.second = lost.first + 1; + ptr += 4; + } + lost_list.push_back(lost); + } + return true; +} +bool NAKPacket::storeToData() { + control_type = NAK; + sub_type = 0; + size_t cif_size = getCIFSize(); + + _data = BufferRaw::create(); + _data->setCapacity(HEADER_SIZE + cif_size); + _data->setSize(HEADER_SIZE + cif_size); + + storeToHeader(); + + uint8_t *ptr = (uint8_t *)_data->data() + HEADER_SIZE; + + for (auto it : lost_list) { + if (it.first + 1 == it.second) { + storeUint32(ptr, it.first); + ptr[0] = ptr[0] & 0x7f; + ptr += 4; + } else { + storeUint32(ptr, it.first); + ptr[0] |= 0x80; + + storeUint32(ptr + 4, it.second - 1); + // ptr[4] = ptr[4]&0x7f; + + ptr += 8; + } + } + + return true; +} + +size_t NAKPacket::getCIFSize() { + size_t size = 0; + for (auto it : lost_list) { + if (it.first + 1 == it.second) { + size += 4; + } else { + size += 8; + } + } + return size; +} + +std::string NAKPacket::dump() { + _StrPrinter printer; + for (auto it : lost_list) { + printer << "[ " << it.first << " , " << it.second - 1 << " ]"; + } + return std::move(printer); +} + +bool MsgDropReqPacket::loadFromData(uint8_t *buf, size_t len) { + if (len < HEADER_SIZE + 8) { + WarnL << "data size" << len << " less " << HEADER_SIZE; + return false; + } + _data = BufferRaw::create(); + _data->assign((char *)buf, len); + loadHeader(); + + uint8_t *ptr = (uint8_t *)_data->data() + HEADER_SIZE; + + first_pkt_seq_num = loadUint32(ptr); + ptr += 4; + + last_pkt_seq_num = loadUint32(ptr); + ptr += 4; + return true; +} +bool MsgDropReqPacket::storeToData() { + control_type = DROPREQ; + sub_type = 0; + _data = BufferRaw::create(); + _data->setCapacity(HEADER_SIZE + 8); + _data->setSize(HEADER_SIZE + 8); + + storeToHeader(); + + uint8_t *ptr = (uint8_t *)_data->data() + HEADER_SIZE; + + storeUint32(ptr, first_pkt_seq_num); + ptr += 4; + + storeUint32(ptr, last_pkt_seq_num); + ptr += 4; + return true; +} +} // namespace SRT \ No newline at end of file diff --git a/srt/Packet.hpp b/srt/Packet.hpp new file mode 100644 index 00000000..8710e940 --- /dev/null +++ b/srt/Packet.hpp @@ -0,0 +1,360 @@ +#ifndef ZLMEDIAKIT_SRT_PACKET_H +#define ZLMEDIAKIT_SRT_PACKET_H + +#include +#include + +#include "Network/Buffer.h" +#include "Network/sockutil.h" +#include "Util/logger.h" + +#include "Common.hpp" +#include "HSExt.hpp" + +namespace SRT { + +using namespace toolkit; +/* + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 ++-+-+-+-+-+-+-+-+-+-+-+-+- SRT Header +-+-+-+-+-+-+-+-+-+-+-+-+-+ +|0| Packet Sequence Number | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +|P P|O|K K|R| Message Number | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Timestamp | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Destination Socket ID | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| | ++ Data + +| | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + Figure 3: Data packet structure + reference https://haivision.github.io/srt-rfc/draft-sharabayko-srt.html#name-packet-structure +*/ +class DataPacket : public Buffer { +public: + using Ptr = std::shared_ptr; + DataPacket() = default; + ~DataPacket() = default; + + static const size_t HEADER_SIZE = 16; + static bool isDataPacket(uint8_t *buf, size_t len); + static uint32_t getSocketID(uint8_t *buf, size_t len); + bool loadFromData(uint8_t *buf, size_t len); + bool storeToData(uint8_t *buf, size_t len); + bool storeToHeader(); + + ///////Buffer override/////// + char *data() const override; + size_t size() const override; + + char *payloadData(); + size_t payloadSize(); + + uint8_t f; + uint32_t packet_seq_number; + uint8_t PP; + uint8_t O; + uint8_t KK; + uint8_t R; + uint32_t msg_number; + uint32_t timestamp; + uint32_t dst_socket_id; + +private: + BufferRaw::Ptr _data; +}; +/* + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 ++-+-+-+-+-+-+-+-+-+-+-+-+- SRT Header +-+-+-+-+-+-+-+-+-+-+-+-+-+ +|1| Control Type | Subtype | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Type-specific Information | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Timestamp | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Destination Socket ID | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+- CIF -+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| | ++ Control Information Field + +| | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + Figure 4: Control packet structure + reference https://haivision.github.io/srt-rfc/draft-sharabayko-srt.html#name-control-packets +*/ +class ControlPacket : public Buffer { +public: + using Ptr = std::shared_ptr; + static const size_t HEADER_SIZE = 16; + static bool isControlPacket(uint8_t *buf, size_t len); + static uint16_t getControlType(uint8_t *buf, size_t len); + static uint32_t getSocketID(uint8_t *buf, size_t len); + + ControlPacket() = default; + virtual ~ControlPacket() = default; + virtual bool loadFromData(uint8_t *buf, size_t len) = 0; + virtual bool storeToData() = 0; + + bool loadHeader(); + bool storeToHeader(); + + ///////Buffer override/////// + char *data() const override; + size_t size() const override; + + enum { + HANDSHAKE = 0x0000, + KEEPALIVE = 0x0001, + ACK = 0x0002, + NAK = 0x0003, + CONGESTIONWARNING = 0x0004, + SHUTDOWN = 0x0005, + ACKACK = 0x0006, + DROPREQ = 0x0007, + PEERERROR = 0x0008, + USERDEFINEDTYPE = 0x7FFF + }; + + uint32_t sub_type : 16; + uint32_t control_type : 15; + uint32_t f : 1; + uint8_t type_specific_info[4]; + uint32_t timestamp; + uint32_t dst_socket_id; + +protected: + BufferRaw::Ptr _data; +}; + +/** + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Version | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Encryption Field | Extension Field | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Initial Packet Sequence Number | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Maximum Transmission Unit Size | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Maximum Flow Window Size | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Handshake Type | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| SRT Socket ID | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| SYN Cookie | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| | ++ + +| | ++ Peer IP Address + +| | ++ + +| | ++=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ +| Extension Type | Extension Length | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| | ++ Extension Contents + +| | ++=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ + Figure 5: Handshake packet structure + https://haivision.github.io/srt-rfc/draft-sharabayko-srt.html#name-handshake + */ +class HandshakePacket : public ControlPacket { +public: + using Ptr = std::shared_ptr; + enum { NO_ENCRYPTION = 0, AES_128 = 1, AES_196 = 2, AES_256 = 3 }; + static const size_t HS_CONTENT_MIN_SIZE = 48; + enum { + HS_TYPE_DONE = 0xFFFFFFFD, + HS_TYPE_AGREEMENT = 0xFFFFFFFE, + HS_TYPE_CONCLUSION = 0xFFFFFFFF, + HS_TYPE_WAVEHAND = 0x00000000, + HS_TYPE_INDUCTION = 0x00000001 + }; + + enum { HS_EXT_FILED_HSREQ = 0x00000001, HS_EXT_FILED_KMREQ = 0x00000002, HS_EXT_FILED_CONFIG = 0x00000004 }; + + HandshakePacket() = default; + ~HandshakePacket() = default; + + static bool isHandshakePacket(uint8_t *buf, size_t len); + static uint32_t getHandshakeType(uint8_t *buf, size_t len); + static uint32_t getSynCookie(uint8_t *buf, size_t len); + static uint32_t + generateSynCookie(struct sockaddr_storage *addr, TimePoint ts, uint32_t current_cookie = 0, int correction = 0); + + void assignPeerIP(struct sockaddr_storage *addr); + ///////ControlPacket override/////// + bool loadFromData(uint8_t *buf, size_t len) override; + bool storeToData() override; + + uint32_t version; + uint16_t encryption_field; + uint16_t extension_field; + uint32_t initial_packet_sequence_number; + uint32_t mtu; + uint32_t max_flow_window_size; + uint32_t handshake_type; + uint32_t srt_socket_id; + uint32_t syn_cookie; + uint8_t peer_ip_addr[16]; + + std::vector ext_list; + +private: + bool loadExtMessage(uint8_t *buf, size_t len); + bool storeExtMessage(); + size_t getExtSize(); +}; +/* + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 ++-+-+-+-+-+-+-+-+-+-+-+-+- SRT Header +-+-+-+-+-+-+-+-+-+-+-+-+-+ +|1| Control Type | Reserved | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Type-specific Information | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Timestamp | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Destination Socket ID | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + Figure 12: Keep-Alive control packet + https://haivision.github.io/srt-rfc/draft-sharabayko-srt.html#name-keep-alive +*/ +class KeepLivePacket : public ControlPacket { +public: + using Ptr = std::shared_ptr; + KeepLivePacket() = default; + ~KeepLivePacket() = default; + ///////ControlPacket override/////// + bool loadFromData(uint8_t *buf, size_t len) override; + bool storeToData() override; +}; + +/* +An SRT NAK packet is formatted as follows: + + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 ++-+-+-+-+-+-+-+-+-+-+-+-+- SRT Header +-+-+-+-+-+-+-+-+-+-+-+-+-+ +|1| Control Type | Reserved | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Type-specific Information | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Timestamp | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Destination Socket ID | ++-+-+-+-+-+-+-+-+-+-+-+- CIF (Loss List) -+-+-+-+-+-+-+-+-+-+-+-+ +|0| Lost packet sequence number | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +|1| Range of lost packets from sequence number | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +|0| Up to sequence number | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +|0| Lost packet sequence number | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + Figure 14: NAK control packet + https://haivision.github.io/srt-rfc/draft-sharabayko-srt.html#name-nak-control-packet +*/ +class NAKPacket : public ControlPacket { +public: + using Ptr = std::shared_ptr; + using LostPair = std::pair; + NAKPacket() = default; + ~NAKPacket() = default; + std::string dump(); + ///////ControlPacket override/////// + bool loadFromData(uint8_t *buf, size_t len) override; + bool storeToData() override; + + std::list lost_list; + +private: + size_t getCIFSize(); +}; + +/* + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 ++-+-+-+-+-+-+-+-+-+-+-+-+- SRT Header +-+-+-+-+-+-+-+-+-+-+-+-+-+ +|1| Control Type = 7 | Reserved = 0 | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Message Number | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Timestamp | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Destination Socket ID | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| First Packet Sequence Number | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Last Packet Sequence Number | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + Figure 18: Drop Request control packet + https://haivision.github.io/srt-rfc/draft-sharabayko-srt.html#name-message-drop-request +*/ +class MsgDropReqPacket : public ControlPacket { +public: + using Ptr = std::shared_ptr; + MsgDropReqPacket() = default; + ~MsgDropReqPacket() = default; + ///////ControlPacket override/////// + bool loadFromData(uint8_t *buf, size_t len) override; + bool storeToData() override; + + uint32_t first_pkt_seq_num; + uint32_t last_pkt_seq_num; +}; + +/* + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 ++-+-+-+-+-+-+-+-+-+-+-+-+- SRT Header +-+-+-+-+-+-+-+-+-+-+-+-+-+ +|1| Control Type | Reserved | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Type-specific Information | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Timestamp | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Destination Socket ID | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + Figure 16: Shutdown control packet + https://haivision.github.io/srt-rfc/draft-sharabayko-srt.html#name-shutdown + +*/ +class ShutDownPacket : public ControlPacket { +public: + using Ptr = std::shared_ptr; + ShutDownPacket() = default; + ~ShutDownPacket() = default; + + ///////ControlPacket override/////// + bool loadFromData(uint8_t *buf, size_t len) override { + if (len < HEADER_SIZE) { + WarnL << "data size" << len << " less " << HEADER_SIZE; + return false; + } + _data = BufferRaw::create(); + _data->assign((char *)buf, len); + + return loadHeader(); + } + bool storeToData() override { + control_type = ControlPacket::SHUTDOWN; + sub_type = 0; + _data = BufferRaw::create(); + _data->setCapacity(HEADER_SIZE); + _data->setSize(HEADER_SIZE); + return storeToHeader(); + } +}; + +} // namespace SRT + +#endif // ZLMEDIAKIT_SRT_PACKET_H \ No newline at end of file diff --git a/srt/PacketQueue.cpp b/srt/PacketQueue.cpp new file mode 100644 index 00000000..cb43bd13 --- /dev/null +++ b/srt/PacketQueue.cpp @@ -0,0 +1,223 @@ +#include "PacketQueue.hpp" + +namespace SRT { + +static inline bool isSeqEdge(uint32_t seq, uint32_t cap) { + if (seq > (MAX_SEQ - cap)) { + return true; + } + return false; +} + +static inline bool isTSCycle(uint32_t first, uint32_t second) { + uint32_t diff; + if (first > second) { + diff = first - second; + } else { + diff = second - first; + } + + if (diff > (MAX_TS >> 1)) { + return true; + } else { + return false; + } +} + +PacketQueue::PacketQueue(uint32_t max_size, uint32_t init_seq, uint32_t latency) + : _pkt_cap(max_size) + , _pkt_latency(latency) + , _pkt_expected_seq(init_seq) {} + +void PacketQueue::tryInsertPkt(DataPacket::Ptr pkt) { + if (_pkt_expected_seq <= pkt->packet_seq_number) { + auto diff = pkt->packet_seq_number - _pkt_expected_seq; + if (diff >= (MAX_SEQ >> 1)) { + TraceL << "drop packet too later for cycle " + << "expected seq=" << _pkt_expected_seq << " pkt seq=" << pkt->packet_seq_number; + return; + } else { + _pkt_map.emplace(pkt->packet_seq_number, pkt); + } + } else { + auto diff = _pkt_expected_seq - pkt->packet_seq_number; + if (diff >= (MAX_SEQ >> 1)) { + _pkt_map.emplace(pkt->packet_seq_number, pkt); + TraceL << " cycle packet " + << "expected seq=" << _pkt_expected_seq << " pkt seq=" << pkt->packet_seq_number; + } else { + // TraceL << "drop packet too later "<< "expected seq=" << _pkt_expected_seq << " pkt seq=" << + // pkt->packet_seq_number; + } + } +} + +bool PacketQueue::inputPacket(DataPacket::Ptr pkt, std::list &out) { + tryInsertPkt(pkt); + auto it = _pkt_map.find(_pkt_expected_seq); + while (it != _pkt_map.end()) { + out.push_back(it->second); + _pkt_map.erase(it); + _pkt_expected_seq = genExpectedSeq(_pkt_expected_seq + 1); + it = _pkt_map.find(_pkt_expected_seq); + } + + while (_pkt_map.size() > _pkt_cap) { + // 防止回环 + it = _pkt_map.find(_pkt_expected_seq); + if (it != _pkt_map.end()) { + out.push_back(it->second); + _pkt_map.erase(it); + } + _pkt_expected_seq = genExpectedSeq(_pkt_expected_seq + 1); + } + + while (timeLatency() > _pkt_latency) { + it = _pkt_map.find(_pkt_expected_seq); + if (it != _pkt_map.end()) { + out.push_back(it->second); + _pkt_map.erase(it); + } + _pkt_expected_seq = genExpectedSeq(_pkt_expected_seq + 1); + } + + return true; +} + +bool PacketQueue::drop(uint32_t first, uint32_t last, std::list &out) { + uint32_t end = genExpectedSeq(last + 1); + decltype(_pkt_map.end()) it; + for (uint32_t i = _pkt_expected_seq; i < end;) { + it = _pkt_map.find(i); + if (it != _pkt_map.end()) { + out.push_back(it->second); + _pkt_map.erase(it); + } + i = genExpectedSeq(i + 1); + } + _pkt_expected_seq = end; + return true; +} + +uint32_t PacketQueue::timeLatency() { + if (_pkt_map.empty()) { + return 0; + } + + auto first = _pkt_map.begin()->second->timestamp; + auto last = _pkt_map.rbegin()->second->timestamp; + uint32_t dur; + if (last > first) { + dur = last - first; + } else { + dur = first - last; + } + + if (dur > 0x80000000) { + dur = MAX_TS - dur; + WarnL << "cycle dur " << dur; + } + + return dur; +} + +std::list PacketQueue::getLostSeq() { + std::list re; + if (_pkt_map.empty()) { + return re; + } + + if (getExpectedSize() == getSize()) { + return re; + } + + uint32_t end = 0; + uint32_t first, last; + + first = _pkt_map.begin()->second->packet_seq_number; + last = _pkt_map.rbegin()->second->packet_seq_number; + if ((last - first) > (MAX_SEQ >> 1)) { + TraceL << " cycle seq first " << first << " last " << last << " size " << _pkt_map.size(); + end = first; + } else { + end = last; + } + PacketQueue::LostPair lost; + lost.first = 0; + lost.second = 0; + + uint32_t i = _pkt_expected_seq; + bool finish = true; + for (i = _pkt_expected_seq; i <= end;) { + if (_pkt_map.find(i) == _pkt_map.end()) { + if (finish) { + finish = false; + lost.first = i; + lost.second = genExpectedSeq(i + 1); + } else { + lost.second = genExpectedSeq(i + 1); + } + } else { + if (!finish) { + finish = true; + re.push_back(lost); + } + } + i = genExpectedSeq(i + 1); + } + + return re; +} + +size_t PacketQueue::getSize() { + return _pkt_map.size(); +} + +size_t PacketQueue::getExpectedSize() { + if (_pkt_map.empty()) { + return 0; + } + + uint32_t max = _pkt_map.rbegin()->first; + uint32_t min = _pkt_map.begin()->first; + if ((max - min) >= (MAX_SEQ >> 1)) { + TraceL << "cycle " + << "expected seq " << _pkt_expected_seq << " min " << min << " max " << max << " size " + << _pkt_map.size(); + return MAX_SEQ - _pkt_expected_seq + min + 1; + } else { + return max - _pkt_expected_seq + 1; + } +} + +size_t PacketQueue::getAvailableBufferSize() { + auto size = getExpectedSize(); + if (_pkt_cap > size) { + return _pkt_cap - size; + } + + if (_pkt_cap > _pkt_map.size()) { + return _pkt_cap - _pkt_map.size(); + } + WarnL << " cap " << _pkt_cap << " expected size " << size << " map size " << _pkt_map.size(); + return _pkt_cap; +} + +uint32_t PacketQueue::getExpectedSeq() { + return _pkt_expected_seq; +} + +std::string PacketQueue::dump() { + _StrPrinter printer; + if (_pkt_map.empty()) { + printer << " expected seq :" << _pkt_expected_seq; + } else { + printer << " expected seq :" << _pkt_expected_seq << " size:" << _pkt_map.size() + << " first:" << _pkt_map.begin()->second->packet_seq_number; + printer << " last:" << _pkt_map.rbegin()->second->packet_seq_number; + printer << " latency:" << timeLatency() / 1e3; + } + return std::move(printer); +} + +} // namespace SRT \ No newline at end of file diff --git a/srt/PacketQueue.hpp b/srt/PacketQueue.hpp new file mode 100644 index 00000000..58d5f18a --- /dev/null +++ b/srt/PacketQueue.hpp @@ -0,0 +1,46 @@ +#ifndef ZLMEDIAKIT_SRT_PACKET_QUEUE_H +#define ZLMEDIAKIT_SRT_PACKET_QUEUE_H +#include "Packet.hpp" +#include +#include +#include +#include +#include +#include + +namespace SRT { + +// for recv +class PacketQueue { +public: + using Ptr = std::shared_ptr; + using LostPair = std::pair; + + PacketQueue(uint32_t max_size, uint32_t init_seq, uint32_t latency); + ~PacketQueue() = default; + bool inputPacket(DataPacket::Ptr pkt, std::list &out); + + uint32_t timeLatency(); + std::list getLostSeq(); + + size_t getSize(); + size_t getExpectedSize(); + size_t getAvailableBufferSize(); + uint32_t getExpectedSeq(); + + std::string dump(); + bool drop(uint32_t first, uint32_t last, std::list &out); + +private: + void tryInsertPkt(DataPacket::Ptr pkt); + +private: + uint32_t _pkt_cap; + uint32_t _pkt_latency; + uint32_t _pkt_expected_seq = 0; + std::map _pkt_map; +}; + +} // namespace SRT + +#endif // ZLMEDIAKIT_SRT_PACKET_QUEUE_H \ No newline at end of file diff --git a/srt/PacketSendQueue.cpp b/srt/PacketSendQueue.cpp new file mode 100644 index 00000000..92730a4a --- /dev/null +++ b/srt/PacketSendQueue.cpp @@ -0,0 +1,79 @@ +#include "PacketSendQueue.hpp" + +namespace SRT { + +PacketSendQueue::PacketSendQueue(uint32_t max_size, uint32_t latency) + : _pkt_cap(max_size) + , _pkt_latency(latency) {} + +bool PacketSendQueue::drop(uint32_t num) { + decltype(_pkt_cache.begin()) it; + for (it = _pkt_cache.begin(); it != _pkt_cache.end(); ++it) { + if ((*it)->packet_seq_number == num) { + break; + } + } + if (it != _pkt_cache.end()) { + _pkt_cache.erase(_pkt_cache.begin(), it); + } + return true; +} + +bool PacketSendQueue::inputPacket(DataPacket::Ptr pkt) { + _pkt_cache.push_back(pkt); + while (_pkt_cache.size() > _pkt_cap) { + _pkt_cache.pop_front(); + } + while (timeLatency() > _pkt_latency) { + _pkt_cache.pop_front(); + } + return true; +} + +std::list PacketSendQueue::findPacketBySeq(uint32_t start, uint32_t end) { + std::list re; + decltype(_pkt_cache.begin()) it; + for (it = _pkt_cache.begin(); it != _pkt_cache.end(); ++it) { + if ((*it)->packet_seq_number == start) { + break; + } + } + + if (start == end) { + if (it != _pkt_cache.end()) { + re.push_back(*it); + } + return re; + } + + for (; it != _pkt_cache.end(); ++it) { + re.push_back(*it); + if ((*it)->packet_seq_number == end) { + break; + } + } + return re; +} + +uint32_t PacketSendQueue::timeLatency() { + if (_pkt_cache.empty()) { + return 0; + } + auto first = _pkt_cache.front()->timestamp; + auto last = _pkt_cache.back()->timestamp; + uint32_t dur; + + if (last > first) { + dur = last - first; + } else { + dur = first - last; + } + if (dur > ((uint32_t)0x01 << 31)) { + TraceL << "cycle timeLatency " << dur; + dur = 0xffffffff - dur; + } + + return dur; +} + +} // namespace SRT \ No newline at end of file diff --git a/srt/PacketSendQueue.hpp b/srt/PacketSendQueue.hpp new file mode 100644 index 00000000..be91c663 --- /dev/null +++ b/srt/PacketSendQueue.hpp @@ -0,0 +1,36 @@ +#ifndef ZLMEDIAKIT_SRT_PACKET_SEND_QUEUE_H +#define ZLMEDIAKIT_SRT_PACKET_SEND_QUEUE_H + +#include "Packet.hpp" +#include +#include +#include +#include +#include +#include + +namespace SRT { + +class PacketSendQueue { +public: + using Ptr = std::shared_ptr; + using LostPair = std::pair; + + PacketSendQueue(uint32_t max_size, uint32_t latency); + ~PacketSendQueue() = default; + + bool drop(uint32_t num); + bool inputPacket(DataPacket::Ptr pkt); + std::list findPacketBySeq(uint32_t start, uint32_t end); + +private: + uint32_t timeLatency(); +private: + uint32_t _pkt_cap; + uint32_t _pkt_latency; + std::list _pkt_cache; +}; + +} // namespace SRT + +#endif // ZLMEDIAKIT_SRT_PACKET_SEND_QUEUE_H \ No newline at end of file diff --git a/srt/SrtSession.cpp b/srt/SrtSession.cpp new file mode 100644 index 00000000..d62cb7cd --- /dev/null +++ b/srt/SrtSession.cpp @@ -0,0 +1,146 @@ +#include "SrtSession.hpp" +#include "Packet.hpp" +#include "SrtTransportImp.hpp" + +#include "Common/config.h" + +namespace SRT { +using namespace mediakit; + +SrtSession::SrtSession(const Socket::Ptr &sock) + : UdpSession(sock) { + socklen_t addr_len = sizeof(_peer_addr); + memset(&_peer_addr, 0, addr_len); + // TraceL<<"before addr len "<rawFD(), (struct sockaddr *)&_peer_addr, &addr_len); + // TraceL<<"after addr len "<data(); + size_t size = buffer->size(); + + if (DataPacket::isDataPacket(data, size)) { + uint32_t socket_id = DataPacket::getSocketID(data, size); + auto trans = SrtTransportManager::Instance().getItem(std::to_string(socket_id)); + return trans ? trans->getPoller() : nullptr; + } + + if (HandshakePacket::isHandshakePacket(data, size)) { + auto type = HandshakePacket::getHandshakeType(data, size); + if (type == HandshakePacket::HS_TYPE_INDUCTION) { + // 握手第一阶段 + return nullptr; + } else if (type == HandshakePacket::HS_TYPE_CONCLUSION) { + // 握手第二阶段 + uint32_t sync_cookie = HandshakePacket::getSynCookie(data, size); + auto trans = SrtTransportManager::Instance().getHandshakeItem(std::to_string(sync_cookie)); + return trans ? trans->getPoller() : nullptr; + } else { + WarnL << " not reach there"; + } + } else { + uint32_t socket_id = ControlPacket::getSocketID(data, size); + auto trans = SrtTransportManager::Instance().getItem(std::to_string(socket_id)); + return trans ? trans->getPoller() : nullptr; + } + return nullptr; +} + +void SrtSession::attachServer(const toolkit::Server &server) { + SockUtil::setRecvBuf(getSock()->rawFD(), 1024 * 1024); +} + +void SrtSession::onRecv(const Buffer::Ptr &buffer) { + uint8_t *data = (uint8_t *)buffer->data(); + size_t size = buffer->size(); + + if (_find_transport) { + //只允许寻找一次transport + _find_transport = false; + + if (DataPacket::isDataPacket(data, size)) { + uint32_t socket_id = DataPacket::getSocketID(data, size); + auto trans = SrtTransportManager::Instance().getItem(std::to_string(socket_id)); + if (trans) { + _transport = std::move(trans); + } else { + WarnL << " data packet not find transport "; + } + } + + if (HandshakePacket::isHandshakePacket(data, size)) { + auto type = HandshakePacket::getHandshakeType(data, size); + if (type == HandshakePacket::HS_TYPE_INDUCTION) { + // 握手第一阶段 + _transport = std::make_shared(getPoller()); + + } else if (type == HandshakePacket::HS_TYPE_CONCLUSION) { + // 握手第二阶段 + uint32_t sync_cookie = HandshakePacket::getSynCookie(data, size); + auto trans = SrtTransportManager::Instance().getHandshakeItem(std::to_string(sync_cookie)); + if (trans) { + _transport = std::move(trans); + } else { + WarnL << " hanshake packet not find transport "; + } + } else { + WarnL << " not reach there"; + } + } else { + uint32_t socket_id = ControlPacket::getSocketID(data, size); + auto trans = SrtTransportManager::Instance().getItem(std::to_string(socket_id)); + if (trans) { + _transport = std::move(trans); + } else { + WarnL << " not find transport"; + } + } + + if (_transport) { + _transport->setSession(shared_from_this()); + } + InfoP(this); + } + _ticker.resetTime(); + + if (_transport) { + _transport->inputSockData(data, size, &_peer_addr); + } else { + // WarnL<< "ingore data"; + } +} + +void SrtSession::onError(const SockException &err) { + // udp链接超时,但是srt链接不一定超时,因为可能存在udp链接迁移的情况 + //在udp链接迁移时,新的SrtSession对象将接管SrtSession对象的生命周期 + //本SrtSession对象将在超时后自动销毁 + WarnP(this) << err.what(); + + if (!_transport) { + return; + } + + // 防止互相引用导致不释放 + auto transport = std::move(_transport); + getPoller()->async( + [transport, err] { + //延时减引用,防止使用transport对象时,销毁对象 + transport->onShutdown(err); + }, + false); +} + +void SrtSession::onManager() { + GET_CONFIG(float, timeoutSec, kTimeOutSec); + if (_ticker.elapsedTime() > timeoutSec * 1000) { + shutdown(SockException(Err_timeout, "srt connection timeout")); + return; + } +} + +} // namespace SRT \ No newline at end of file diff --git a/srt/SrtSession.hpp b/srt/SrtSession.hpp new file mode 100644 index 00000000..342a4a91 --- /dev/null +++ b/srt/SrtSession.hpp @@ -0,0 +1,30 @@ +#ifndef ZLMEDIAKIT_SRT_SESSION_H +#define ZLMEDIAKIT_SRT_SESSION_H + +#include "Network/Session.h" +#include "SrtTransport.hpp" + +namespace SRT { + +using namespace toolkit; + +class SrtSession : public UdpSession { +public: + SrtSession(const Socket::Ptr &sock); + ~SrtSession() override; + + void onRecv(const Buffer::Ptr &) override; + void onError(const SockException &err) override; + void onManager() override; + void attachServer(const toolkit::Server &server) override; + static EventPoller::Ptr queryPoller(const Buffer::Ptr &buffer); + +private: + bool _find_transport = true; + Ticker _ticker; + struct sockaddr_storage _peer_addr; + SrtTransport::Ptr _transport; +}; + +} // namespace SRT +#endif // ZLMEDIAKIT_SRT_SESSION_H \ No newline at end of file diff --git a/srt/SrtTransport.cpp b/srt/SrtTransport.cpp new file mode 100644 index 00000000..4714bc64 --- /dev/null +++ b/srt/SrtTransport.cpp @@ -0,0 +1,673 @@ +#include "Util/onceToken.h" +#include + +#include "Ack.hpp" +#include "Packet.hpp" +#include "SrtTransport.hpp" + +namespace SRT { +#define SRT_FIELD "srt." +// srt 超时时间 +const std::string kTimeOutSec = SRT_FIELD "timeoutSec"; +// srt 单端口udp服务器 +const std::string kPort = SRT_FIELD "port"; +const std::string kLatencyMul = SRT_FIELD "latencyMul"; + +static std::atomic s_srt_socket_id_generate { 125 }; +//////////// SrtTransport ////////////////////////// +SrtTransport::SrtTransport(const EventPoller::Ptr &poller) + : _poller(poller) { + _start_timestamp = SteadyClock::now(); + _socket_id = s_srt_socket_id_generate.fetch_add(1); + _pkt_recv_rate_context = std::make_shared(_start_timestamp); + _recv_rate_context = std::make_shared(_start_timestamp); + _estimated_link_capacity_context = std::make_shared(_start_timestamp); +} + +SrtTransport::~SrtTransport() { + TraceL << " "; +} + +const EventPoller::Ptr &SrtTransport::getPoller() const { + return _poller; +} + +void SrtTransport::setSession(Session::Ptr session) { + _history_sessions.emplace(session.get(), session); + if (_selected_session) { + InfoL << "srt network changed: " << _selected_session->get_peer_ip() << ":" + << _selected_session->get_peer_port() << " -> " << session->get_peer_ip() << ":" + << session->get_peer_port() << ", id:" << _selected_session->getIdentifier(); + } + _selected_session = session; +} + +const Session::Ptr &SrtTransport::getSession() const { + return _selected_session; +} + +void SrtTransport::switchToOtherTransport(uint8_t *buf, int len, uint32_t socketid, struct sockaddr_storage *addr) { + BufferRaw::Ptr tmp = BufferRaw::create(); + struct sockaddr_storage tmp_addr = *addr; + tmp->assign((char *)buf, len); + auto trans = SrtTransportManager::Instance().getItem(std::to_string(socketid)); + if (trans) { + trans->getPoller()->async([tmp, tmp_addr, trans] { + trans->inputSockData((uint8_t *)tmp->data(), tmp->size(), (struct sockaddr_storage *)&tmp_addr); + }); + } +} + +void SrtTransport::inputSockData(uint8_t *buf, int len, struct sockaddr_storage *addr) { + using srt_control_handler = void (SrtTransport::*)(uint8_t * buf, int len, struct sockaddr_storage *addr); + static std::unordered_map s_control_functions; + static onceToken token([]() { + s_control_functions.emplace(ControlPacket::HANDSHAKE, &SrtTransport::handleHandshake); + s_control_functions.emplace(ControlPacket::KEEPALIVE, &SrtTransport::handleKeeplive); + s_control_functions.emplace(ControlPacket::ACK, &SrtTransport::handleACK); + s_control_functions.emplace(ControlPacket::NAK, &SrtTransport::handleNAK); + s_control_functions.emplace(ControlPacket::CONGESTIONWARNING, &SrtTransport::handleCongestionWarning); + s_control_functions.emplace(ControlPacket::SHUTDOWN, &SrtTransport::handleShutDown); + s_control_functions.emplace(ControlPacket::ACKACK, &SrtTransport::handleACKACK); + s_control_functions.emplace(ControlPacket::DROPREQ, &SrtTransport::handleDropReq); + s_control_functions.emplace(ControlPacket::PEERERROR, &SrtTransport::handlePeerError); + s_control_functions.emplace(ControlPacket::USERDEFINEDTYPE, &SrtTransport::handleUserDefinedType); + }); + _now = SteadyClock::now(); + // 处理srt数据 + if (DataPacket::isDataPacket(buf, len)) { + uint32_t socketId = DataPacket::getSocketID(buf, len); + if (socketId == _socket_id) { + _pkt_recv_rate_context->inputPacket(_now); + _estimated_link_capacity_context->inputPacket(_now); + _recv_rate_context->inputPacket(_now, len); + + handleDataPacket(buf, len, addr); + } else { + switchToOtherTransport(buf, len, socketId, addr); + } + } else { + if (ControlPacket::isControlPacket(buf, len)) { + uint32_t socketId = ControlPacket::getSocketID(buf, len); + uint16_t type = ControlPacket::getControlType(buf, len); + if (type != ControlPacket::HANDSHAKE && socketId != _socket_id && _socket_id != 0) { + // socket id not same + switchToOtherTransport(buf, len, socketId, addr); + return; + } + _pkt_recv_rate_context->inputPacket(_now); + _estimated_link_capacity_context->inputPacket(_now); + _recv_rate_context->inputPacket(_now, len); + + auto it = s_control_functions.find(type); + if (it == s_control_functions.end()) { + WarnL << " not support type ignore" << ControlPacket::getControlType(buf, len); + return; + } else { + (this->*(it->second))(buf, len, addr); + } + } else { + // not reach + WarnL << "not reach this"; + } + } +} + +void SrtTransport::handleHandshakeInduction(HandshakePacket &pkt, struct sockaddr_storage *addr) { + // Induction Phase + TraceL << getIdentifier() << " Induction Phase "; + if (_handleshake_res) { + TraceL << getIdentifier() << " Induction handle repeate "; + sendControlPacket(_handleshake_res, true); + return; + } + _induction_ts = _now; + _start_timestamp = _now; + _init_seq_number = pkt.initial_packet_sequence_number; + _max_window_size = pkt.max_flow_window_size; + _mtu = pkt.mtu; + + _last_pkt_seq = _init_seq_number - 1; + + _peer_socket_id = pkt.srt_socket_id; + HandshakePacket::Ptr res = std::make_shared(); + res->dst_socket_id = _peer_socket_id; + res->timestamp = DurationCountMicroseconds(_start_timestamp.time_since_epoch()); + res->mtu = _mtu; + res->max_flow_window_size = _max_window_size; + res->initial_packet_sequence_number = _init_seq_number; + res->version = 5; + res->encryption_field = HandshakePacket::NO_ENCRYPTION; + res->extension_field = 0x4A17; + res->handshake_type = HandshakePacket::HS_TYPE_INDUCTION; + res->srt_socket_id = _peer_socket_id; + res->syn_cookie = HandshakePacket::generateSynCookie(addr, _start_timestamp); + _sync_cookie = res->syn_cookie; + memcpy(res->peer_ip_addr, pkt.peer_ip_addr, sizeof(pkt.peer_ip_addr) * sizeof(pkt.peer_ip_addr[0])); + _handleshake_res = res; + res->storeToData(); + + registerSelfHandshake(); + sendControlPacket(res, true); +} + +void SrtTransport::handleHandshakeConclusion(HandshakePacket &pkt, struct sockaddr_storage *addr) { + if (!_handleshake_res) { + ErrorL << "must Induction Phase for handleshake "; + return; + } + + if (_handleshake_res->handshake_type == HandshakePacket::HS_TYPE_INDUCTION) { + // first + HSExtMessage::Ptr req; + HSExtStreamID::Ptr sid; + uint32_t srt_flag = 0xbf; + uint16_t delay = DurationCountMicroseconds(_now - _induction_ts) * getLatencyMul() / 1000; + + for (auto ext : pkt.ext_list) { + // TraceL << getIdentifier() << " ext " << ext->dump(); + if (!req) { + req = std::dynamic_pointer_cast(ext); + } + if (!sid) { + sid = std::dynamic_pointer_cast(ext); + } + } + if (sid) { + _stream_id = sid->streamid; + } + if (req) { + srt_flag = req->srt_flag; + delay = delay <= req->recv_tsbpd_delay ? req->recv_tsbpd_delay : delay; + } + TraceL << getIdentifier() << " CONCLUSION Phase "; + HandshakePacket::Ptr res = std::make_shared(); + res->dst_socket_id = _peer_socket_id; + res->timestamp = DurationCountMicroseconds(_now - _start_timestamp); + res->mtu = _mtu; + res->max_flow_window_size = _max_window_size; + res->initial_packet_sequence_number = _init_seq_number; + res->version = 5; + res->encryption_field = HandshakePacket::NO_ENCRYPTION; + res->extension_field = HandshakePacket::HS_EXT_FILED_HSREQ; + res->handshake_type = HandshakePacket::HS_TYPE_CONCLUSION; + res->srt_socket_id = _socket_id; + res->syn_cookie = 0; + res->assignPeerIP(addr); + HSExtMessage::Ptr ext = std::make_shared(); + ext->extension_type = HSExt::SRT_CMD_HSRSP; + ext->srt_version = srtVersion(1, 5, 0); + ext->srt_flag = srt_flag; + ext->recv_tsbpd_delay = ext->send_tsbpd_delay = delay; + res->ext_list.push_back(std::move(ext)); + res->storeToData(); + _handleshake_res = res; + unregisterSelfHandshake(); + registerSelf(); + sendControlPacket(res, true); + TraceL << " buf size = " << res->max_flow_window_size << " init seq =" << _init_seq_number + << " latency=" << delay; + _recv_buf = std::make_shared(res->max_flow_window_size, _init_seq_number, delay * 1e3); + _send_buf = std::make_shared(res->max_flow_window_size, delay * 1e3); + _send_packet_seq_number = _init_seq_number; + _buf_delay = delay; + onHandShakeFinished(_stream_id, addr); + } else { + TraceL << getIdentifier() << " CONCLUSION handle repeate "; + sendControlPacket(_handleshake_res, true); + } + _last_ack_pkt_seq_num = _init_seq_number; +} + +void SrtTransport::handleHandshake(uint8_t *buf, int len, struct sockaddr_storage *addr) { + HandshakePacket pkt; + assert(pkt.loadFromData(buf, len)); + + if (pkt.handshake_type == HandshakePacket::HS_TYPE_INDUCTION) { + handleHandshakeInduction(pkt, addr); + } else if (pkt.handshake_type == HandshakePacket::HS_TYPE_CONCLUSION) { + handleHandshakeConclusion(pkt, addr); + } else { + WarnL << " not support handshake type = " << pkt.handshake_type; + } + _ack_ticker.resetTime(_now); + _nak_ticker.resetTime(_now); +} + +void SrtTransport::handleKeeplive(uint8_t *buf, int len, struct sockaddr_storage *addr) { + // TraceL; + sendKeepLivePacket(); +} + +void SrtTransport::sendKeepLivePacket() { + KeepLivePacket::Ptr pkt = std::make_shared(); + pkt->dst_socket_id = _peer_socket_id; + pkt->timestamp = DurationCountMicroseconds(_now - _start_timestamp); + pkt->storeToData(); + sendControlPacket(pkt, true); +} + +void SrtTransport::handleACK(uint8_t *buf, int len, struct sockaddr_storage *addr) { + // TraceL; + ACKPacket ack; + if (!ack.loadFromData(buf, len)) { + return; + } + + ACKACKPacket::Ptr pkt = std::make_shared(); + pkt->dst_socket_id = _peer_socket_id; + pkt->timestamp = DurationCountMicroseconds(_now - _start_timestamp); + pkt->ack_number = ack.ack_number; + pkt->storeToData(); + _send_buf->drop(ack.last_ack_pkt_seq_number); + sendControlPacket(pkt, true); + // TraceL<<"ack number "<(); + pkt->dst_socket_id = _peer_socket_id; + pkt->timestamp = DurationCountMicroseconds(_now - _start_timestamp); + pkt->first_pkt_seq_num = first; + pkt->last_pkt_seq_num = last; + pkt->storeToData(); + sendControlPacket(pkt, true); +} + +void SrtTransport::handleNAK(uint8_t *buf, int len, struct sockaddr_storage *addr) { + // TraceL; + NAKPacket pkt; + pkt.loadFromData(buf, len); + bool empty = false; + bool flush = false; + + for (auto it : pkt.lost_list) { + if (pkt.lost_list.back() == it) { + flush = true; + } + empty = true; + auto re_list = _send_buf->findPacketBySeq(it.first, it.second - 1); + for (auto pkt : re_list) { + pkt->R = 1; + pkt->storeToHeader(); + sendPacket(pkt, flush); + empty = false; + } + if (empty) { + sendMsgDropReq(it.first, it.second - 1); + } + } +} + +void SrtTransport::handleCongestionWarning(uint8_t *buf, int len, struct sockaddr_storage *addr) { + TraceL; +} + +void SrtTransport::handleShutDown(uint8_t *buf, int len, struct sockaddr_storage *addr) { + TraceL; + onShutdown(SockException(Err_shutdown, "peer close connection")); +} + +void SrtTransport::handleDropReq(uint8_t *buf, int len, struct sockaddr_storage *addr) { + MsgDropReqPacket pkt; + pkt.loadFromData(buf, len); + std::list list; + // TraceL<<"drop "<drop(pkt.first_pkt_seq_num, pkt.last_pkt_seq_num, list); + if (list.empty()) { + return; + } + uint32_t max_seq = 0; + for (auto data : list) { + max_seq = data->packet_seq_number; + if (_last_pkt_seq + 1 != data->packet_seq_number) { + TraceL << "pkt lost " << _last_pkt_seq + 1 << "->" << data->packet_seq_number; + } + _last_pkt_seq = data->packet_seq_number; + onSRTData(std::move(data)); + } + _recv_nack.drop(max_seq); + + auto lost = _recv_buf->getLostSeq(); + _recv_nack.update(_now, lost); + lost.clear(); + _recv_nack.getLostList(_now, _rtt, _rtt_variance, lost); + if (!lost.empty()) { + sendNAKPacket(lost); + // TraceL << "check lost send nack"; + } + + auto nak_interval = (_rtt + _rtt_variance * 4) / 2; + if (nak_interval <= 20 * 1000) { + nak_interval = 20 * 1000; + } + if (_nak_ticker.elapsedTime(_now) > nak_interval) { + auto lost = _recv_buf->getLostSeq(); + if (!lost.empty()) { + sendNAKPacket(lost); + } + _nak_ticker.resetTime(_now); + } + + if (_ack_ticker.elapsedTime(_now) > 10 * 1000) { + _light_ack_pkt_count = 0; + _ack_ticker.resetTime(_now); + // send a ack per 10 ms for receiver + sendACKPacket(); + } else { + if (_light_ack_pkt_count >= 64) { + // for high bitrate stream send light ack + // TODO + sendLightACKPacket(); + TraceL << "send light ack"; + } + _light_ack_pkt_count = 0; + } + _light_ack_pkt_count++; +} + +void SrtTransport::handleUserDefinedType(uint8_t *buf, int len, struct sockaddr_storage *addr) { + TraceL; +} + +void SrtTransport::handleACKACK(uint8_t *buf, int len, struct sockaddr_storage *addr) { + // TraceL; + ACKACKPacket::Ptr pkt = std::make_shared(); + pkt->loadFromData(buf, len); + + uint32_t rtt = DurationCountMicroseconds(_now - _ack_send_timestamp[pkt->ack_number]); + _rtt_variance = (3 * _rtt_variance + abs((long)_rtt - (long)rtt)) / 4; + _rtt = (7 * rtt + _rtt) / 8; + + // TraceL<<" rtt:"<<_rtt<<" rtt variance:"<<_rtt_variance; + _ack_send_timestamp.erase(pkt->ack_number); +} + +void SrtTransport::handlePeerError(uint8_t *buf, int len, struct sockaddr_storage *addr) { + TraceL; +} + +void SrtTransport::sendACKPacket() { + ACKPacket::Ptr pkt = std::make_shared(); + pkt->dst_socket_id = _peer_socket_id; + pkt->timestamp = DurationCountMicroseconds(_now - _start_timestamp); + pkt->ack_number = ++_ack_number_count; + pkt->last_ack_pkt_seq_number = _recv_buf->getExpectedSeq(); + pkt->rtt = _rtt; + pkt->rtt_variance = _rtt_variance; + pkt->available_buf_size = _recv_buf->getAvailableBufferSize(); + pkt->pkt_recv_rate = _pkt_recv_rate_context->getPacketRecvRate(); + pkt->estimated_link_capacity = _estimated_link_capacity_context->getEstimatedLinkCapacity(); + pkt->recv_rate = _recv_rate_context->getRecvRate(); + pkt->storeToData(); + _ack_send_timestamp[pkt->ack_number] = _now; + _last_ack_pkt_seq_num = pkt->last_ack_pkt_seq_number; + sendControlPacket(pkt, true); + // TraceL<<"send ack "<dump(); +} + +void SrtTransport::sendLightACKPacket() { + ACKPacket::Ptr pkt = std::make_shared(); + + pkt->dst_socket_id = _peer_socket_id; + pkt->timestamp = DurationCountMicroseconds(_now - _start_timestamp); + pkt->ack_number = 0; + pkt->last_ack_pkt_seq_number = _recv_buf->getExpectedSeq(); + pkt->rtt = 0; + pkt->rtt_variance = 0; + pkt->available_buf_size = 0; + pkt->pkt_recv_rate = 0; + pkt->estimated_link_capacity = 0; + pkt->recv_rate = 0; + pkt->storeToData(); + _last_ack_pkt_seq_num = pkt->last_ack_pkt_seq_number; + sendControlPacket(pkt, true); + TraceL << "send ack " << pkt->dump(); +} + +void SrtTransport::sendNAKPacket(std::list &lost_list) { + NAKPacket::Ptr pkt = std::make_shared(); + + pkt->dst_socket_id = _peer_socket_id; + pkt->timestamp = DurationCountMicroseconds(_now - _start_timestamp); + pkt->lost_list = lost_list; + + pkt->storeToData(); + + // TraceL<<"send NAK "<dump(); + sendControlPacket(pkt, true); +} + +void SrtTransport::sendShutDown() { + ShutDownPacket::Ptr pkt = std::make_shared(); + pkt->dst_socket_id = _peer_socket_id; + pkt->timestamp = DurationCountMicroseconds(_now - _start_timestamp); + pkt->storeToData(); + sendControlPacket(pkt, true); +} + +void SrtTransport::handleDataPacket(uint8_t *buf, int len, struct sockaddr_storage *addr) { + DataPacket::Ptr pkt = std::make_shared(); + pkt->loadFromData(buf, len); + + std::list list; + //TraceL<<" seq="<< pkt->packet_seq_number<<" ts="<timestamp<<" size="<payloadSize()<<\ + //" PP="<<(int)pkt->PP<<" O="<<(int)pkt->O<<" kK="<<(int)pkt->KK<<" R="<<(int)pkt->R; + _recv_buf->inputPacket(pkt, list); + if (list.empty()) { + // when no data ok send nack to sender immediately + } else { + uint32_t last_seq; + for (auto data : list) { + last_seq = data->packet_seq_number; + if (_last_pkt_seq + 1 != data->packet_seq_number) { + TraceL << "pkt lost " << _last_pkt_seq + 1 << "->" << data->packet_seq_number; + } + _last_pkt_seq = data->packet_seq_number; + onSRTData(std::move(data)); + } + _recv_nack.drop(last_seq); + } + + auto lost = _recv_buf->getLostSeq(); + _recv_nack.update(_now, lost); + lost.clear(); + _recv_nack.getLostList(_now, _rtt, _rtt_variance, lost); + if (!lost.empty()) { + // TraceL << "check lost send nack immediately"; + sendNAKPacket(lost); + } + + auto nak_interval = (_rtt + _rtt_variance * 4) / 2; + if (nak_interval <= 20 * 1000) { + nak_interval = 20 * 1000; + } + + if (_nak_ticker.elapsedTime(_now) > nak_interval) { + // Periodic NAK reports + auto lost = _recv_buf->getLostSeq(); + if (!lost.empty()) { + sendNAKPacket(lost); + // TraceL<<"send NAK"; + } else { + // TraceL<<"lost is empty"; + } + _nak_ticker.resetTime(_now); + } + + if (_ack_ticker.elapsedTime(_now) > 10 * 1000) { + _light_ack_pkt_count = 0; + _ack_ticker.resetTime(_now); + // send a ack per 10 ms for receiver + sendACKPacket(); + } else { + if (_light_ack_pkt_count >= 64) { + // for high bitrate stream send light ack + // TODO + sendLightACKPacket(); + TraceL << "send light ack"; + } + _light_ack_pkt_count = 0; + } + _light_ack_pkt_count++; + // bufCheckInterval(); +} + +void SrtTransport::sendDataPacket(DataPacket::Ptr pkt, char *buf, int len, bool flush) { + pkt->storeToData((uint8_t *)buf, len); + sendPacket(pkt, flush); + _send_buf->inputPacket(pkt); +} + +void SrtTransport::sendControlPacket(ControlPacket::Ptr pkt, bool flush) { + sendPacket(pkt, flush); +} + +void SrtTransport::sendPacket(Buffer::Ptr pkt, bool flush) { + if (_selected_session) { + auto tmp = _packet_pool.obtain2(); + tmp->assign(pkt->data(), pkt->size()); + _selected_session->setSendFlushFlag(flush); + _selected_session->send(std::move(tmp)); + } else { + WarnL << "not reach this"; + } +} + +std::string SrtTransport::getIdentifier() { + return _selected_session ? _selected_session->getIdentifier() : ""; +} + +void SrtTransport::registerSelfHandshake() { + SrtTransportManager::Instance().addHandshakeItem(std::to_string(_sync_cookie), shared_from_this()); +} + +void SrtTransport::unregisterSelfHandshake() { + if (_sync_cookie == 0) { + return; + } + SrtTransportManager::Instance().removeHandshakeItem(std::to_string(_sync_cookie)); +} + +void SrtTransport::registerSelf() { + if (_socket_id == 0) { + return; + } + SrtTransportManager::Instance().addItem(std::to_string(_socket_id), shared_from_this()); +} + +void SrtTransport::unregisterSelf() { + SrtTransportManager::Instance().removeItem(std::to_string(_socket_id)); +} + +void SrtTransport::onShutdown(const SockException &ex) { + sendShutDown(); + WarnL << ex.what(); + unregisterSelfHandshake(); + unregisterSelf(); + for (auto &pr : _history_sessions) { + auto session = pr.second.lock(); + if (session) { + session->shutdown(ex); + } + } +} + +size_t SrtTransport::getPayloadSize() { + size_t ret = (_mtu - 28 - 16) / 188 * 188; + return ret; +} + +void SrtTransport::onSendTSData(const Buffer::Ptr &buffer, bool flush) { + // TraceL; + DataPacket::Ptr pkt; + size_t payloadSize = getPayloadSize(); + size_t size = buffer->size(); + char *ptr = buffer->data(); + char *end = buffer->data() + size; + + while (ptr < end && size >= payloadSize) { + pkt = std::make_shared(); + pkt->f = 0; + pkt->packet_seq_number = _send_packet_seq_number & 0x7fffffff; + _send_packet_seq_number = (_send_packet_seq_number + 1) & 0x7fffffff; + pkt->PP = 3; + pkt->O = 0; + pkt->KK = 0; + pkt->R = 0; + pkt->msg_number = _send_msg_number++; + pkt->dst_socket_id = _peer_socket_id; + pkt->timestamp = DurationCountMicroseconds(SteadyClock::now() - _start_timestamp); + sendDataPacket(pkt, ptr, (int)payloadSize, flush); + ptr += payloadSize; + size -= payloadSize; + } + + if (size > 0 && ptr < end) { + pkt = std::make_shared(); + pkt->f = 0; + pkt->packet_seq_number = _send_packet_seq_number & 0x7fffffff; + _send_packet_seq_number = (_send_packet_seq_number + 1) & 0x7fffffff; + pkt->PP = 3; + pkt->O = 0; + pkt->KK = 0; + pkt->R = 0; + pkt->msg_number = _send_msg_number++; + pkt->dst_socket_id = _peer_socket_id; + pkt->timestamp = DurationCountMicroseconds(SteadyClock::now() - _start_timestamp); + sendDataPacket(pkt, ptr, (int)size, flush); + } +} + +//////////// SrtTransportManager ////////////////////////// + +SrtTransportManager &SrtTransportManager::Instance() { + static SrtTransportManager s_instance; + return s_instance; +} + +void SrtTransportManager::addItem(const std::string &key, const SrtTransport::Ptr &ptr) { + std::lock_guard lck(_mtx); + _map[key] = ptr; +} + +SrtTransport::Ptr SrtTransportManager::getItem(const std::string &key) { + if (key.empty()) { + return nullptr; + } + std::lock_guard lck(_mtx); + auto it = _map.find(key); + if (it == _map.end()) { + return nullptr; + } + return it->second.lock(); +} + +void SrtTransportManager::removeItem(const std::string &key) { + std::lock_guard lck(_mtx); + _map.erase(key); +} + +void SrtTransportManager::addHandshakeItem(const std::string &key, const SrtTransport::Ptr &ptr) { + std::lock_guard lck(_handshake_mtx); + _handshake_map[key] = ptr; +} + +void SrtTransportManager::removeHandshakeItem(const std::string &key) { + std::lock_guard lck(_handshake_mtx); + _handshake_map.erase(key); +} + +SrtTransport::Ptr SrtTransportManager::getHandshakeItem(const std::string &key) { + if (key.empty()) { + return nullptr; + } + std::lock_guard lck(_handshake_mtx); + auto it = _handshake_map.find(key); + if (it == _handshake_map.end()) { + return nullptr; + } + return it->second.lock(); +} + +} // namespace SRT \ No newline at end of file diff --git a/srt/SrtTransport.hpp b/srt/SrtTransport.hpp new file mode 100644 index 00000000..27281063 --- /dev/null +++ b/srt/SrtTransport.hpp @@ -0,0 +1,169 @@ +#ifndef ZLMEDIAKIT_SRT_TRANSPORT_H +#define ZLMEDIAKIT_SRT_TRANSPORT_H + +#include +#include +#include +#include + +#include "Network/Session.h" +#include "Poller/EventPoller.h" +#include "Poller/Timer.h" + +#include "Common.hpp" +#include "NackContext.hpp" +#include "Packet.hpp" +#include "PacketQueue.hpp" +#include "PacketSendQueue.hpp" +#include "Statistic.hpp" +namespace SRT { + +using namespace toolkit; + +extern const std::string kPort; +extern const std::string kTimeOutSec; +extern const std::string kLatencyMul; + +class SrtTransport : public std::enable_shared_from_this { +public: + friend class SrtSession; + using Ptr = std::shared_ptr; + + SrtTransport(const EventPoller::Ptr &poller); + virtual ~SrtTransport(); + const EventPoller::Ptr &getPoller() const; + void setSession(Session::Ptr session); + const Session::Ptr &getSession() const; + + /** + * socket收到udp数据 + * @param buf 数据指针 + * @param len 数据长度 + * @param addr 数据来源地址 + */ + virtual void inputSockData(uint8_t *buf, int len, struct sockaddr_storage *addr); + virtual void onSendTSData(const Buffer::Ptr &buffer, bool flush); + + std::string getIdentifier(); + void unregisterSelf(); + void unregisterSelfHandshake(); + +protected: + virtual bool isPusher() { return true; }; + virtual void onSRTData(DataPacket::Ptr pkt) {}; + virtual void onShutdown(const SockException &ex); + virtual void onHandShakeFinished(std::string &streamid, struct sockaddr_storage *addr) {}; + virtual void sendPacket(Buffer::Ptr pkt, bool flush = true); + virtual int getLatencyMul() { return 4; }; + +private: + void registerSelf(); + void registerSelfHandshake(); + + void switchToOtherTransport(uint8_t *buf, int len, uint32_t socketid, struct sockaddr_storage *addr); + + void handleHandshake(uint8_t *buf, int len, struct sockaddr_storage *addr); + void handleHandshakeInduction(HandshakePacket &pkt, struct sockaddr_storage *addr); + void handleHandshakeConclusion(HandshakePacket &pkt, struct sockaddr_storage *addr); + + void handleKeeplive(uint8_t *buf, int len, struct sockaddr_storage *addr); + void handleACK(uint8_t *buf, int len, struct sockaddr_storage *addr); + void handleACKACK(uint8_t *buf, int len, struct sockaddr_storage *addr); + void handleNAK(uint8_t *buf, int len, struct sockaddr_storage *addr); + void handleCongestionWarning(uint8_t *buf, int len, struct sockaddr_storage *addr); + void handleShutDown(uint8_t *buf, int len, struct sockaddr_storage *addr); + void handleDropReq(uint8_t *buf, int len, struct sockaddr_storage *addr); + void handleUserDefinedType(uint8_t *buf, int len, struct sockaddr_storage *addr); + void handlePeerError(uint8_t *buf, int len, struct sockaddr_storage *addr); + void handleDataPacket(uint8_t *buf, int len, struct sockaddr_storage *addr); + + void sendNAKPacket(std::list &lost_list); + void sendACKPacket(); + void sendLightACKPacket(); + void sendKeepLivePacket(); + void sendShutDown(); + void sendMsgDropReq(uint32_t first, uint32_t last); + + size_t getPayloadSize(); + +protected: + void sendDataPacket(DataPacket::Ptr pkt, char *buf, int len, bool flush = false); + void sendControlPacket(ControlPacket::Ptr pkt, bool flush = true); + +private: + // 当前选中的udp链接 + Session::Ptr _selected_session; + // 链接迁移前后使用过的udp链接 + std::unordered_map> _history_sessions; + + EventPoller::Ptr _poller; + + uint32_t _peer_socket_id; + uint32_t _socket_id = 0; + + TimePoint _now; + TimePoint _start_timestamp; + + // for calculate rtt for delay + TimePoint _induction_ts; + + uint32_t _mtu = 1500; + uint32_t _max_window_size = 8192; + uint32_t _init_seq_number = 0; + + std::string _stream_id; + uint32_t _sync_cookie = 0; + uint32_t _send_packet_seq_number = 0; + uint32_t _send_msg_number = 1; + + PacketSendQueue::Ptr _send_buf; + uint32_t _buf_delay = 120; + PacketQueue::Ptr _recv_buf; + NackContext _recv_nack; + uint32_t _rtt = 100 * 1000; + uint32_t _rtt_variance = 50 * 1000; + uint32_t _light_ack_pkt_count = 0; + uint32_t _ack_number_count = 0; + uint32_t _last_ack_pkt_seq_num = 0; + + uint32_t _last_pkt_seq = 0; + UTicker _ack_ticker; + std::map _ack_send_timestamp; + + std::shared_ptr _pkt_recv_rate_context; + std::shared_ptr _estimated_link_capacity_context; + std::shared_ptr _recv_rate_context; + + UTicker _nak_ticker; + + // 保持发送的握手消息,防止丢失重发 + HandshakePacket::Ptr _handleshake_res; + + ResourcePool _packet_pool; +}; + +class SrtTransportManager { +public: + static SrtTransportManager &Instance(); + SrtTransport::Ptr getItem(const std::string &key); + void addItem(const std::string &key, const SrtTransport::Ptr &ptr); + void removeItem(const std::string &key); + + void addHandshakeItem(const std::string &key, const SrtTransport::Ptr &ptr); + void removeHandshakeItem(const std::string &key); + SrtTransport::Ptr getHandshakeItem(const std::string &key); + +private: + SrtTransportManager() = default; + +private: + std::mutex _mtx; + std::unordered_map> _map; + + std::mutex _handshake_mtx; + std::unordered_map> _handshake_map; +}; + +} // namespace SRT + +#endif // ZLMEDIAKIT_SRT_TRANSPORT_H \ No newline at end of file diff --git a/srt/SrtTransportImp.cpp b/srt/SrtTransportImp.cpp new file mode 100644 index 00000000..1ae37064 --- /dev/null +++ b/srt/SrtTransportImp.cpp @@ -0,0 +1,282 @@ +#include +#include "Util/util.h" + +#include "SrtTransportImp.hpp" + +namespace SRT { +SrtTransportImp::SrtTransportImp(const EventPoller::Ptr &poller) : SrtTransport(poller) {} + +SrtTransportImp::~SrtTransportImp() { + InfoP(this); + uint64_t duration = _alive_ticker.createdTime() / 1000; + WarnP(this) << (_is_pusher ? "srt 推流器(" : "srt 播放器(") << _media_info._vhost << "/" << _media_info._app << "/" + << _media_info._streamid << ")断开,耗时(s):" << duration; + + //流量统计事件广播 + GET_CONFIG(uint32_t, iFlowThreshold, General::kFlowThreshold); + if (_total_bytes >= iFlowThreshold * 1024) { + NoticeCenter::Instance().emitEvent( + Broadcast::kBroadcastFlowReport, _media_info, _total_bytes, duration, false, + static_cast(*this)); + } +} + +void SrtTransportImp::onHandShakeFinished(std::string &streamid, struct sockaddr_storage *addr) { + // TODO parse stream id like this zlmediakit.com/live/test?token=1213444&type=push + if (!_addr) { + _addr.reset(new sockaddr_storage(*((sockaddr_storage *)addr))); + } + _is_pusher = false; + TraceL << " stream id " << streamid; + if (streamid.empty()) { + onShutdown(SockException(Err_shutdown, "stream id not empty")); + return; + } + + _media_info.parse("srt://" + streamid); + + auto params = Parser::parseArgs(_media_info._param_strs); + if (params["type"] == "push") { + _is_pusher = true; + _decoder = DecoderImp::createDecoder(DecoderImp::decoder_ts, this); + emitOnPublish(); + } else { + _is_pusher = false; + emitOnPlay(); + } +} + +void SrtTransportImp::onSRTData(DataPacket::Ptr pkt) { + if (!_is_pusher) { + WarnP(this) << "this is a player data ignore"; + return; + } + if (_decoder) { + _decoder->input(reinterpret_cast(pkt->payloadData()), pkt->payloadSize()); + } else { + WarnP(this) << " not reach this"; + } +} + +void SrtTransportImp::onShutdown(const SockException &ex) { + SrtTransport::onShutdown(ex); +} + +bool SrtTransportImp::close(mediakit::MediaSource &sender, bool force) { + if (!force && totalReaderCount(sender)) { + return false; + } + std::string err = StrPrinter << "close media:" << sender.getSchema() << "/" + << sender.getVhost() << "/" + << sender.getApp() << "/" + << sender.getId() << " " << force; + weak_ptr weak_self = static_pointer_cast(shared_from_this()); + getPoller()->async([weak_self, err]() { + auto strong_self = weak_self.lock(); + if (strong_self) { + strong_self->onShutdown(SockException(Err_shutdown, err)); + //主动关闭推流,那么不延时注销 + strong_self->_muxer = nullptr; + } + }); + return true; +} + +// 播放总人数 +int SrtTransportImp::totalReaderCount(mediakit::MediaSource &sender) { + return _muxer ? _muxer->totalReaderCount() : sender.readerCount(); +} + +// 获取媒体源类型 +mediakit::MediaOriginType SrtTransportImp::getOriginType(mediakit::MediaSource &sender) const { + return MediaOriginType::srt_push; +} + +// 获取媒体源url或者文件路径 +std::string SrtTransportImp::getOriginUrl(mediakit::MediaSource &sender) const { + return _media_info._full_url; +} + +// 获取媒体源客户端相关信息 +std::shared_ptr SrtTransportImp::getOriginSock(mediakit::MediaSource &sender) const { + return static_pointer_cast(getSession()); +} + +void SrtTransportImp::emitOnPublish() { + std::weak_ptr weak_self = static_pointer_cast(shared_from_this()); + Broadcast::PublishAuthInvoker invoker = [weak_self](const std::string &err, const ProtocolOption &option) { + auto strong_self = weak_self.lock(); + if (!strong_self) { + return; + } + if (err.empty()) { + strong_self->_muxer = std::make_shared( + strong_self->_media_info._vhost, strong_self->_media_info._app, strong_self->_media_info._streamid, + 0.0f, option); + strong_self->_muxer->setMediaListener(strong_self); + strong_self->doCachedFunc(); + InfoP(strong_self) << "允许 srt 推流"; + } else { + WarnP(strong_self) << "禁止 srt 推流:" << err; + strong_self->onShutdown(SockException(Err_refused, err)); + } + }; + + //触发推流鉴权事件 + auto flag = NoticeCenter::Instance().emitEvent( + Broadcast::kBroadcastMediaPublish, MediaOriginType::srt_push, _media_info, invoker, + static_cast(*this)); + if (!flag) { + //该事件无人监听,默认不鉴权 + invoker("", ProtocolOption()); + } +} + +void SrtTransportImp::emitOnPlay() { + std::weak_ptr weak_self = static_pointer_cast(shared_from_this()); + Broadcast::AuthInvoker invoker = [weak_self](const string &err) { + auto strong_self = weak_self.lock(); + if (!strong_self) { + return; + } + strong_self->getPoller()->async([strong_self, err] { + if (err != "") { + strong_self->onShutdown(SockException(Err_refused, err)); + } else { + strong_self->doPlay(); + } + }); + }; + + auto flag = NoticeCenter::Instance().emitEvent( + Broadcast::kBroadcastMediaPlayed, _media_info, invoker, static_cast(*this)); + if (!flag) { + doPlay(); + } +} + +void SrtTransportImp::doPlay() { + //异步查找直播流 + MediaInfo info = _media_info; + info._schema = TS_SCHEMA; + std::weak_ptr weak_self = static_pointer_cast(shared_from_this()); + MediaSource::findAsync(info, getSession(), [weak_self](const MediaSource::Ptr &src) { + auto strong_self = weak_self.lock(); + if (!strong_self) { + //本对象已经销毁 + TraceL << "本对象已经销毁"; + return; + } + if (!src) { + //未找到该流 + TraceL << "未找到该流"; + strong_self->onShutdown(SockException(Err_shutdown)); + } else { + TraceL << "找到该流"; + auto ts_src = dynamic_pointer_cast(src); + assert(ts_src); + ts_src->pause(false); + strong_self->_ts_reader = ts_src->getRing()->attach(strong_self->getPoller()); + strong_self->_ts_reader->setDetachCB([weak_self]() { + auto strong_self = weak_self.lock(); + if (!strong_self) { + //本对象已经销毁 + return; + } + strong_self->onShutdown(SockException(Err_shutdown)); + }); + strong_self->_ts_reader->setReadCB([weak_self](const TSMediaSource::RingDataType &ts_list) { + auto strong_self = weak_self.lock(); + if (!strong_self) { + //本对象已经销毁 + return; + } + size_t i = 0; + auto size = ts_list->size(); + ts_list->for_each([&](const TSPacket::Ptr &ts) { strong_self->onSendTSData(ts, ++i == size); }); + }); + } + }); +} + +std::string SrtTransportImp::get_peer_ip() { + if (!_addr) { + return "::"; + } + return SockUtil::inet_ntoa((sockaddr *)_addr.get()); +} + +uint16_t SrtTransportImp::get_peer_port() { + if (!_addr) { + return 0; + } + return SockUtil::inet_port((sockaddr *)_addr.get()); +} + +std::string SrtTransportImp::get_local_ip() { + auto s = getSession(); + if (s) { + return s->get_local_ip(); + } + return "::"; +} + +uint16_t SrtTransportImp::get_local_port() { + auto s = getSession(); + if (s) { + return s->get_local_port(); + } + return 0; +} + +std::string SrtTransportImp::getIdentifier() const { + return _media_info._streamid; +} + +bool SrtTransportImp::inputFrame(const Frame::Ptr &frame) { + if (_muxer) { + return _muxer->inputFrame(frame); + } + if (_cached_func.size() > 200) { + WarnL << "cached frame of track(" << frame->getCodecName() << ") is too much, now dropped"; + return false; + } + auto frame_cached = Frame::getCacheAbleFrame(frame); + lock_guard lck(_func_mtx); + _cached_func.emplace_back([this, frame_cached]() { _muxer->inputFrame(frame_cached); }); + return true; +} + +bool SrtTransportImp::addTrack(const Track::Ptr &track) { + if (_muxer) { + return _muxer->addTrack(track); + } + + lock_guard lck(_func_mtx); + _cached_func.emplace_back([this, track]() { _muxer->addTrack(track); }); + return true; +} + +void SrtTransportImp::addTrackCompleted() { + if (_muxer) { + _muxer->addTrackCompleted(); + } else { + lock_guard lck(_func_mtx); + _cached_func.emplace_back([this]() { _muxer->addTrackCompleted(); }); + } +} + +void SrtTransportImp::doCachedFunc() { + lock_guard lck(_func_mtx); + for (auto &func : _cached_func) { + func(); + } + _cached_func.clear(); +} + +int SrtTransportImp::getLatencyMul() { + GET_CONFIG(int, latencyMul, kLatencyMul); + return latencyMul; +} + +} // namespace SRT \ No newline at end of file diff --git a/srt/SrtTransportImp.hpp b/srt/SrtTransportImp.hpp new file mode 100644 index 00000000..67cd23ec --- /dev/null +++ b/srt/SrtTransportImp.hpp @@ -0,0 +1,92 @@ +#ifndef ZLMEDIAKIT_SRT_TRANSPORT_IMP_H +#define ZLMEDIAKIT_SRT_TRANSPORT_IMP_H + +#include +#include "Rtp/Decoder.h" +#include "SrtTransport.hpp" +#include "TS/TSMediaSource.h" +#include "Common/MultiMediaSourceMuxer.h" + +namespace SRT { + +using namespace std; +using namespace toolkit; +using namespace mediakit; +class SrtTransportImp + : public SrtTransport + , public toolkit::SockInfo + , public MediaSinkInterface + , public mediakit::MediaSourceEvent { +public: + SrtTransportImp(const EventPoller::Ptr &poller); + ~SrtTransportImp(); + + void inputSockData(uint8_t *buf, int len, struct sockaddr_storage *addr) override { + SrtTransport::inputSockData(buf, len, addr); + _total_bytes += len; + } + void onSendTSData(const Buffer::Ptr &buffer, bool flush) override { SrtTransport::onSendTSData(buffer, flush); } + /// SockInfo override + std::string get_local_ip() override; + uint16_t get_local_port() override; + std::string get_peer_ip() override; + uint16_t get_peer_port() override; + std::string getIdentifier() const override; + +protected: + ///////SrtTransport override/////// + int getLatencyMul() override; + void onSRTData(DataPacket::Ptr pkt) override; + void onShutdown(const SockException &ex) override; + void onHandShakeFinished(std::string &streamid, struct sockaddr_storage *addr) override; + + void sendPacket(Buffer::Ptr pkt, bool flush = true) override { + _total_bytes += pkt->size(); + SrtTransport::sendPacket(pkt, flush); + } + + bool isPusher() override { return _is_pusher; } + + ///////MediaSourceEvent override/////// + // 关闭 + bool close(mediakit::MediaSource &sender, bool force) override; + // 播放总人数 + int totalReaderCount(mediakit::MediaSource &sender) override; + // 获取媒体源类型 + mediakit::MediaOriginType getOriginType(mediakit::MediaSource &sender) const override; + // 获取媒体源url或者文件路径 + std::string getOriginUrl(mediakit::MediaSource &sender) const override; + // 获取媒体源客户端相关信息 + std::shared_ptr getOriginSock(mediakit::MediaSource &sender) const override; + + ///////MediaSinkInterface override/////// + void resetTracks() override {}; + void addTrackCompleted() override; + bool addTrack(const Track::Ptr &track) override; + bool inputFrame(const Frame::Ptr &frame) override; + +private: + void emitOnPublish(); + void emitOnPlay(); + + void doPlay(); + void doCachedFunc(); + +private: + bool _is_pusher = true; + MediaInfo _media_info; + uint64_t _total_bytes = 0; + Ticker _alive_ticker; + std::unique_ptr _addr; + // for player + TSMediaSource::RingType::RingReader::Ptr _ts_reader; + // for pusher + MultiMediaSourceMuxer::Ptr _muxer; + DecoderImp::Ptr _decoder; + std::recursive_mutex _func_mtx; + std::deque> _cached_func; +}; + +} // namespace SRT + +#endif // ZLMEDIAKIT_SRT_TRANSPORT_IMP_H diff --git a/srt/Statistic.cpp b/srt/Statistic.cpp new file mode 100644 index 00000000..446fc6fe --- /dev/null +++ b/srt/Statistic.cpp @@ -0,0 +1,98 @@ +#include + +#include "Statistic.hpp" + +namespace SRT { + +void PacketRecvRateContext::inputPacket(TimePoint &ts) { + if (_pkt_map.size() > 100) { + _pkt_map.erase(_pkt_map.begin()); + } + auto tmp = DurationCountMicroseconds(ts - _start); + _pkt_map.emplace(tmp, tmp); +} + +uint32_t PacketRecvRateContext::getPacketRecvRate() { + if (_pkt_map.size() < 2) { + return 50000; + } + int64_t dur = 1000; + for (auto it = _pkt_map.begin(); it != _pkt_map.end(); ++it) { + auto next = it; + ++next; + if (next == _pkt_map.end()) { + break; + } + + if ((next->first - it->first) < dur) { + dur = next->first - it->first; + } + } + + double rate = 1e6 / (double)dur; + if (rate <= 1000) { + return 50000; + } + return rate; +} + +void EstimatedLinkCapacityContext::inputPacket(TimePoint &ts) { + if (_pkt_map.size() > 16) { + _pkt_map.erase(_pkt_map.begin()); + } + auto tmp = DurationCountMicroseconds(ts - _start); + _pkt_map.emplace(tmp, tmp); +} + +uint32_t EstimatedLinkCapacityContext::getEstimatedLinkCapacity() { + decltype(_pkt_map.begin()) next; + std::vector tmp; + + for (auto it = _pkt_map.begin(); it != _pkt_map.end(); ++it) { + next = it; + ++next; + if (next != _pkt_map.end()) { + tmp.push_back(next->first - it->first); + } else { + break; + } + } + std::sort(tmp.begin(), tmp.end()); + if (tmp.empty()) { + return 1000; + } + + if (tmp.size() < 16) { + return 1000; + } + + double dur = tmp[0] / 1e6; + return (uint32_t)(1.0 / dur); +} + +void RecvRateContext::inputPacket(TimePoint &ts, size_t size) { + if (_pkt_map.size() > 100) { + _pkt_map.erase(_pkt_map.begin()); + } + auto tmp = DurationCountMicroseconds(ts - _start); + _pkt_map.emplace(tmp, tmp); +} + +uint32_t RecvRateContext::getRecvRate() { + if (_pkt_map.size() < 2) { + return 0; + } + + auto first = _pkt_map.begin(); + auto last = _pkt_map.rbegin(); + double dur = (last->first - first->first) / 1000000.0; + + size_t bytes = 0; + for (auto it : _pkt_map) { + bytes += it.second; + } + double rate = (double)bytes / dur; + return (uint32_t)rate; +} + +} // namespace SRT \ No newline at end of file diff --git a/srt/Statistic.hpp b/srt/Statistic.hpp new file mode 100644 index 00000000..d2a5036e --- /dev/null +++ b/srt/Statistic.hpp @@ -0,0 +1,49 @@ +#ifndef ZLMEDIAKIT_SRT_STATISTIC_H +#define ZLMEDIAKIT_SRT_STATISTIC_H +#include + +#include "Common.hpp" +#include "Packet.hpp" + +namespace SRT { + +class PacketRecvRateContext { +public: + PacketRecvRateContext(TimePoint start) + : _start(start) {}; + ~PacketRecvRateContext() = default; + void inputPacket(TimePoint &ts); + uint32_t getPacketRecvRate(); + +private: + TimePoint _start; + std::map _pkt_map; +}; + +class EstimatedLinkCapacityContext { +public: + EstimatedLinkCapacityContext(TimePoint start) : _start(start) {}; + ~EstimatedLinkCapacityContext() = default; + void inputPacket(TimePoint &ts); + uint32_t getEstimatedLinkCapacity(); + +private: + TimePoint _start; + std::map _pkt_map; +}; + +class RecvRateContext { +public: + RecvRateContext(TimePoint start) + : _start(start) {}; + ~RecvRateContext() = default; + void inputPacket(TimePoint &ts, size_t size); + uint32_t getRecvRate(); + +private: + TimePoint _start; + std::map _pkt_map; +}; + +} // namespace SRT +#endif // ZLMEDIAKIT_SRT_STATISTIC_H \ No newline at end of file diff --git a/srt/srt.md b/srt/srt.md new file mode 100644 index 00000000..2d504720 --- /dev/null +++ b/srt/srt.md @@ -0,0 +1,25 @@ +## 特性 +- NACK(重传) +- listener 支持 +- 推流只支持ts推流 +- 拉流只支持ts拉流 +- 协议实现 [参考](https://haivision.github.io/srt-rfc/draft-sharabayko-srt.html) +- 版本支持(>=1.3.0) +- fec与加密没有实现 + +## 使用 + +zlm中的srt更加streamid 来确定是推流还是拉流,来确定vhost,app,streamid(ZLM中的) +srt中的streamid 为 `//?type=& =` + +- OBS 推流地址 + + `srt://192.168.1.105:9000?streamid=__defaultVhost__/live/test?type=push` +- ffmpeg 推流 + + `ffmpeg -re -stream_loop -1 -i test.ts -c:v copy -c:a copy -f mpegts srt://192.168.1.105:9000?streamid="__defaultVhost__/live/test?type=push"` +- ffplay 拉流 + + `ffplay -i srt://192.168.1.105:9000?streamid=__defaultVhost__/live/test` + +- vlc 不支持,因为无法指定streamid[参考](https://github.com/Haivision/srt/issues/1015) \ No newline at end of file diff --git a/srt/srt_en.md b/srt/srt_en.md new file mode 100644 index 00000000..c4ef2b1d --- /dev/null +++ b/srt/srt_en.md @@ -0,0 +1,24 @@ +## feature +- NACK support +- listener support +- push stream payload must ts +- pull stream payload is ts +- protocol impliment [reference](https://haivision.github.io/srt-rfc/draft-sharabayko-srt.html) +- version support (>=1.3.0) +- fec and encriyped not support + +## usage + +zlm get vhost,app,streamid and push or play by streamid of srt like this `//?type=& =` + +- OBS push stream url + + `srt://192.168.1.105:9000?streamid=__defaultVhost__/live/test?type=push` +- ffmpeg push + + `ffmpeg -re -stream_loop -1 -i test.ts -c:v copy -c:a copy -f mpegts srt://192.168.1.105:9000?streamid="__defaultVhost__/live/test?type=push"` +- ffplay pull + + `ffplay -i srt://192.168.1.105:9000?streamid=__defaultVhost__/live/test` + +- vlc not support ,because can't set stream id [reference](https://github.com/Haivision/srt/issues/1015) \ No newline at end of file