diff --git a/srt/Packet.cpp b/srt/Packet.cpp index af4b3772..3cf40584 100644 --- a/srt/Packet.cpp +++ b/srt/Packet.cpp @@ -61,7 +61,35 @@ bool DataPacket::loadFromData(uint8_t *buf, size_t len) { _data->assign((char *)(buf), len); return true; } +bool DataPacket::storeToHeader(){ + if (!_data || _data->size() < HEADER_SIZE) { + WarnL << "data size less " << HEADER_SIZE; + return false; + } + uint8_t *ptr = (uint8_t *)_data->data(); + ptr[0] = packet_seq_number >> 24; + ptr[1] = (packet_seq_number >> 16) & 0xff; + ptr[2] = (packet_seq_number >> 8) & 0xff; + ptr[3] = packet_seq_number & 0xff; + ptr += 4; + + ptr[0] = PP << 6; + ptr[0] |= O << 5; + ptr[0] |= KK << 3; + ptr[0] |= R << 2; + ptr[0] |= (msg_number & 0xff000000) >> 24; + ptr[1] = (msg_number & 0xff0000) >> 16; + ptr[2] = (msg_number & 0xff00) >> 8; + ptr[3] = msg_number & 0xff; + ptr += 4; + + storeUint32(ptr, timestamp); + ptr += 4; + + storeUint32(ptr, dst_socket_id); + ptr += 4; +} bool DataPacket::storeToData(uint8_t *buf, size_t len) { _data = BufferRaw::create(); _data->setCapacity(len + HEADER_SIZE); diff --git a/srt/Packet.hpp b/srt/Packet.hpp index ddc784fd..b34aae29 100644 --- a/srt/Packet.hpp +++ b/srt/Packet.hpp @@ -43,6 +43,7 @@ public: static uint32_t getSocketID(uint8_t *buf, size_t len); bool loadFromData(uint8_t *buf, size_t len); bool storeToData(uint8_t *buf, size_t len); + bool storeToHeader(); ///////Buffer override/////// char *data() const override; diff --git a/srt/PacketQueue.cpp b/srt/PacketQueue.cpp index ea342277..20181262 100644 --- a/srt/PacketQueue.cpp +++ b/srt/PacketQueue.cpp @@ -51,12 +51,39 @@ bool PacketQueue::dropForRecv(uint32_t first,uint32_t last){ } if(_pkt_expected_seq <= last){ - _pkt_expected_seq = last+1; + for(uint32_t i =first;i<=last;++i){ + if(_pkt_map.find(i) != _pkt_map.end()){ + _pkt_map.erase(i); + } + } + _pkt_expected_seq = last+1; return true; } return false; } + +bool PacketQueue::dropForSend(uint32_t num){ + if(num <= _pkt_expected_seq){ + return false; + } + + for(uint32_t i =_pkt_expected_seq;i< num;++i){ + if(_pkt_map.find(i) != _pkt_map.end()){ + _pkt_map.erase(i); + } + } + _pkt_expected_seq = num; + return true; +} + +DataPacket::Ptr PacketQueue::findPacketBySeq(uint32_t seq){ + auto it = _pkt_map.find(seq); + if(it != _pkt_map.end()){ + return it->second; + } + return nullptr; +} uint32_t PacketQueue::timeLantency() { if (_pkt_map.empty()) { return 0; diff --git a/srt/PacketQueue.hpp b/srt/PacketQueue.hpp index 05996647..2ba8467b 100644 --- a/srt/PacketQueue.hpp +++ b/srt/PacketQueue.hpp @@ -29,6 +29,10 @@ public: uint32_t getExpectedSeq(); bool dropForRecv(uint32_t first,uint32_t last); + + bool dropForSend(uint32_t num); + + DataPacket::Ptr findPacketBySeq(uint32_t seq); private: diff --git a/srt/SrtTransport.cpp b/srt/SrtTransport.cpp index b9fef358..bc185717 100644 --- a/srt/SrtTransport.cpp +++ b/srt/SrtTransport.cpp @@ -192,6 +192,8 @@ void SrtTransport::handleHandshakeConclusion(HandshakePacket &pkt, struct sockad sendControlPacket(res, true); TraceL<<" buf size = "<max_flow_window_size<<" init seq ="<<_init_seq_number<<" lantency="<recv_tsbpd_delay; _recv_buf = std::make_shared(res->max_flow_window_size,_init_seq_number, req->recv_tsbpd_delay*1e6); + _send_buf = std::make_shared(res->max_flow_window_size,_init_seq_number, req->recv_tsbpd_delay*1e6); + _send_packet_seq_number = _init_seq_number; onHandShakeFinished(_stream_id,addr); } else { TraceL << getIdentifier() << " CONCLUSION handle repeate "; @@ -235,10 +237,33 @@ void SrtTransport::handleACK(uint8_t *buf, int len, struct sockaddr_storage *add pkt->timestamp = DurationCountMicroseconds(_now -_start_timestamp); pkt->ack_number = ack.ack_number; pkt->storeToData(); + _send_buf->dropForSend(ack.last_ack_pkt_seq_number); sendControlPacket(pkt,true); +} +void SrtTransport::sendMsgDropReq(uint32_t first ,uint32_t last){ + } void SrtTransport::handleNAK(uint8_t *buf, int len, struct sockaddr_storage *addr){ TraceL; + NAKPacket pkt; + pkt.loadFromData(buf,len); + bool empty = false; + + for(auto it : pkt.lost_list){ + empty = true; + for(uint32_t i=it.first;ifindPacketBySeq(i); + if(data){ + data->R = 1; + data->storeToHeader(); + sendPacket(data,true); + empty = false; + } + } + if(empty){ + sendMsgDropReq(it.first,it.second-1); + } + } } void SrtTransport::handleCongestionWarning(uint8_t *buf, int len, struct sockaddr_storage *addr){ TraceL; @@ -391,6 +416,7 @@ void SrtTransport::handleDataPacket(uint8_t *buf, int len, struct sockaddr_stora void SrtTransport::sendDataPacket(DataPacket::Ptr pkt,char* buf,int len, bool flush) { pkt->storeToData((uint8_t*)buf,len); sendPacket(pkt,flush); + _send_buf->inputPacket(pkt); } void SrtTransport::sendControlPacket(ControlPacket::Ptr pkt, bool flush) { sendPacket(pkt,flush); @@ -442,6 +468,48 @@ void SrtTransport::onShutdown(const SockException &ex){ } } } +size_t SrtTransport::getPayloadSize(){ + size_t ret = (_mtu - 28 -16)/188*188; + return ret; +} +void SrtTransport::onSendData(const Buffer::Ptr &buffer, bool flush){ + DataPacket::Ptr pkt; + size_t payloadSize = getPayloadSize(); + size_t size = buffer->size(); + char* ptr = buffer->data(); + char* end = buffer->data()+size; + + while(ptr < end && size >=payloadSize){ + pkt = std::make_shared(); + pkt->f = 0; + pkt->packet_seq_number = _send_packet_seq_number++; + pkt->PP = 3; + pkt->O = 0; + pkt->KK = 0; + pkt->R = 0; + pkt->msg_number = _send_msg_number++; + pkt->dst_socket_id = _peer_socket_id; + pkt->timestamp = DurationCountMicroseconds(SteadyClock::now() - _start_timestamp); + sendDataPacket(pkt,ptr,(int)payloadSize,flush); + ptr += payloadSize; + size -= payloadSize; + } + + if(size >0 && ptr (); + pkt->f = 0; + pkt->packet_seq_number = _send_packet_seq_number++; + pkt->PP = 3; + pkt->O = 0; + pkt->KK = 0; + pkt->R = 0; + pkt->msg_number = _send_msg_number++; + pkt->dst_socket_id = _peer_socket_id; + pkt->timestamp = DurationCountMicroseconds(SteadyClock::now() - _start_timestamp); + sendDataPacket(pkt,ptr,(int)size,flush); + } + +} //////////// SrtTransportManager ////////////////////////// SrtTransportManager &SrtTransportManager::Instance() { static SrtTransportManager s_instance; diff --git a/srt/SrtTransport.hpp b/srt/SrtTransport.hpp index 02ff395e..892a85c4 100644 --- a/srt/SrtTransport.hpp +++ b/srt/SrtTransport.hpp @@ -37,6 +37,7 @@ public: * @param addr 数据来源地址 */ virtual void inputSockData(uint8_t *buf, int len, struct sockaddr_storage *addr); + virtual void onSendData(const Buffer::Ptr &buffer, bool flush); std::string getIdentifier(); @@ -73,6 +74,9 @@ private: void sendLightACKPacket(); void sendKeepLivePacket(); void sendShutDown(); + void sendMsgDropReq(uint32_t first ,uint32_t last); + + size_t getPayloadSize(); protected: void sendDataPacket(DataPacket::Ptr pkt,char* buf,int len,bool flush = false); void sendControlPacket(ControlPacket::Ptr pkt,bool flush = true); @@ -97,7 +101,10 @@ private: std::string _stream_id; uint32_t _sync_cookie = 0; + uint32_t _send_packet_seq_number = 0; + uint32_t _send_msg_number = 1; + PacketQueue::Ptr _send_buf; PacketQueue::Ptr _recv_buf; uint32_t _rtt = 100*1000; uint32_t _rtt_variance =50*1000; diff --git a/srt/SrtTransportImp.cpp b/srt/SrtTransportImp.cpp index 9d06659c..e2b1c903 100644 --- a/srt/SrtTransportImp.cpp +++ b/srt/SrtTransportImp.cpp @@ -145,7 +145,72 @@ void SrtTransportImp::emitOnPlay(){ } } void SrtTransportImp::doPlay(){ + //鉴权结果回调 + weak_ptr weak_self = dynamic_pointer_cast(shared_from_this()); + auto onRes = [weak_self](const string &err) { + auto strong_self = weak_self.lock(); + if (!strong_self) { + //本对象已经销毁 + return; + } + if (!err.empty()) { + //播放鉴权失败 + strong_self->onShutdown(SockException(Err_refused, err)); + return; + } + + //异步查找直播流 + MediaInfo info = strong_self->_media_info; + info._schema = TS_SCHEMA; + MediaSource::findAsync(info, strong_self->getSession(), [weak_self](const MediaSource::Ptr &src) { + auto strong_self = weak_self.lock(); + if (!strong_self) { + //本对象已经销毁 + return; + } + if (!src) { + //未找到该流 + strong_self->onShutdown(SockException(Err_shutdown)); + } else { + auto ts_src = dynamic_pointer_cast(src); + assert(ts_src); + ts_src->pause(false); + strong_self->_ts_reader = ts_src->getRing()->attach(strong_self->getPoller()); + strong_self->_ts_reader->setDetachCB([weak_self]() { + auto strong_self = weak_self.lock(); + if (!strong_self) { + //本对象已经销毁 + return; + } + strong_self->onShutdown(SockException(Err_shutdown)); + }); + strong_self->_ts_reader->setReadCB([weak_self](const TSMediaSource::RingDataType &ts_list) { + auto strong_self = weak_self.lock(); + if (!strong_self) { + //本对象已经销毁 + return; + } + size_t i = 0; + auto size = ts_list->size(); + ts_list->for_each([&](const TSPacket::Ptr &ts) { strong_self->onSendData(ts, ++i == size); }); + }); + }; + + }); + }; + + Broadcast::AuthInvoker invoker = [weak_self, onRes](const string &err) { + if (auto strongSelf = weak_self.lock()) { + strongSelf->getPoller()->async([onRes, err]() { onRes(err); }); + } + }; + + auto flag = NoticeCenter::Instance().emitEvent(Broadcast::kBroadcastMediaPlayed, _media_info, invoker, static_cast(*this)); + if (!flag) { + //该事件无人监听,默认不鉴权 + onRes(""); + } } std::string SrtTransportImp::get_peer_ip() { if (!_addr) { diff --git a/srt/SrtTransportImp.hpp b/srt/SrtTransportImp.hpp index a1f3b65b..80106d2d 100644 --- a/srt/SrtTransportImp.hpp +++ b/srt/SrtTransportImp.hpp @@ -1,10 +1,12 @@ #ifndef ZLMEDIAKIT_SRT_TRANSPORT_IMP_H #define ZLMEDIAKIT_SRT_TRANSPORT_IMP_H - +#include #include "Common/MultiMediaSourceMuxer.h" #include "Rtp/Decoder.h" +#include "TS/TSMediaSource.h" #include "SrtTransport.hpp" + namespace SRT { using namespace toolkit; using namespace mediakit; @@ -64,7 +66,8 @@ private: uint64_t _total_bytes = 0; Ticker _alive_ticker; std::unique_ptr _addr; - + // for player + TSMediaSource::RingType::RingReader::Ptr _ts_reader; // for pusher MultiMediaSourceMuxer::Ptr _muxer; DecoderImp::Ptr _decoder;