first commit for srt intergrate

This commit is contained in:
xiongguangjie 2022-06-03 13:25:32 +08:00
parent 59dcd03b70
commit aa2ff01d9f
18 changed files with 2484 additions and 0 deletions

View File

@ -106,6 +106,7 @@ option(ENABLE_FFMPEG "Enable FFmpeg" true)
option(ENABLE_MSVC_MT "Enable MSVC Mt/Mtd lib" true)
option(ENABLE_API_STATIC_LIB "Enable mk_api static lib" false)
option(USE_SOLUTION_FOLDERS "Enable solution dir supported" ON)
option(ENABLE_SRT "Enable SRT" true)
# ----------------------------------------------------------------------------
# Solution folders:
# ----------------------------------------------------------------------------
@ -486,6 +487,15 @@ if (ENABLE_WEBRTC)
endif ()
endif ()
if (ENABLE_SRT)
add_definitions(-DENABLE_SRT)
include_directories(./srt)
file(GLOB SRC_SRT_LIST ./srt/*.cpp ./srt/*.h ./srt/*.hpp)
add_library(srt ${SRC_SRT_LIST})
list(APPEND LINK_LIB_LIST srt)
message(STATUS "srt 功能已开启")
endif()
#c
if (ENABLE_API)
add_subdirectory(api)

View File

@ -284,6 +284,14 @@ preferredCodecA=PCMU,PCMA,opus,mpeg4-generic
#以下范例为所有支持的视频codec
preferredCodecV=H264,H265,AV1X,VP9,VP8
[srt]
#srt播放推流、播放超时时间,单位秒
timeoutSec=5
#srt udp服务器监听端口号所有srt客户端将通过该端口传输srt数据
#该端口是多线程的,同时支持客户端网络切换导致的连接迁移
port=9000
[rtsp]
#rtsp专有鉴权方式是采用base64还是md5方式
authBasic=0

View File

@ -32,6 +32,11 @@
#include "../webrtc/WebRtcSession.h"
#endif
#if defined(ENABLE_SRT)
#include "../srt/SrtSession.hpp"
#include "../srt/SrtTransport.hpp"
#endif
#if defined(ENABLE_VERSION)
#include "version.h"
#endif
@ -284,6 +289,24 @@ int start_main(int argc,char *argv[]) {
uint16_t rtcPort = mINI::Instance()[RTC::kPort];
#endif//defined(ENABLE_WEBRTC)
#if defined(ENABLE_SRT)
auto srtSrv = std::make_shared<UdpServer>();
srtSrv->setOnCreateSocket([](const EventPoller::Ptr &poller, const Buffer::Ptr &buf, struct sockaddr *, int) {
if (!buf) {
return Socket::createSocket(poller, false);
}
auto new_poller = SRT::SrtSession::queryPoller(buf);
if (!new_poller) {
//握手第一阶段
return Socket::createSocket(poller, false);
}
return Socket::createSocket(new_poller, false);
});
uint16_t srtPort = mINI::Instance()[SRT::kPort];
#endif //defined(ENABLE_SRT)
try {
//rtsp服务器端口默认554
if (rtspPort) { rtspSrv->start<RtspSession>(rtspPort); }
@ -313,6 +336,14 @@ int start_main(int argc,char *argv[]) {
if (rtcPort) { rtcSrv->start<WebRtcSession>(rtcPort); }
#endif//defined(ENABLE_WEBRTC)
#if defined(ENABLE_SRT)
// srt udp服务器
if(srtPort){
srtSrv->start<SRT::SrtSession>(srtPort);
}
#endif//defined(ENABLE_SRT)
} catch (std::exception &ex) {
WarnL << "端口占用或无权限:" << ex.what() << endl;
ErrorL << "程序启动失败,请修改配置文件中端口号后重试!" << endl;

74
srt/Ack.cpp Normal file
View File

@ -0,0 +1,74 @@
#include "Ack.hpp"
#include "Common.hpp"
namespace SRT {
bool ACKPacket::loadFromData(uint8_t *buf, size_t len) {
if(len < ACK_CIF_SIZE + ControlPacket::HEADER_SIZE){
return false;
}
_data = BufferRaw::create();
_data->assign((char *)(buf), len);
ControlPacket::loadHeader();
ack_number = loadUint32(type_specific_info);
uint8_t* ptr = (uint8_t*)_data->data()+ControlPacket::HEADER_SIZE;
last_ack_pkt_seq_number = loadUint32(ptr);
ptr += 4;
rtt = loadUint32(ptr);
ptr += 4;
rtt_variance = loadUint32(ptr);
ptr += 4;
available_buf_size = loadUint32(ptr);
ptr += 4;
pkt_recv_rate = loadUint32(ptr);
ptr += 4;
estimated_link_capacity = loadUint32(ptr);
ptr += 4;
recv_rate = loadUint32(ptr);
ptr += 4;
return true;
}
bool ACKPacket::storeToData() {
_data = BufferRaw::create();
_data->setCapacity(HEADER_SIZE + ACK_CIF_SIZE);
_data->setSize(HEADER_SIZE + ACK_CIF_SIZE);
control_type = ControlPacket::ACK;
sub_type = 0;
storeUint32(type_specific_info,ack_number);
storeToHeader();
uint8_t* ptr = (uint8_t*)_data->data()+ControlPacket::HEADER_SIZE;
storeUint32(ptr,last_ack_pkt_seq_number);
ptr += 4;
storeUint32(ptr,rtt);
ptr += 4;
storeUint32(ptr,rtt_variance);
ptr += 4;
storeUint32(ptr,pkt_recv_rate);
ptr += 4;
storeUint32(ptr,available_buf_size);
ptr += 4;
storeUint32(ptr,estimated_link_capacity);
ptr += 4;
storeUint32(ptr,recv_rate);
ptr += 4;
return true;
}
} // namespace

96
srt/Ack.hpp Normal file
View File

@ -0,0 +1,96 @@
#ifndef ZLMEDIAKIT_SRT_ACK_H
#define ZLMEDIAKIT_SRT_ACK_H
#include "Packet.hpp"
namespace SRT{
/*
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+- SRT Header +-+-+-+-+-+-+-+-+-+-+-+-+-+
|1| Control Type | Reserved |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Acknowledgement Number |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Timestamp |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Destination Socket ID |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+- CIF -+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Last Acknowledged Packet Sequence Number |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| RTT |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| RTT Variance |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Available Buffer Size |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Packets Receiving Rate |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Estimated Link Capacity |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Receiving Rate |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
Figure 13: ACK control packet
https://haivision.github.io/srt-rfc/draft-sharabayko-srt.html#name-ack-acknowledgment
*/
class ACKPacket : public ControlPacket
{
public:
using Ptr = std::shared_ptr<ACKPacket>;
ACKPacket() = default;
~ACKPacket() = default;
enum{
ACK_CIF_SIZE = 7*4
};
///////ControlPacket override///////
bool loadFromData(uint8_t *buf, size_t len) override;
bool storeToData() override;
uint32_t ack_number;
uint32_t last_ack_pkt_seq_number;
uint32_t rtt;
uint32_t rtt_variance;
uint32_t available_buf_size;
uint32_t pkt_recv_rate;
uint32_t estimated_link_capacity;
uint32_t recv_rate;
};
class ACKACKPacket : public ControlPacket{
public:
using Ptr = std::shared_ptr<ACKACKPacket>;
ACKACKPacket() = default;
~ACKACKPacket() = default;
///////ControlPacket override///////
bool loadFromData(uint8_t *buf, size_t len) override{
if(len < ControlPacket::HEADER_SIZE){
return false;
}
_data = BufferRaw::create();
_data->assign((char *)(buf), len);
ControlPacket::loadHeader();
ack_number = loadUint32(type_specific_info);
return true;
}
bool storeToData() override{
_data = BufferRaw::create();
_data->setCapacity(HEADER_SIZE);
_data->setSize(HEADER_SIZE );
control_type = ControlPacket::ACKACK;
sub_type = 0;
storeUint32(type_specific_info,ack_number);
storeToHeader();
return true;
}
uint32_t ack_number;
};
} //namespace SRT
#endif // ZLMEDIAKIT_SRT_ACK_H

54
srt/Common.hpp Normal file
View File

@ -0,0 +1,54 @@
#ifndef ZLMEDIAKIT_SRT_COMMON_H
#define ZLMEDIAKIT_SRT_COMMON_H
#include <chrono>
namespace SRT
{
using SteadyClock = std::chrono::steady_clock;
using TimePoint = std::chrono::time_point<SteadyClock>;
using Microseconds = std::chrono::microseconds;
using Milliseconds = std::chrono::milliseconds;
inline int64_t DurationCountMicroseconds( SteadyClock::duration dur){
return std::chrono::duration_cast<std::chrono::microseconds>(dur).count();
}
inline uint32_t loadUint32(uint8_t *ptr) {
return ptr[0] << 24 | ptr[1] << 16 | ptr[2] << 8 | ptr[3];
}
inline uint16_t loadUint16(uint8_t *ptr) {
return ptr[0] << 8 | ptr[1];
}
inline void storeUint32(uint8_t *buf, uint32_t val) {
buf[0] = val >> 24;
buf[1] = (val >> 16) & 0xff;
buf[2] = (val >> 8) & 0xff;
buf[3] = val & 0xff;
}
inline void storeUint16(uint8_t *buf, uint16_t val) {
buf[0] = (val >> 8) & 0xff;
buf[1] = val & 0xff;
}
inline void storeUint32LE(uint8_t *buf, uint32_t val) {
buf[0] = val & 0xff;
buf[1] = (val >> 8) & 0xff;
buf[2] = (val >> 16) & 0xff;
buf[3] = (val >>24) & 0xff;
}
inline void storeUint16LE(uint8_t *buf, uint16_t val) {
buf[0] = val & 0xff;
buf[1] = (val>>8) & 0xff;
}
inline uint32_t srtVersion(int major, int minor, int patch)
{
return patch + minor*0x100 + major*0x10000;
}
} // namespace SRT
#endif //ZLMEDIAKIT_SRT_COMMON_H

127
srt/HSExt.cpp Normal file
View File

@ -0,0 +1,127 @@
#include "HSExt.hpp"
namespace SRT {
bool HSExtMessage::loadFromData(uint8_t *buf, size_t len) {
if(buf == NULL || len != HSEXT_MSG_SIZE){
return false;
}
_data = BufferRaw::create();
_data->assign((char*)buf,len);
extension_length = 3;
HSExt::loadHeader();
assert(extension_type == SRT_CMD_HSREQ || extension_type == SRT_CMD_HSRSP);
uint8_t* ptr = (uint8_t*)_data->data()+4;
srt_version = loadUint32(ptr);
ptr += 4;
srt_flag = loadUint32(ptr);
ptr += 4;
recv_tsbpd_delay = loadUint16(ptr);
ptr += 2;
send_tsbpd_delay = loadUint16(ptr);
ptr += 2;
return true;
}
std::string HSExtMessage::dump(){
_StrPrinter printer;
printer << "srt version : "<<std::hex<<srt_version<<" srt flag : "<<std::hex<<srt_flag<<\
" recv_tsbpd_delay="<<recv_tsbpd_delay<<" send_tsbpd_delay = "<<send_tsbpd_delay;
return std::move(printer);
}
bool HSExtMessage::storeToData() {
_data = BufferRaw::create();
_data->setCapacity(HSEXT_MSG_SIZE);
_data->setSize(HSEXT_MSG_SIZE);
extension_length = 3;
HSExt::storeHeader();
uint8_t* ptr = (uint8_t*)_data->data()+4;
storeUint32(ptr,srt_version);
ptr += 4;
storeUint32(ptr,srt_flag);
ptr += 4;
storeUint16(ptr,recv_tsbpd_delay);
ptr += 2;
storeUint16(ptr,send_tsbpd_delay);
ptr += 2;
return true;
}
bool HSExtStreamID::loadFromData(uint8_t *buf, size_t len) {
if(buf == NULL || len < 4){
return false;
}
_data = BufferRaw::create();
_data->assign((char*)buf,len);
HSExt::loadHeader();
size_t content_size = extension_length*4;
if(len < content_size+4){
return false;
}
streamid.clear();
char* ptr = _data->data()+4;
for(size_t i = 0; i<extension_length; ++i){
streamid.push_back(*(ptr+3));
streamid.push_back(*(ptr+2));
streamid.push_back(*(ptr+1));
streamid.push_back(*(ptr));
ptr+=4;
}
return true;
}
bool HSExtStreamID::storeToData() {
size_t content_size = ((streamid.length()+4)+3)/4*4;
_data = BufferRaw::create();
_data->setCapacity(content_size);
_data->setSize(content_size);
extension_length = (content_size-4)/4;
extension_type = SRT_CMD_SID;
HSExt::storeHeader();
auto ptr = _data->data()+4;
memset(ptr,0,extension_length*4);
const char* src = streamid.c_str();
for(size_t i = 0; i< streamid.length()/4;++i){
*ptr = *(src+3+i*4);
ptr++;
*ptr = *(src+2+i*4);
ptr++;
*ptr = *(src+1+i*4);
ptr++;
*ptr = *(src+0+i*4);
ptr++;
}
ptr += 3;
size_t offset = streamid.length()/4*4;
for(size_t i = 0; i<streamid.length()%4;++i){
*ptr = *(src+offset+i);
ptr -= 1;
}
return true;
}
std::string HSExtStreamID::dump(){
_StrPrinter printer;
printer << " streamid : "<< streamid;
return std::move(printer);
}
} // namespace SRT

128
srt/HSExt.hpp Normal file
View File

@ -0,0 +1,128 @@
#ifndef ZLMEDIAKIT_SRT_HS_EXT_H
#define ZLMEDIAKIT_SRT_HS_EXT_H
#include "Network/Buffer.h"
#include "Common.hpp"
namespace SRT {
using namespace toolkit;
class HSExt : public Buffer {
public:
HSExt() = default;
virtual ~HSExt() = default;
enum {
SRT_CMD_REJECT = 0,
SRT_CMD_HSREQ = 1,
SRT_CMD_HSRSP = 2,
SRT_CMD_KMREQ = 3,
SRT_CMD_KMRSP = 4,
SRT_CMD_SID = 5,
SRT_CMD_CONGESTION = 6,
SRT_CMD_FILTER = 7,
SRT_CMD_GROUP = 8,
SRT_CMD_NONE = -1
};
using Ptr = std::shared_ptr<HSExt>;
uint16_t extension_type;
uint16_t extension_length;
virtual bool loadFromData(uint8_t *buf, size_t len) = 0;
virtual bool storeToData() = 0;
virtual std::string dump() = 0;
///////Buffer override///////
char *data() const override {
if (_data) {
return _data->data();
}
return nullptr;
};
size_t size() const override {
if (_data) {
return _data->size();
}
return 0;
};
protected:
void loadHeader() {
uint8_t *ptr = (uint8_t *)_data->data();
extension_type = loadUint16(ptr);
ptr += 2;
extension_length = loadUint16(ptr);
ptr += 2;
}
void storeHeader() {
uint8_t *ptr = (uint8_t *)_data->data();
SRT::storeUint16(ptr, extension_type);
ptr += 2;
storeUint16(ptr, extension_length);
}
protected:
BufferRaw::Ptr _data;
};
/*
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| SRT Version |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| SRT Flags |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Receiver TSBPD Delay | Sender TSBPD Delay |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
Figure 6: Handshake Extension Message structure
https://haivision.github.io/srt-rfc/draft-sharabayko-srt.html#name-handshake-extension-message
*/
class HSExtMessage : public HSExt {
public:
using Ptr = std::shared_ptr<HSExtMessage>;
enum {
HS_EXT_MSG_TSBPDSND = 0x00000001,
HS_EXT_MSG_TSBPDRCV = 0x00000002,
HS_EXT_MSG_CRYPT = 0x00000004,
HS_EXT_MSG_TLPKTDROP = 0x00000008,
HS_EXT_MSG_PERIODICNAK = 0x00000010,
HS_EXT_MSG_REXMITFLG = 0x00000020,
HS_EXT_MSG_STREAM = 0x00000040,
HS_EXT_MSG_PACKET_FILTER = 0x00000080
};
enum { HSEXT_MSG_SIZE = 16 };
HSExtMessage() = default;
~HSExtMessage() = default;
bool loadFromData(uint8_t *buf, size_t len) override;
bool storeToData() override;
std::string dump() override;
uint32_t srt_version;
uint32_t srt_flag;
uint16_t recv_tsbpd_delay;
uint16_t send_tsbpd_delay;
};
/*
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| |
| Stream ID |
...
| |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
Figure 7: Stream ID Extension Message
https://haivision.github.io/srt-rfc/draft-sharabayko-srt.html#name-stream-id-extension-message
*/
class HSExtStreamID : public HSExt {
public:
using Ptr = std::shared_ptr<HSExtStreamID>;
HSExtStreamID() = default;
~HSExtStreamID() = default;
bool loadFromData(uint8_t *buf, size_t len) override;
bool storeToData() override;
std::string dump() override;
std::string streamid;
};
} // namespace SRT
#endif // ZLMEDIAKIT_SRT_HS_EXT_H

574
srt/Packet.cpp Normal file
View File

@ -0,0 +1,574 @@

#include "sys/socket.h"
#include "netdb.h"
#include <atomic>
#include "Util/logger.h"
#include "Util/MD5.h"
#include "Packet.hpp"
namespace SRT {
const size_t DataPacket::HEADER_SIZE;
const size_t ControlPacket::HEADER_SIZE;
const size_t HandshakePacket::HS_CONTENT_MIN_SIZE;
bool DataPacket::isDataPacket(uint8_t *buf, size_t len) {
if (len < HEADER_SIZE) {
WarnL << "data size" << len << " less " << HEADER_SIZE;
return false;
}
if (!(buf[0] & 0x80)) {
return true;
}
return false;
}
uint32_t DataPacket::getSocketID(uint8_t *buf, size_t len){
uint8_t *ptr = buf;
ptr += 12;
return loadUint32(ptr);
}
bool DataPacket::loadFromData(uint8_t *buf, size_t len) {
if (len < HEADER_SIZE) {
WarnL << "data size" << len << " less " << HEADER_SIZE;
return false;
}
uint8_t *ptr = buf;
f = ptr[0] >> 7;
packet_seq_number = loadUint32(ptr)&0x7fffffff;
ptr += 4;
PP = ptr[0] >> 6;
O = (ptr[0] & 0x20) >> 5;
KK = (ptr[0] & 0x18) >> 3;
R = (ptr[0] & 0x04) >> 2;
msg_number = (ptr[0] & 0x03) << 24 | ptr[1] << 12 | ptr[2] << 8 | ptr[3];
ptr += 4;
timestamp = loadUint32(ptr);
ptr += 4;
dst_socket_id = loadUint32(ptr);
ptr += 4;
_data = BufferRaw::create();
_data->assign((char *)(buf), len);
return true;
}
bool DataPacket::storeToData(uint8_t *buf, size_t len) {
_data = BufferRaw::create();
_data->setCapacity(len + HEADER_SIZE);
_data->setSize(len + HEADER_SIZE);
uint8_t *ptr = (uint8_t *)_data->data();
ptr[0] = packet_seq_number >> 24;
ptr[1] = (packet_seq_number >> 16) & 0xff;
ptr[2] = (packet_seq_number >> 8) & 0xff;
ptr[3] = packet_seq_number & 0xff;
ptr += 4;
ptr[0] = PP << 6;
ptr[0] |= O << 5;
ptr[0] |= KK << 3;
ptr[0] |= R << 2;
ptr[0] |= (msg_number & 0xff000000) >> 24;
ptr[1] = (msg_number & 0xff0000) >> 16;
ptr[2] = (msg_number & 0xff00) >> 8;
ptr[3] = msg_number & 0xff;
ptr += 4;
storeUint32(ptr, timestamp);
ptr += 4;
storeUint32(ptr, dst_socket_id);
ptr += 4;
memcpy(ptr, buf, len);
return true;
}
char *DataPacket::data() const {
if (!_data)
return nullptr;
return _data->data();
}
size_t DataPacket::size() const {
if (!_data) {
return 0;
}
return _data->size();
}
char *DataPacket::payloadData() {
if (!_data)
return nullptr;
return _data->data() + HEADER_SIZE;
}
size_t DataPacket::payloadSize() {
if (!_data) {
return 0;
}
return _data->size() - HEADER_SIZE;
}
bool ControlPacket::isControlPacket(uint8_t *buf, size_t len) {
if (len < HEADER_SIZE) {
WarnL << "data size" << len << " less " << HEADER_SIZE;
return false;
}
if (buf[0] & 0x80) {
return true;
}
return false;
}
uint16_t ControlPacket::getControlType(uint8_t *buf, size_t len) {
uint8_t *ptr = buf;
uint16_t control_type = (ptr[0] & 0x7f) << 8 | ptr[1];
return control_type;
}
bool ControlPacket::loadHeader() {
uint8_t *ptr = (uint8_t *)_data->data();
f = ptr[0] >> 7;
control_type = (ptr[0] & 0x7f) << 8 | ptr[1];
ptr += 2;
sub_type = loadUint16(ptr);
ptr += 2;
type_specific_info[0] = ptr[0];
type_specific_info[1] = ptr[1];
type_specific_info[2] = ptr[2];
type_specific_info[3] = ptr[3];
ptr += 4;
timestamp = loadUint32(ptr);
ptr += 4;
dst_socket_id = loadUint32(ptr);
ptr += 4;
return true;
}
bool ControlPacket::storeToHeader() {
uint8_t *ptr = (uint8_t *)_data->data();
ptr[0] = 0x80;
ptr[0] |= control_type >> 8;
ptr[1] = control_type & 0xff;
ptr += 2;
storeUint16(ptr, sub_type);
ptr += 2;
ptr[0] = type_specific_info[0];
ptr[1] = type_specific_info[1];
ptr[2] = type_specific_info[2];
ptr[3] = type_specific_info[3];
ptr += 4;
storeUint32(ptr, timestamp);
ptr += 4;
storeUint32(ptr, dst_socket_id);
ptr += 4;
return true;
}
char *ControlPacket::data() const {
if (!_data)
return nullptr;
return _data->data();
}
size_t ControlPacket::size() const {
if (!_data) {
return 0;
}
return _data->size();
}
uint32_t ControlPacket::getSocketID(uint8_t *buf, size_t len){
return loadUint32(buf+12);
}
bool HandshakePacket::loadFromData(uint8_t *buf, size_t len) {
if(HEADER_SIZE+HS_CONTENT_MIN_SIZE > len){
ErrorL << "size too smalle " << encryption_field;
return false;
}
_data = BufferRaw::create();
_data->assign((char *)(buf), len);
ControlPacket::loadHeader();
uint8_t *ptr = (uint8_t *)_data->data() + HEADER_SIZE;
// parse CIF
version = loadUint32(ptr);
ptr += 4;
encryption_field = loadUint16(ptr);
ptr += 2;
extension_field = loadUint16(ptr);
ptr += 2;
initial_packet_sequence_number = loadUint32(ptr);
ptr += 4;
mtu = loadUint32(ptr);
ptr += 4;
max_flow_window_size = loadUint32(ptr);
ptr += 4;
handshake_type = loadUint32(ptr);
ptr += 4;
srt_socket_id = loadUint32(ptr);
ptr += 4;
syn_cookie = loadUint32(ptr);
ptr += 4;
memcpy(peer_ip_addr, ptr, sizeof(peer_ip_addr) * sizeof(peer_ip_addr[0]));
ptr += sizeof(peer_ip_addr) * sizeof(peer_ip_addr[0]);
if (encryption_field != NO_ENCRYPTION) {
ErrorL << "not support encryption " << encryption_field;
}
if(extension_field == 0){
return true;
}
if(len == HEADER_SIZE+HS_CONTENT_MIN_SIZE){
//ErrorL << "extension filed not exist " << extension_field;
return true;
}
return loadExtMessage(ptr,len-HS_CONTENT_MIN_SIZE-HEADER_SIZE);
}
bool HandshakePacket::loadExtMessage(uint8_t *buf,size_t len){
uint8_t* ptr = buf;
ext_list.clear();
uint16_t type;
uint16_t length;
HSExt::Ptr ext;
while(ptr<buf+len){
type = loadUint16(ptr);
length = loadUint16(ptr+2);
switch (type)
{
case HSExt::SRT_CMD_HSREQ:
case HSExt::SRT_CMD_HSRSP:
ext = std::make_shared<HSExtMessage>();
break;
case HSExt::SRT_CMD_SID:
ext = std::make_shared<HSExtStreamID>();
break;
default:
WarnL<<"not support ext "<<type;
break;
}
if(ext){
if(ext->loadFromData(ptr,length*4+4)){
ext_list.push_back(std::move(ext));
}else{
WarnL<<"parse HS EXT failed type="<<type<<" len="<<length;
}
ext = nullptr;
}
ptr += length*4+4;
}
return true;
}
bool HandshakePacket::storeExtMessage()
{
uint8_t* buf = (uint8_t*)_data->data()+HEADER_SIZE+48;
size_t len = _data->size()- HEADER_SIZE-48;
for(auto ex : ext_list){
memcpy(buf,ex->data(),ex->size());
buf += ex->size();
}
return true;
}
size_t HandshakePacket::getExtSize(){
size_t size = 0;
for(auto it : ext_list){
size += it->size();
}
return size;
}
bool HandshakePacket::storeToData() {
_data = BufferRaw::create();
for(auto ex : ext_list){
ex->storeToData();
}
auto ext_size = getExtSize();
_data->setCapacity(HEADER_SIZE + 48+ext_size);
_data->setSize(HEADER_SIZE + 48+ext_size);
control_type = ControlPacket::HANDSHAKE;
sub_type = 0;
ControlPacket::storeToHeader();
uint8_t *ptr = (uint8_t *)_data->data() + HEADER_SIZE;
storeUint32(ptr, version);
ptr += 4;
storeUint16(ptr, encryption_field);
ptr += 2;
storeUint16(ptr, extension_field);
ptr += 2;
storeUint32(ptr, initial_packet_sequence_number);
ptr += 4;
storeUint32(ptr, mtu);
ptr += 4;
storeUint32(ptr, max_flow_window_size);
ptr += 4;
storeUint32(ptr, handshake_type);
ptr += 4;
storeUint32(ptr, srt_socket_id);
ptr += 4;
storeUint32(ptr, syn_cookie);
ptr += 4;
memcpy(ptr, peer_ip_addr, sizeof(peer_ip_addr) * sizeof(peer_ip_addr[0]));
ptr += sizeof(peer_ip_addr) * sizeof(peer_ip_addr[0]);
if (encryption_field != NO_ENCRYPTION) {
ErrorL << "not support encryption " << encryption_field;
}
assert(encryption_field == NO_ENCRYPTION);
return storeExtMessage();
}
bool HandshakePacket::isHandshakePacket(uint8_t *buf, size_t len){
if(!ControlPacket::isControlPacket(buf,len)){
return false;
}
if(len < HEADER_SIZE+48){
return false;
}
return ControlPacket::getControlType(buf,len) == HANDSHAKE;
}
uint32_t HandshakePacket::getHandshakeType(uint8_t *buf, size_t len){
uint8_t *ptr = buf+HEADER_SIZE+5*4;
return loadUint32(ptr);
}
uint32_t HandshakePacket::getSynCookie(uint8_t *buf, size_t len){
uint8_t *ptr = buf+HEADER_SIZE+7*4;
return loadUint32(ptr);
}
void HandshakePacket::assignPeerIP(struct sockaddr_storage* addr){
memset(peer_ip_addr,0,sizeof(peer_ip_addr)*sizeof(peer_ip_addr[0]));
if(addr->ss_family == AF_INET){
struct sockaddr_in * ipv4 = (struct sockaddr_in *)addr;
//抓包 奇怪好像是小头端???
storeUint32LE(peer_ip_addr,ipv4->sin_addr.s_addr);
}else{
const sockaddr_in6* ipv6 = (struct sockaddr_in6 *)addr;
memcpy(peer_ip_addr,ipv6->sin6_addr.s6_addr,sizeof(peer_ip_addr)*sizeof(peer_ip_addr[0]));
}
}
uint32_t HandshakePacket::generateSynCookie(struct sockaddr_storage* addr,TimePoint ts,uint32_t current_cookie, int correction ){
static std::atomic<uint32_t> distractor{0};
uint32_t rollover = distractor.load() + 10;
for (;;)
{
// SYN cookie
char clienthost[NI_MAXHOST];
char clientport[NI_MAXSERV];
getnameinfo((struct sockaddr*)addr,
sizeof(struct sockaddr_storage),
clienthost,
sizeof(clienthost),
clientport,
sizeof(clientport),
NI_NUMERICHOST | NI_NUMERICSERV);
int64_t timestamp = (DurationCountMicroseconds(SteadyClock::now() - ts) / 60000000) + distractor.load() +
correction; // secret changes every one minute
std::stringstream cookiestr;
cookiestr << clienthost << ":" << clientport << ":" << timestamp;
union {
unsigned char cookie[16];
uint32_t cookie_val;
};
MD5 md5(cookiestr.str());
memcpy(cookie,md5.rawdigest().c_str(),16);
if (cookie_val != current_cookie)
return cookie_val;
++distractor;
// This is just to make the loop formally breakable,
// but this is virtually impossible to happen.
if (distractor == rollover)
return cookie_val;
}
}
bool KeepLivePacket::loadFromData(uint8_t *buf, size_t len){
if (len < HEADER_SIZE) {
WarnL << "data size" << len << " less " << HEADER_SIZE;
return false;
}
_data = BufferRaw::create();
_data->assign((char*)buf,len);
return loadHeader();
}
bool KeepLivePacket::storeToData(){
control_type = ControlPacket::KEEPALIVE;
sub_type = 0;
_data = BufferRaw::create();
_data->setCapacity(HEADER_SIZE);
_data->setSize(HEADER_SIZE);
return storeToHeader();
}
bool NAKPacket::loadFromData(uint8_t *buf, size_t len) {
if (len < HEADER_SIZE) {
WarnL << "data size" << len << " less " << HEADER_SIZE;
return false;
}
_data = BufferRaw::create();
_data->assign((char*)buf,len);
loadHeader();
uint8_t* ptr = (uint8_t*)_data->data()+HEADER_SIZE;
uint8_t* end = (uint8_t*)_data->data()+_data->size();
LostPair lost;
while (ptr<end)
{
if((*ptr)&0x80){
lost.first = loadUint32(ptr)&0x7fffffff;
lost.second = loadUint32(ptr+4)&0x7fffffff;
lost.second += 1;
ptr += 8;
}else{
lost.first = loadUint32(ptr);
lost.second = lost.first +1;
ptr += 4;
}
lost_list.push_back(lost);
}
return true;
}
bool NAKPacket::storeToData() {
control_type = NAK;
sub_type = 0;
size_t cif_size = getCIFSize();
_data = BufferRaw::create();
_data->setCapacity(HEADER_SIZE+cif_size);
_data->setSize(HEADER_SIZE+cif_size);
storeToHeader();
uint8_t* ptr = (uint8_t*)_data->data()+HEADER_SIZE;
for(auto it : lost_list){
if(it.first+1 ==it.second){
storeUint32(ptr,it.first);
ptr[0] = ptr[0]&0x7f;
ptr += 4;
}else{
storeUint32(ptr,it.first);
ptr[0] |= 0x80;
storeUint32(ptr+4,it.second-1);
//ptr[4] = ptr[4]&0x7f;
ptr += 8;
}
}
return true;
}
size_t NAKPacket::getCIFSize(){
size_t size = 0;
for(auto it : lost_list){
if(it.first+1 ==it.second){
size += 4;
}else{
size += 8;
}
}
return size;
}
std::string NAKPacket::dump(){
_StrPrinter printer;
for (auto it : lost_list) {
printer<<"[ "<<it.first<<" , "<<it.second-1<<" ]";
}
return std::move(printer);
}
bool MsgDropReqPacket::loadFromData(uint8_t *buf, size_t len) {
if (len < HEADER_SIZE+8) {
WarnL << "data size" << len << " less " << HEADER_SIZE;
return false;
}
_data = BufferRaw::create();
_data->assign((char*)buf,len);
loadHeader();
uint8_t* ptr = (uint8_t*)_data->data()+HEADER_SIZE;
first_pkt_seq_num = loadUint32(ptr);
ptr += 4;
last_pkt_seq_num = loadUint32(ptr);
ptr += 4;
return true;
}
bool MsgDropReqPacket::storeToData() {
control_type = DROPREQ;
sub_type = 0;
_data = BufferRaw::create();
_data->setCapacity(HEADER_SIZE+8);
_data->setSize(HEADER_SIZE+8);
storeToHeader();
uint8_t* ptr = (uint8_t*)_data->data()+HEADER_SIZE;
storeUint32(ptr,first_pkt_seq_num);
ptr += 4;
storeUint32(ptr,last_pkt_seq_num);
ptr += 4;
return true;
}
} // namespace SRT

317
srt/Packet.hpp Normal file
View File

@ -0,0 +1,317 @@
#ifndef ZLMEDIAKIT_SRT_PACKET_H
#define ZLMEDIAKIT_SRT_PACKET_H
#include <stdint.h>
#include <vector>
#include "Network/Buffer.h"
#include "Common.hpp"
#include "HSExt.hpp"
namespace SRT {
using namespace toolkit;
/*
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+- SRT Header +-+-+-+-+-+-+-+-+-+-+-+-+-+
|0| Packet Sequence Number |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|P P|O|K K|R| Message Number |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Timestamp |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Destination Socket ID |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| |
+ Data +
| |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
Figure 3: Data packet structure
reference https://haivision.github.io/srt-rfc/draft-sharabayko-srt.html#name-packet-structure
*/
class DataPacket : public Buffer {
public:
using Ptr = std::shared_ptr<DataPacket>;
DataPacket() = default;
~DataPacket() = default;
static const size_t HEADER_SIZE = 16;
static bool isDataPacket(uint8_t *buf, size_t len);
static uint32_t getSocketID(uint8_t *buf, size_t len);
bool loadFromData(uint8_t *buf, size_t len);
bool storeToData(uint8_t *buf, size_t len);
///////Buffer override///////
char *data() const override;
size_t size() const override;
char *payloadData();
size_t payloadSize();
uint8_t f;
uint32_t packet_seq_number;
uint8_t PP;
uint8_t O;
uint8_t KK;
uint8_t R;
uint32_t msg_number;
uint32_t timestamp;
uint32_t dst_socket_id;
private:
BufferRaw::Ptr _data;
};
/*
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+- SRT Header +-+-+-+-+-+-+-+-+-+-+-+-+-+
|1| Control Type | Subtype |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Type-specific Information |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Timestamp |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Destination Socket ID |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+- CIF -+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| |
+ Control Information Field +
| |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
Figure 4: Control packet structure
reference https://haivision.github.io/srt-rfc/draft-sharabayko-srt.html#name-control-packets
*/
class ControlPacket : public Buffer {
public:
using Ptr = std::shared_ptr<ControlPacket>;
static const size_t HEADER_SIZE = 16;
static bool isControlPacket(uint8_t *buf, size_t len);
static uint16_t getControlType(uint8_t *buf, size_t len);
static uint32_t getSocketID(uint8_t *buf, size_t len);
ControlPacket() = default;
virtual ~ControlPacket() = default;
virtual bool loadFromData(uint8_t *buf, size_t len) = 0;
virtual bool storeToData() = 0;
bool loadHeader();
bool storeToHeader();
///////Buffer override///////
char *data() const override;
size_t size() const override;
enum {
HANDSHAKE = 0x0000,
KEEPALIVE = 0x0001,
ACK = 0x0002,
NAK = 0x0003,
CONGESTIONWARNING = 0x0004,
SHUTDOWN = 0x0005,
ACKACK = 0x0006,
DROPREQ = 0x0007,
PEERERROR = 0x0008,
USERDEFINEDTYPE = 0x7FFF
};
uint32_t sub_type : 16;
uint32_t control_type : 15;
uint32_t f : 1;
uint8_t type_specific_info[4];
uint32_t timestamp;
uint32_t dst_socket_id;
protected:
BufferRaw::Ptr _data;
};
/**
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Version |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Encryption Field | Extension Field |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Initial Packet Sequence Number |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Maximum Transmission Unit Size |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Maximum Flow Window Size |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Handshake Type |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| SRT Socket ID |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| SYN Cookie |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| |
+ +
| |
+ Peer IP Address +
| |
+ +
| |
+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+
| Extension Type | Extension Length |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| |
+ Extension Contents +
| |
+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+
Figure 5: Handshake packet structure
https://haivision.github.io/srt-rfc/draft-sharabayko-srt.html#name-handshake
*/
class HandshakePacket : public ControlPacket {
public:
using Ptr = std::shared_ptr<HandshakePacket>;
enum { NO_ENCRYPTION = 0, AES_128 = 1, AES_196 = 2, AES_256 = 3 };
static const size_t HS_CONTENT_MIN_SIZE = 48;
enum {
HS_TYPE_DONE = 0xFFFFFFFD,
HS_TYPE_AGREEMENT = 0xFFFFFFFE,
HS_TYPE_CONCLUSION = 0xFFFFFFFF,
HS_TYPE_WAVEHAND = 0x00000000,
HS_TYPE_INDUCTION = 0x00000001
};
enum { HS_EXT_FILED_HSREQ = 0x00000001, HS_EXT_FILED_KMREQ = 0x00000002, HS_EXT_FILED_CONFIG = 0x00000004 };
HandshakePacket() = default;
~HandshakePacket() = default;
static bool isHandshakePacket(uint8_t *buf, size_t len);
static uint32_t getHandshakeType(uint8_t *buf, size_t len);
static uint32_t getSynCookie(uint8_t *buf, size_t len);
static uint32_t generateSynCookie(struct sockaddr_storage* addr,TimePoint ts,uint32_t current_cookie = 0, int correction = 0);
void assignPeerIP(struct sockaddr_storage* addr);
///////ControlPacket override///////
bool loadFromData(uint8_t *buf, size_t len) override;
bool storeToData() override;
uint32_t version;
uint16_t encryption_field;
uint16_t extension_field;
uint32_t initial_packet_sequence_number;
uint32_t mtu;
uint32_t max_flow_window_size;
uint32_t handshake_type;
uint32_t srt_socket_id;
uint32_t syn_cookie;
uint8_t peer_ip_addr[16];
std::vector<HSExt::Ptr> ext_list;
private:
bool loadExtMessage(uint8_t *buf,size_t len);
bool storeExtMessage();
size_t getExtSize();
};
/*
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+- SRT Header +-+-+-+-+-+-+-+-+-+-+-+-+-+
|1| Control Type | Reserved |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Type-specific Information |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Timestamp |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Destination Socket ID |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
Figure 12: Keep-Alive control packet
https://haivision.github.io/srt-rfc/draft-sharabayko-srt.html#name-keep-alive
*/
class KeepLivePacket : public ControlPacket
{
public:
using Ptr = std::shared_ptr<KeepLivePacket>;
KeepLivePacket() = default;
~KeepLivePacket() = default;
///////ControlPacket override///////
bool loadFromData(uint8_t *buf, size_t len) override;
bool storeToData() override;
};
/*
An SRT NAK packet is formatted as follows:
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+- SRT Header +-+-+-+-+-+-+-+-+-+-+-+-+-+
|1| Control Type | Reserved |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Type-specific Information |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Timestamp |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Destination Socket ID |
+-+-+-+-+-+-+-+-+-+-+-+- CIF (Loss List) -+-+-+-+-+-+-+-+-+-+-+-+
|0| Lost packet sequence number |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|1| Range of lost packets from sequence number |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|0| Up to sequence number |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|0| Lost packet sequence number |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
Figure 14: NAK control packet
https://haivision.github.io/srt-rfc/draft-sharabayko-srt.html#name-nak-control-packet
*/
class NAKPacket : public ControlPacket
{
public:
using Ptr = std::shared_ptr<NAKPacket>;
using LostPair = std::pair<uint32_t,uint32_t>;
NAKPacket() = default;
~NAKPacket() = default;
std::string dump();
///////ControlPacket override///////
bool loadFromData(uint8_t *buf, size_t len) override;
bool storeToData() override;
std::list<LostPair> lost_list;
private:
size_t getCIFSize();
};
/*
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+- SRT Header +-+-+-+-+-+-+-+-+-+-+-+-+-+
|1| Control Type = 7 | Reserved = 0 |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Message Number |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Timestamp |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Destination Socket ID |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| First Packet Sequence Number |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Last Packet Sequence Number |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
Figure 18: Drop Request control packet
https://haivision.github.io/srt-rfc/draft-sharabayko-srt.html#name-message-drop-request
*/
class MsgDropReqPacket : public ControlPacket
{
public:
using Ptr = std::shared_ptr<MsgDropReqPacket>;
MsgDropReqPacket() = default;
~MsgDropReqPacket() = default;
///////ControlPacket override///////
bool loadFromData(uint8_t *buf, size_t len) override;
bool storeToData() override;
uint32_t first_pkt_seq_num;
uint32_t last_pkt_seq_num;
};
} // namespace SRT
#endif //ZLMEDIAKIT_SRT_PACKET_H

126
srt/PacketQueue.cpp Normal file
View File

@ -0,0 +1,126 @@
#include "PacketQueue.hpp"
namespace SRT {
PacketQueue::PacketQueue(uint32_t max_size, uint32_t init_seq, uint32_t lantency)
: _pkt_expected_seq(init_seq)
, _pkt_cap(max_size)
, _pkt_lantency(lantency) {
}
bool PacketQueue::inputPacket(DataPacket::Ptr pkt) {
if (pkt->packet_seq_number < _pkt_expected_seq) {
// TOO later drop this packet
return false;
}
_pkt_map[pkt->packet_seq_number] = pkt;
return true;
}
std::list<DataPacket::Ptr> PacketQueue::tryGetPacket() {
std::list<DataPacket::Ptr> re;
while (_pkt_map.find(_pkt_expected_seq) != _pkt_map.end()) {
re.push_back(_pkt_map[_pkt_expected_seq]);
_pkt_map.erase(_pkt_expected_seq);
_pkt_expected_seq++;
}
while (_pkt_map.size() > _pkt_cap) {
// force pop some packet
auto it = _pkt_map.begin();
re.push_back(it->second);
_pkt_expected_seq = it->second->packet_seq_number + 1;
_pkt_map.erase(it);
}
while (timeLantency() > _pkt_lantency) {
auto it = _pkt_map.begin();
re.push_back(it->second);
_pkt_expected_seq = it->second->packet_seq_number + 1;
_pkt_map.erase(it);
}
return std::move(re);
}
bool PacketQueue::dropForRecv(uint32_t first,uint32_t last){
if(first >= last){
return false;
}
if(_pkt_expected_seq <= last){
_pkt_expected_seq = last+1;
return true;
}
return false;
}
uint32_t PacketQueue::timeLantency() {
if (_pkt_map.empty()) {
return 0;
}
auto first = _pkt_map.begin()->second;
auto last = _pkt_map.rbegin()->second;
return last->timestamp - first->timestamp;
}
std::list<PacketQueue::LostPair> PacketQueue::getLostSeq() {
std::list<PacketQueue::LostPair> re;
if(_pkt_map.empty()){
return re;
}
if(getExpectedSize() == getSize()){
return re;
}
PacketQueue::LostPair lost;
lost.first = 0;
lost.second = 0;
uint32_t i = _pkt_expected_seq;
bool finish = true;
for(i = _pkt_expected_seq;i<=_pkt_map.rbegin()->first;++i){
if(_pkt_map.find(i) == _pkt_map.end()){
if(finish){
finish = false;
lost.first = i;
lost.second = i+1;
}else{
lost.second = i+1;
}
}else{
if(!finish){
finish = true;
re.push_back(lost);
}
}
}
return re;
}
size_t PacketQueue::getSize(){
return _pkt_map.size();
}
size_t PacketQueue::getExpectedSize() {
if(_pkt_map.empty()){
return 0;
}
return _pkt_map.rbegin()->first - _pkt_expected_seq+1;
}
size_t PacketQueue::getAvailableBufferSize(){
return _pkt_cap - getExpectedSize();
}
uint32_t PacketQueue::getExpectedSeq(){
return _pkt_expected_seq;
}
} // namespace SRT

44
srt/PacketQueue.hpp Normal file
View File

@ -0,0 +1,44 @@
#ifndef ZLMEDIAKIT_SRT_PACKET_QUEUE_H
#define ZLMEDIAKIT_SRT_PACKET_QUEUE_H
#include <memory>
#include <map>
#include <list>
#include <utility>
#include <tuple>
#include "Packet.hpp"
namespace SRT{
class PacketQueue
{
public:
using Ptr = std::shared_ptr<PacketQueue>;
using LostPair = std::pair<uint32_t,uint32_t>;
PacketQueue(uint32_t max_size,uint32_t init_seq,uint32_t lantency);
~PacketQueue() = default;
bool inputPacket(DataPacket::Ptr pkt);
std::list<DataPacket::Ptr> tryGetPacket();
uint32_t timeLantency();
std::list<LostPair> getLostSeq();
size_t getSize();
size_t getExpectedSize();
size_t getAvailableBufferSize();
uint32_t getExpectedSeq();
bool dropForRecv(uint32_t first,uint32_t last);
private:
std::map<uint32_t,DataPacket::Ptr> _pkt_map;
uint32_t _pkt_expected_seq = 0;
uint32_t _pkt_cap;
uint32_t _pkt_lantency;
};
}
#endif //ZLMEDIAKIT_SRT_PACKET_QUEUE_H

137
srt/SrtSession.cpp Normal file
View File

@ -0,0 +1,137 @@
#include "SrtSession.hpp"
#include "Packet.hpp"
#include "SrtTransport.hpp"
#include "Common/config.h"
namespace SRT {
using namespace mediakit;
SrtSession::SrtSession(const Socket::Ptr &sock)
: UdpSession(sock) {
socklen_t addr_len = sizeof(_peer_addr);
getpeername(sock->rawFD(), (struct sockaddr *)&_peer_addr, &addr_len);
}
SrtSession::~SrtSession() {
InfoP(this);
}
EventPoller::Ptr SrtSession::queryPoller(const Buffer::Ptr &buffer) {
uint8_t* data = (uint8_t*)buffer->data();
size_t size = buffer->size();
if(DataPacket::isDataPacket(data,size)){
uint32_t socket_id = DataPacket::getSocketID(data,size);
auto trans = SrtTransportManager::Instance().getItem(std::to_string(socket_id));
return trans ? trans->getPoller() : nullptr;
}
if(HandshakePacket::isHandshakePacket(data,size)){
auto type = HandshakePacket::getHandshakeType(data,size);
if(type == HandshakePacket::HS_TYPE_INDUCTION){
// 握手第一阶段
return nullptr;
}else if(type == HandshakePacket::HS_TYPE_CONCLUSION){
// 握手第二阶段
uint32_t sync_cookie = HandshakePacket::getSynCookie(data,size);
auto trans = SrtTransportManager::Instance().getHandshakeItem(std::to_string(sync_cookie));
return trans ? trans->getPoller() : nullptr;
}else{
WarnL<<" not reach there";
}
}else{
uint32_t socket_id = ControlPacket::getSocketID(data,size);
auto trans = SrtTransportManager::Instance().getItem(std::to_string(socket_id));
return trans ? trans->getPoller() : nullptr;
}
return nullptr;
}
void SrtSession::onRecv(const Buffer::Ptr &buffer) {
uint8_t* data = (uint8_t*)buffer->data();
size_t size = buffer->size();
if (_find_transport) {
//只允许寻找一次transport
_find_transport = false;
if (DataPacket::isDataPacket(data, size)) {
uint32_t socket_id = DataPacket::getSocketID(data, size);
auto trans = SrtTransportManager::Instance().getItem(std::to_string(socket_id));
if(trans){
_transport = std::move(trans);
}else{
WarnL<<" data packet not find transport ";
}
}
if (HandshakePacket::isHandshakePacket(data, size)) {
auto type = HandshakePacket::getHandshakeType(data, size);
if (type == HandshakePacket::HS_TYPE_INDUCTION) {
// 握手第一阶段
_transport = std::make_shared<SrtTransport>(getPoller());
} else if (type == HandshakePacket::HS_TYPE_CONCLUSION) {
// 握手第二阶段
uint32_t sync_cookie = HandshakePacket::getSynCookie(data, size);
auto trans = SrtTransportManager::Instance().getHandshakeItem(std::to_string(sync_cookie));
if (trans) {
_transport = std::move(trans);
} else {
WarnL << " hanshake packet not find transport ";
}
} else {
WarnL << " not reach there";
}
} else {
uint32_t socket_id = ControlPacket::getSocketID(data, size);
auto trans = SrtTransportManager::Instance().getItem(std::to_string(socket_id));
if(trans){
_transport = std::move(trans);
}else{
WarnL << " not find transport";
}
}
if(_transport){
_transport->setSession(shared_from_this());
}
InfoP(this);
}
_ticker.resetTime();
if(_transport){
_transport->inputSockData(data,size,&_peer_addr);
}else{
WarnL<< "ingore data";
}
}
void SrtSession::onError(const SockException &err) {
// udp链接超时但是srt链接不一定超时因为可能存在udp链接迁移的情况
//在udp链接迁移时新的SrtSession对象将接管SrtSession对象的生命周期
//本SrtSession对象将在超时后自动销毁
WarnP(this) << err.what();
if (!_transport) {
return;
}
// 防止互相引用导致不释放
auto transport = std::move(_transport);
getPoller()->async([transport,err] {
//延时减引用防止使用transport对象时销毁对象
transport->onShutdown(err);
}, false);
}
void SrtSession::onManager() {
GET_CONFIG(float, timeoutSec, kTimeOutSec);
if (_ticker.elapsedTime() > timeoutSec*1000) {
shutdown(SockException(Err_timeout, "srt connection timeout"));
return;
}
}
} // namespace SRT

31
srt/SrtSession.hpp Normal file
View File

@ -0,0 +1,31 @@
#ifndef ZLMEDIAKIT_SRT_SESSION_H
#define ZLMEDIAKIT_SRT_SESSION_H
#include "Network/Session.h"
#include "SrtTransport.hpp"
namespace SRT {
using namespace toolkit;
class SrtSession : public UdpSession {
public:
SrtSession(const Socket::Ptr &sock);
~SrtSession() override;
void onRecv(const Buffer::Ptr &) override;
void onError(const SockException &err) override;
void onManager() override;
static EventPoller::Ptr queryPoller(const Buffer::Ptr &buffer);
private:
bool _find_transport = true;
Ticker _ticker;
struct sockaddr_storage _peer_addr;
SrtTransport::Ptr _transport;
};
} // namespace SRT
#endif //ZLMEDIAKIT_SRT_SESSION_H

465
srt/SrtTransport.cpp Normal file
View File

@ -0,0 +1,465 @@
#include "Util/onceToken.h"
#include "SrtTransport.hpp"
#include "Packet.hpp"
#include "Ack.hpp"
namespace SRT {
#define SRT_FIELD "srt."
//srt 超时时间
const std::string kTimeOutSec = SRT_FIELD"timeoutSec";
//srt 单端口udp服务器
const std::string kPort = SRT_FIELD"port";
static std::atomic<uint32_t> s_srt_socket_id_generate{125};
//////////// SrtTransport //////////////////////////
SrtTransport::SrtTransport(const EventPoller::Ptr &poller)
: _poller(poller) {
_start_timestamp = SteadyClock::now();
_socket_id = s_srt_socket_id_generate.fetch_add(1);
}
SrtTransport::~SrtTransport(){
TraceL<<" ";
}
const EventPoller::Ptr &SrtTransport::getPoller() const {
return _poller;
}
void SrtTransport::setSession(Session::Ptr session) {
_history_sessions.emplace(session.get(), session);
if (_selected_session) {
InfoL << "srt network changed: " << _selected_session->get_peer_ip() << ":"
<< _selected_session->get_peer_port() << " -> " << session->get_peer_ip() << ":"
<< session->get_peer_port() << ", id:" << _selected_session->getIdentifier();
}
_selected_session = session;
}
const Session::Ptr &SrtTransport::getSession() const {
return _selected_session;
}
void SrtTransport::switchToOtherTransport(uint8_t *buf, int len,uint32_t socketid, struct sockaddr_storage *addr){
BufferRaw::Ptr tmp = BufferRaw::create();
struct sockaddr_storage tmp_addr = *addr;
tmp->assign((char*)buf,len);
auto trans = SrtTransportManager::Instance().getItem(std::to_string(socketid));
if(trans){
trans->getPoller()->async([tmp,tmp_addr,trans]{
trans->inputSockData((uint8_t*)tmp->data(),tmp->size(),(struct sockaddr_storage*)&tmp_addr);
});
}
}
void SrtTransport::inputSockData(uint8_t *buf, int len, struct sockaddr_storage *addr) {
using srt_control_handler = void (SrtTransport::*)(uint8_t* buf,int len,struct sockaddr_storage *addr);
static std::unordered_map<uint16_t, srt_control_handler> s_control_functions;
static onceToken token([]() {
s_control_functions.emplace(ControlPacket::HANDSHAKE, &SrtTransport::handleHandshake);
s_control_functions.emplace(ControlPacket::KEEPALIVE, &SrtTransport::handleKeeplive);
s_control_functions.emplace(ControlPacket::ACK, &SrtTransport::handleACK);
s_control_functions.emplace(ControlPacket::NAK, &SrtTransport::handleNAK);
s_control_functions.emplace(ControlPacket::CONGESTIONWARNING, &SrtTransport::handleCongestionWarning);
s_control_functions.emplace(ControlPacket::SHUTDOWN, &SrtTransport::handleShutDown);
s_control_functions.emplace(ControlPacket::ACKACK, &SrtTransport::handleACKACK);
s_control_functions.emplace(ControlPacket::DROPREQ, &SrtTransport::handleDropReq);
s_control_functions.emplace(ControlPacket::PEERERROR, &SrtTransport::handlePeerError);
s_control_functions.emplace(ControlPacket::USERDEFINEDTYPE, &SrtTransport::handleUserDefinedType);
});
auto now = SteadyClock::now();
// 处理srt数据
if (DataPacket::isDataPacket(buf, len)) {
uint32_t socketId = DataPacket::getSocketID(buf,len);
if(socketId == _socket_id){
_pkt_recv_rate_context.inputPacket(now);
_estimated_link_capacity_context.inputPacket(now);
_recv_rate_context.inputPacket(now, len);
handleDataPacket(buf, len, addr);
}else{
switchToOtherTransport(buf,len,socketId,addr);
}
} else {
if (ControlPacket::isControlPacket(buf, len)) {
uint32_t socketId = ControlPacket::getSocketID(buf,len);
uint16_t type = ControlPacket::getControlType(buf,len);
if(type != ControlPacket::HANDSHAKE && socketId != _socket_id && _socket_id != 0){
// socket id not same
switchToOtherTransport(buf,len,socketId,addr);
return;
}
_pkt_recv_rate_context.inputPacket(now);
_estimated_link_capacity_context.inputPacket(now);
_recv_rate_context.inputPacket(now, len);
auto it = s_control_functions.find(type);
if (it == s_control_functions.end()) {
WarnL<<" not support type ignore" << ControlPacket::getControlType(buf,len);
return;
}else{
(this->*(it->second))(buf,len,addr);
}
} else {
// not reach
WarnL << "not reach this";
}
}
}
void SrtTransport::handleHandshakeInduction(HandshakePacket &pkt, struct sockaddr_storage *addr) {
// Induction Phase
TraceL << getIdentifier() << " Induction Phase ";
if (_handleshake_res) {
TraceL << getIdentifier() << " Induction handle repeate ";
sendControlPacket(_handleshake_res, true);
return;
}
_init_seq_number = pkt.initial_packet_sequence_number;
_max_window_size = pkt.max_flow_window_size;
_mtu = pkt.mtu;
_peer_socket_id = pkt.srt_socket_id;
HandshakePacket::Ptr res = std::make_shared<HandshakePacket>();
res->dst_socket_id = _peer_socket_id;
res->timestamp = DurationCountMicroseconds(_start_timestamp.time_since_epoch());
res->mtu = _mtu;
res->max_flow_window_size = _max_window_size;
res->initial_packet_sequence_number = _init_seq_number;
res->version = 5;
res->encryption_field = HandshakePacket::NO_ENCRYPTION;
res->extension_field = 0x4A17;
res->handshake_type = HandshakePacket::HS_TYPE_INDUCTION;
res->srt_socket_id = _peer_socket_id;
res->syn_cookie = HandshakePacket::generateSynCookie(addr, _start_timestamp);
_sync_cookie = res->syn_cookie;
memcpy(res->peer_ip_addr, pkt.peer_ip_addr, sizeof(pkt.peer_ip_addr) * sizeof(pkt.peer_ip_addr[0]));
_handleshake_res = res;
res->storeToData();
registerSelfHandshake();
sendControlPacket(res, true);
}
void SrtTransport::handleHandshakeConclusion(HandshakePacket &pkt, struct sockaddr_storage *addr) {
if(!_handleshake_res){
ErrorL<<"must Induction Phase for handleshake ";
return;
}
if (_handleshake_res->handshake_type == HandshakePacket::HS_TYPE_INDUCTION) {
// first
HSExtMessage::Ptr req;
HSExtStreamID::Ptr sid;
for (auto ext : pkt.ext_list) {
//TraceL << getIdentifier() << " ext " << ext->dump();
if (!req) {
req = std::dynamic_pointer_cast<HSExtMessage>(ext);
}
if(!sid){
sid = std::dynamic_pointer_cast<HSExtStreamID>(ext);
}
}
if(sid){
_stream_id = sid->streamid;
}
TraceL << getIdentifier() << " CONCLUSION Phase ";
HandshakePacket::Ptr res = std::make_shared<HandshakePacket>();
res->dst_socket_id = _peer_socket_id;
res->timestamp = DurationCountMicroseconds(SteadyClock::now() - _start_timestamp);
res->mtu = _mtu;
res->max_flow_window_size = _max_window_size;
res->initial_packet_sequence_number = _init_seq_number;
res->version = 5;
res->encryption_field = HandshakePacket::NO_ENCRYPTION;
res->extension_field = HandshakePacket::HS_EXT_FILED_HSREQ;
res->handshake_type = HandshakePacket::HS_TYPE_CONCLUSION;
res->srt_socket_id = _socket_id;
res->syn_cookie = 0;
res->assignPeerIP(addr);
HSExtMessage::Ptr ext = std::make_shared<HSExtMessage>();
ext->extension_type = HSExt::SRT_CMD_HSRSP;
ext->srt_version = srtVersion(1, 5, 0);
ext->srt_flag = req->srt_flag;
ext->recv_tsbpd_delay = ext->send_tsbpd_delay = req->recv_tsbpd_delay;
res->ext_list.push_back(std::move(ext));
res->storeToData();
_handleshake_res = res;
unregisterSelfHandshake();
registerSelf();
sendControlPacket(res, true);
TraceL<<" buf size = "<<res->max_flow_window_size<<" init seq ="<<_init_seq_number<<" lantency="<<req->recv_tsbpd_delay;
_recv_buf = std::make_shared<PacketQueue>(res->max_flow_window_size,_init_seq_number, req->recv_tsbpd_delay*1e6);
onHandShakeFinished(_stream_id);
} else {
TraceL << getIdentifier() << " CONCLUSION handle repeate ";
sendControlPacket(_handleshake_res, true);
}
}
void SrtTransport::handleHandshake(uint8_t *buf, int len, struct sockaddr_storage *addr){
HandshakePacket pkt;
assert(pkt.loadFromData(buf,len));
if(pkt.handshake_type == HandshakePacket::HS_TYPE_INDUCTION){
handleHandshakeInduction(pkt,addr);
}else if(pkt.handshake_type == HandshakePacket::HS_TYPE_CONCLUSION){
handleHandshakeConclusion(pkt,addr);
}else{
WarnL<<" not support handshake type = "<< pkt.handshake_type;
}
_ack_ticker.resetTime();
_nak_ticker.resetTime();
}
void SrtTransport::handleKeeplive(uint8_t *buf, int len, struct sockaddr_storage *addr){
TraceL;
}
void SrtTransport::handleACK(uint8_t *buf, int len, struct sockaddr_storage *addr){
TraceL;
auto now = SteadyClock::now();
ACKPacket ack;
ack.loadFromData(buf,len);
ACKACKPacket::Ptr pkt = std::make_shared<ACKACKPacket>();
pkt->dst_socket_id = _peer_socket_id;
pkt->timestamp = DurationCountMicroseconds(now -_start_timestamp);
pkt->ack_number = ack.ack_number;
pkt->storeToData();
sendControlPacket(pkt,true);
}
void SrtTransport::handleNAK(uint8_t *buf, int len, struct sockaddr_storage *addr){
TraceL;
}
void SrtTransport::handleCongestionWarning(uint8_t *buf, int len, struct sockaddr_storage *addr){
TraceL;
}
void SrtTransport::handleShutDown(uint8_t *buf, int len, struct sockaddr_storage *addr){
TraceL;
onShutdown(SockException(Err_shutdown, "peer close connection"));
}
void SrtTransport::handleDropReq(uint8_t *buf, int len, struct sockaddr_storage *addr){
MsgDropReqPacket pkt;
pkt.loadFromData(buf,len);
TraceL<<"drop "<<pkt.first_pkt_seq_num<<" last "<<pkt.last_pkt_seq_num;
_recv_buf->dropForRecv(pkt.first_pkt_seq_num,pkt.last_pkt_seq_num);
}
void SrtTransport::handleUserDefinedType(uint8_t *buf, int len, struct sockaddr_storage *addr){
TraceL;
}
void SrtTransport::handleACKACK(uint8_t *buf, int len, struct sockaddr_storage *addr){
//TraceL;
auto now = SteadyClock::now();
ACKACKPacket::Ptr pkt = std::make_shared<ACKACKPacket>();
pkt->loadFromData(buf,len);
uint32_t rtt = DurationCountMicroseconds(now - _ack_send_timestamp[pkt->ack_number]);
_rtt_variance = 3*_rtt_variance/4+abs(_rtt - rtt);
_rtt = 7*rtt/8+_rtt/8;
_ack_send_timestamp.erase(pkt->ack_number);
}
void SrtTransport::handlePeerError(uint8_t *buf, int len, struct sockaddr_storage *addr){
TraceL;
}
void SrtTransport::sendACKPacket() {
ACKPacket::Ptr pkt=std::make_shared<ACKPacket>();
auto now = SteadyClock::now();
pkt->dst_socket_id = _peer_socket_id;
pkt->timestamp = DurationCountMicroseconds(now - _start_timestamp);
pkt->ack_number = ++_ack_number_count;
pkt->last_ack_pkt_seq_number = _recv_buf->getExpectedSeq();
pkt->rtt = _rtt;
pkt->rtt_variance = _rtt_variance;
pkt->available_buf_size = _recv_buf->getAvailableBufferSize();
pkt->pkt_recv_rate = _pkt_recv_rate_context.getPacketRecvRate();
pkt->estimated_link_capacity = _estimated_link_capacity_context.getEstimatedLinkCapacity();
pkt->recv_rate = _recv_rate_context.getRecvRate();
pkt->storeToData();
_ack_send_timestamp[pkt->ack_number] = now;
sendControlPacket(pkt,true);
}
void SrtTransport::sendLightACKPacket() {
ACKPacket::Ptr pkt=std::make_shared<ACKPacket>();
auto now = SteadyClock::now();
pkt->dst_socket_id = _peer_socket_id;
pkt->timestamp = DurationCountMicroseconds(now - _start_timestamp);
pkt->ack_number = 0;
pkt->last_ack_pkt_seq_number = _recv_buf->getExpectedSeq();
pkt->rtt = 0;
pkt->rtt_variance = 0;
pkt->available_buf_size = 0;
pkt->pkt_recv_rate = 0;
pkt->estimated_link_capacity = 0;
pkt->recv_rate = 0;
pkt->storeToData();
sendControlPacket(pkt,true);
}
void SrtTransport::sendNAKPacket(std::list<PacketQueue::LostPair>& lost_list){
NAKPacket::Ptr pkt = std::make_shared<NAKPacket>();
auto now = SteadyClock::now();
pkt->dst_socket_id = _peer_socket_id;
pkt->timestamp = DurationCountMicroseconds(now - _start_timestamp);
pkt->lost_list = lost_list;
pkt->storeToData();
//TraceL<<"send NAK "<<pkt->dump();
sendControlPacket(pkt,true);
}
void SrtTransport::handleDataPacket(uint8_t *buf, int len, struct sockaddr_storage *addr){
DataPacket::Ptr pkt = std::make_shared<DataPacket>();
pkt->loadFromData(buf,len);
if(_ack_ticker.elapsedTime()>=10){
_light_ack_pkt_count = 0;
_ack_ticker.resetTime();
// send a ack per 10 ms for receiver
sendACKPacket();
}else{
if(_light_ack_pkt_count >= 64){
// for high bitrate stream send light ack
// TODO
sendLightACKPacket();
}
_light_ack_pkt_count = 0;
}
_light_ack_pkt_count++;
//TraceL<<" seq="<< pkt->packet_seq_number<<" ts="<<pkt->timestamp<<" size="<<pkt->payloadSize()<<\
" PP="<<(int)pkt->PP<<" O="<<(int)pkt->O<<" kK="<<(int)pkt->KK<<" R="<<(int)pkt->R;
#if 1
_recv_buf->inputPacket(pkt);
#else
if(pkt->packet_seq_number%100 == 0){
// drop
TraceL<<"drop packet";
TraceL<<"expected size "<<_recv_buf->getExpectedSize()<<" real size="<<_recv_buf->getSize();
}else{
_recv_buf->inputPacket(pkt);
}
#endif
//TraceL<<" data number size "<<list.size();
auto nak_interval = (_rtt+_rtt_variance*4)/2/1000;
if(_nak_ticker.elapsedTime()>20 && _nak_ticker.elapsedTime()>nak_interval){
auto lost = _recv_buf->getLostSeq();
if(!lost.empty()){
sendNAKPacket(lost);
//TraceL<<"send NAK";
}
_nak_ticker.resetTime();
}
auto list = _recv_buf->tryGetPacket();
for(auto data : list){
onSRTData(std::move(data));
}
}
void SrtTransport::sendDataPacket(DataPacket::Ptr pkt,char* buf,int len, bool flush) {
pkt->storeToData((uint8_t*)buf,len);
sendPacket(pkt,flush);
}
void SrtTransport::sendControlPacket(ControlPacket::Ptr pkt, bool flush) {
sendPacket(pkt,flush);
}
void SrtTransport::sendPacket(Buffer::Ptr pkt,bool flush){
if(_selected_session){
auto tmp = _packet_pool.obtain2();
tmp->assign(pkt->data(),pkt->size());
_selected_session->setSendFlushFlag(flush);
_selected_session->send(std::move(tmp));
}else{
WarnL<<"not reach this";
}
}
std::string SrtTransport::getIdentifier(){
return _selected_session ? _selected_session->getIdentifier() : "";
}
void SrtTransport::registerSelfHandshake() {
SrtTransportManager::Instance().addHandshakeItem(std::to_string(_sync_cookie),shared_from_this());
}
void SrtTransport::unregisterSelfHandshake() {
if(_sync_cookie == 0){
return;
}
SrtTransportManager::Instance().removeHandshakeItem(std::to_string(_sync_cookie));
}
void SrtTransport::registerSelf() {
if(_socket_id == 0){
return;
}
SrtTransportManager::Instance().addItem(std::to_string(_socket_id),shared_from_this());
}
void SrtTransport::unregisterSelf() {
SrtTransportManager::Instance().removeItem(std::to_string(_socket_id));
}
void SrtTransport::onShutdown(const SockException &ex){
WarnL << ex.what();
unregisterSelfHandshake();
unregisterSelf();
for (auto &pr : _history_sessions) {
auto session = pr.second.lock();
if (session) {
session->shutdown(ex);
}
}
}
//////////// SrtTransportManager //////////////////////////
SrtTransportManager &SrtTransportManager::Instance() {
static SrtTransportManager s_instance;
return s_instance;
}
void SrtTransportManager::addItem(const std::string &key, const SrtTransport::Ptr &ptr) {
std::lock_guard<std::mutex> lck(_mtx);
_map[key] = ptr;
}
SrtTransport::Ptr SrtTransportManager::getItem(const std::string &key) {
if (key.empty()) {
return nullptr;
}
std::lock_guard<std::mutex> lck(_mtx);
auto it = _map.find(key);
if (it == _map.end()) {
return nullptr;
}
return it->second.lock();
}
void SrtTransportManager::removeItem(const std::string &key) {
std::lock_guard<std::mutex> lck(_mtx);
_map.erase(key);
}
void SrtTransportManager::addHandshakeItem(const std::string &key, const SrtTransport::Ptr &ptr) {
std::lock_guard<std::mutex> lck(_handshake_mtx);
_handshake_map[key] = ptr;
}
void SrtTransportManager::removeHandshakeItem(const std::string &key) {
std::lock_guard<std::mutex> lck(_handshake_mtx);
_handshake_map.erase(key);
}
SrtTransport::Ptr SrtTransportManager::getHandshakeItem(const std::string &key) {
if (key.empty()) {
return nullptr;
}
std::lock_guard<std::mutex> lck(_handshake_mtx);
auto it = _handshake_map.find(key);
if (it == _handshake_map.end()) {
return nullptr;
}
return it->second.lock();
}
} // namespace SRT

143
srt/SrtTransport.hpp Normal file
View File

@ -0,0 +1,143 @@
#ifndef ZLMEDIAKIT_SRT_TRANSPORT_H
#define ZLMEDIAKIT_SRT_TRANSPORT_H
#include <mutex>
#include <chrono>
#include <memory>
#include <atomic>
#include "Network/Session.h"
#include "Poller/EventPoller.h"
#include "Util/TimeTicker.h"
#include "Common.hpp"
#include "Packet.hpp"
#include "PacketQueue.hpp"
#include "Statistic.hpp"
namespace SRT {
using namespace toolkit;
extern const std::string kPort;
extern const std::string kTimeOutSec;
class SrtTransport : public std::enable_shared_from_this<SrtTransport> {
public:
friend class SrtSession;
using Ptr = std::shared_ptr<SrtTransport>;
SrtTransport(const EventPoller::Ptr &poller);
virtual ~SrtTransport();
const EventPoller::Ptr &getPoller() const;
void setSession(Session::Ptr session);
const Session::Ptr &getSession() const;
/**
* socket收到udp数据
* @param buf
* @param len
* @param addr
*/
void inputSockData(uint8_t *buf, int len, struct sockaddr_storage *addr);
std::string getIdentifier();
void unregisterSelfHandshake();
void unregisterSelf();
protected:
virtual void onHandShakeFinished(std::string& streamid){};
virtual void onSRTData(DataPacket::Ptr pkt){};
virtual void onShutdown(const SockException &ex);
private:
void registerSelfHandshake();
void registerSelf();
void switchToOtherTransport(uint8_t *buf, int len,uint32_t socketid, struct sockaddr_storage *addr);
void handleHandshake(uint8_t *buf, int len, struct sockaddr_storage *addr);
void handleHandshakeInduction(HandshakePacket& pkt,struct sockaddr_storage *addr);
void handleHandshakeConclusion(HandshakePacket& pkt,struct sockaddr_storage *addr);
void handleKeeplive(uint8_t *buf, int len, struct sockaddr_storage *addr);
void handleACK(uint8_t *buf, int len, struct sockaddr_storage *addr);
void handleACKACK(uint8_t *buf, int len, struct sockaddr_storage *addr);
void handleNAK(uint8_t *buf, int len, struct sockaddr_storage *addr);
void handleCongestionWarning(uint8_t *buf, int len, struct sockaddr_storage *addr);
void handleShutDown(uint8_t *buf, int len, struct sockaddr_storage *addr);
void handleDropReq(uint8_t *buf, int len, struct sockaddr_storage *addr);
void handleUserDefinedType(uint8_t *buf, int len, struct sockaddr_storage *addr);
void handlePeerError(uint8_t *buf, int len, struct sockaddr_storage *addr);
void handleDataPacket(uint8_t *buf, int len, struct sockaddr_storage *addr);
void sendNAKPacket(std::list<PacketQueue::LostPair>& lost_list);
void sendACKPacket();
void sendLightACKPacket();
protected:
void sendDataPacket(DataPacket::Ptr pkt,char* buf,int len,bool flush = false);
void sendControlPacket(ControlPacket::Ptr pkt,bool flush = true);
void sendPacket(Buffer::Ptr pkt,bool flush = true);
private:
//当前选中的udp链接
Session::Ptr _selected_session;
//链接迁移前后使用过的udp链接
std::unordered_map<Session *, std::weak_ptr<Session> > _history_sessions;
EventPoller::Ptr _poller;
uint32_t _peer_socket_id;
uint32_t _socket_id = 0;
TimePoint _start_timestamp;
uint32_t _mtu = 1500;
uint32_t _max_window_size = 8192;
uint32_t _init_seq_number = 0;
std::string _stream_id;
uint32_t _sync_cookie = 0;
PacketQueue::Ptr _recv_buf;
uint32_t _rtt = 100*1000;
uint32_t _rtt_variance =50*1000;
uint32_t _light_ack_pkt_count = 0;
uint32_t _ack_number_count = 0;
Ticker _ack_ticker;
std::map<uint32_t,TimePoint> _ack_send_timestamp;
PacketRecvRateContext _pkt_recv_rate_context;
EstimatedLinkCapacityContext _estimated_link_capacity_context;
RecvRateContext _recv_rate_context;
Ticker _nak_ticker;
//保持发送的握手消息,防止丢失重发
HandshakePacket::Ptr _handleshake_res;
ResourcePool<BufferRaw> _packet_pool;
};
class SrtTransportManager {
public:
static SrtTransportManager &Instance();
SrtTransport::Ptr getItem(const std::string &key);
void addItem(const std::string &key, const SrtTransport::Ptr &ptr);
void removeItem(const std::string &key);
void addHandshakeItem(const std::string &key, const SrtTransport::Ptr &ptr);
void removeHandshakeItem(const std::string &key);
SrtTransport::Ptr getHandshakeItem(const std::string &key);
private:
SrtTransportManager() = default;
private:
std::mutex _mtx;
std::unordered_map<std::string, std::weak_ptr<SrtTransport>> _map;
std::mutex _handshake_mtx;
std::unordered_map<std::string, std::weak_ptr<SrtTransport>> _handshake_map;
};
} // namespace SRT
#endif // ZLMEDIAKIT_SRT_TRANSPORT_H

76
srt/Statistic.cpp Normal file
View File

@ -0,0 +1,76 @@
#include <algorithm>
#include "Statistic.hpp"
namespace SRT {
void PacketRecvRateContext::inputPacket(TimePoint ts) {
if(_pkt_map.size()>100){
_pkt_map.erase(_pkt_map.begin());
}
_pkt_map.emplace(ts,ts);
}
uint32_t PacketRecvRateContext::getPacketRecvRate() {
if(_pkt_map.size()<2){
return 0;
}
auto first = _pkt_map.begin();
auto last = _pkt_map.rbegin();
double dur = DurationCountMicroseconds(last->first - first->first)/1000000.0;
double rate = _pkt_map.size()/dur;
return (uint32_t)rate;
}
void EstimatedLinkCapacityContext::inputPacket(TimePoint ts) {
if(_pkt_map.size()>16){
_pkt_map.erase(_pkt_map.begin());
}
_pkt_map.emplace(ts,ts);
}
uint32_t EstimatedLinkCapacityContext::getEstimatedLinkCapacity() {
decltype(_pkt_map.begin()) next;
std::vector<SteadyClock::duration> tmp;
for(auto it = _pkt_map.begin();it != _pkt_map.end();++it){
next = it;
++next;
if(next != _pkt_map.end()){
tmp.push_back(next->first -it->first);
}else{
break;
}
}
std::sort(tmp.begin(),tmp.end());
if(tmp.empty()){
return 0;
}
double dur =DurationCountMicroseconds(tmp[tmp.size()/2])/1e6;
return (uint32_t)(1.0/dur);
}
void RecvRateContext::inputPacket(TimePoint ts, size_t size ) {
if (_pkt_map.size() > 100) {
_pkt_map.erase(_pkt_map.begin());
}
_pkt_map.emplace(ts, size);
}
uint32_t RecvRateContext::getRecvRate() {
if(_pkt_map.size()<2){
return 0;
}
auto first = _pkt_map.begin();
auto last = _pkt_map.rbegin();
double dur = DurationCountMicroseconds(last->first - first->first)/1000000.0;
size_t bytes = 0;
for(auto it : _pkt_map){
bytes += it.second;
}
double rate = (double)bytes/dur;
return (uint32_t)rate;
}
} // namespace SRT

43
srt/Statistic.hpp Normal file
View File

@ -0,0 +1,43 @@
#ifndef ZLMEDIAKIT_SRT_STATISTIC_H
#define ZLMEDIAKIT_SRT_STATISTIC_H
#include <map>
#include "Common.hpp"
#include "Packet.hpp"
namespace SRT {
class PacketRecvRateContext {
public:
PacketRecvRateContext() = default;
~PacketRecvRateContext() = default;
void inputPacket(TimePoint ts);
uint32_t getPacketRecvRate();
private:
std::map<TimePoint,TimePoint> _pkt_map;
};
class EstimatedLinkCapacityContext {
public:
EstimatedLinkCapacityContext() = default;
~EstimatedLinkCapacityContext() = default;
void inputPacket(TimePoint ts);
uint32_t getEstimatedLinkCapacity();
private:
std::map<TimePoint,TimePoint> _pkt_map;
};
class RecvRateContext {
public:
RecvRateContext() = default;
~RecvRateContext() = default;
void inputPacket(TimePoint ts,size_t size);
uint32_t getRecvRate();
private:
std::map<TimePoint,size_t> _pkt_map;
};
} // namespace SRT
#endif // ZLMEDIAKIT_SRT_STATISTIC_H