diff --git a/webrtc/WebRtcTransport.cpp b/webrtc/WebRtcTransport.cpp index 38677d45..9f9a305a 100644 --- a/webrtc/WebRtcTransport.cpp +++ b/webrtc/WebRtcTransport.cpp @@ -269,21 +269,22 @@ void WebRtcTransport::inputSockData(char *buf, size_t len, RTC::TransportTuple * } } -void WebRtcTransport::sendRtpPacket(char *buf, size_t len, bool flush, TrackType type) { +void WebRtcTransport::sendRtpPacket(char *buf, size_t len, bool flush, void *ctx) { if (_srtp_session_send) { CHECK(len + SRTP_MAX_TRAILER_LEN <= sizeof(_srtp_buf)); memcpy(_srtp_buf, buf, len); - onBeforeEncryptRtp((char *) _srtp_buf, len, type); + onBeforeEncryptRtp((char *) _srtp_buf, len, ctx); if (_srtp_session_send->EncryptRtp(_srtp_buf, &len)) { onSendSockData((char *) _srtp_buf, len, flush); } } } -void WebRtcTransport::sendRtcpPacket(char *buf, size_t len, bool flush){ +void WebRtcTransport::sendRtcpPacket(char *buf, size_t len, bool flush, void *ctx){ if (_srtp_session_send) { CHECK(len + SRTP_MAX_TRAILER_LEN <= sizeof(_srtp_buf)); memcpy(_srtp_buf, buf, len); + onBeforeEncryptRtcp((char *) _srtp_buf, len, ctx); if (_srtp_session_send->EncryptRtcp(_srtp_buf, &len)) { onSendSockData((char *) _srtp_buf, len, flush); } @@ -414,10 +415,13 @@ void WebRtcTransportImp::onStartWebRTC() { auto m_with_ssrc = getSdpWithSSRC().getMedia(m.type); //获取offer端rtp的ssrc和pt相关信息 auto &ref = _rtp_info_pt[plan.pt]; - _rtp_info_ssrc[m_with_ssrc->rtp_rtx_ssrc[0].ssrc] = &ref; ref.plan = &plan; ref.media = m_with_ssrc; ref.is_common_rtp = getCodecId(plan.codec) != CodecInvalid; + if (ref.is_common_rtp) { + //rtp + _rtp_info_ssrc[m_with_ssrc->rtp_rtx_ssrc[0].ssrc] = &ref; + } ref.rtcp_context_recv = std::make_shared(ref.plan->sample_rate, true); ref.rtcp_context_send = std::make_shared(ref.plan->sample_rate, false); ref.receiver = std::make_shared([&ref, this](RtpPacket::Ptr rtp) { @@ -628,7 +632,30 @@ void WebRtcTransportImp::onRtcp(const char *buf, size_t len) { } case RtcpType::RTCP_PSFB: case RtcpType::RTCP_RTPFB: { -// DebugL << "\r\n" << rtcp->dumpString(); + RtcpFB *fb = (RtcpFB *) rtcp; + auto it = _rtp_info_ssrc.find(fb->ssrc); + if (it == _rtp_info_ssrc.end()) { + WarnL << "未识别的 rtcp包:" << rtcp->dumpString(); + return; + } + if ((RtcpType) rtcp->pt == RtcpType::RTCP_PSFB) { +// DebugL << "\r\n" << rtcp->dumpString(); + break; + } + //RTPFB + switch ((RTPFBType) rtcp->report_count) { + case RTPFBType::RTCP_RTPFB_NACK : { + auto &fci = fb->getFci(); + it->second->nack_list.for_each_nack(fci, [&](const RtpPacket::Ptr &rtp) { + //rtp重传 + onSendRtp(rtp, true, true); + }); + break; + } + default: +// DebugL << "\r\n" << rtcp->dumpString(); + break; + } break; } default: break; @@ -703,27 +730,36 @@ void WebRtcTransportImp::onBeforeSortedRtp(const RtpPayloadInfo &info, const Rtp info.rtcp_context_recv->onRtp(rtp->getSeq(), rtp->getStampMS(), rtp->size() - RtpPacket::kRtpTcpHeaderSize); } -void WebRtcTransportImp::onSendRtp(const RtpPacket::Ptr &rtp, bool flush){ +void WebRtcTransportImp::onSendRtp(const RtpPacket::Ptr &rtp, bool flush, bool rtx){ auto info = _send_rtp_info[rtp->type]; if (!info) { //忽略,对方不支持该编码类型 return; } + if (!rtx) { + //统计rtp发送情况,好做sr汇报 + info->rtcp_context_send->onRtp(rtp->getSeq(), rtp->getStampMS(), rtp->size() - RtpPacket::kRtpTcpHeaderSize); + info->nack_list.push_back(rtp); + } else { + WarnL << "重传rtp:" << rtp->getSeq(); + } + sendRtpPacket(rtp->data() + RtpPacket::kRtpTcpHeaderSize, rtp->size() - RtpPacket::kRtpTcpHeaderSize, flush, info); _bytes_usage += rtp->size() - RtpPacket::kRtpTcpHeaderSize; - sendRtpPacket(rtp->data() + RtpPacket::kRtpTcpHeaderSize, rtp->size() - RtpPacket::kRtpTcpHeaderSize, flush, rtp->type); - //统计rtp发送情况,好做sr汇报 - info->rtcp_context_send->onRtp(rtp->getSeq(), rtp->getStampMS(), rtp->size() - RtpPacket::kRtpTcpHeaderSize); } -void WebRtcTransportImp::onBeforeEncryptRtp(const char *buf, size_t len, TrackType type) { +void WebRtcTransportImp::onBeforeEncryptRtp(const char *buf, size_t len, void *ctx) { + RtpPayloadInfo *info = reinterpret_cast(ctx); auto header = (RtpHeader *)buf; - auto info = _send_rtp_info[type]; //修改目标pt和ssrc header->pt = info->plan->pt; header->ssrc = htons(info->media->rtp_rtx_ssrc[0].ssrc); changeRtpExtId(header, _rtp_ext_type_to_id); } +void WebRtcTransportImp::onBeforeEncryptRtcp(const char *buf, size_t len, void *ctx) { + +} + void WebRtcTransportImp::onShutdown(const SockException &ex){ InfoL << ex.what(); _self = nullptr; diff --git a/webrtc/WebRtcTransport.h b/webrtc/WebRtcTransport.h index 9e6116c1..7b911020 100644 --- a/webrtc/WebRtcTransport.h +++ b/webrtc/WebRtcTransport.h @@ -21,6 +21,7 @@ #include "Network/Socket.h" #include "Rtsp/RtspMediaSourceImp.h" #include "Rtcp/RtcpContext.h" +#include "Rtcp/RtcpFCI.h" using namespace toolkit; using namespace mediakit; @@ -60,10 +61,10 @@ public: * @param buf rtcp内容 * @param len rtcp长度 * @param flush 是否flush socket - * @param type rtp类型 + * @param ctx 用户指针 */ - void sendRtpPacket(char *buf, size_t len, bool flush, TrackType type); - void sendRtcpPacket(char *buf, size_t len, bool flush); + void sendRtpPacket(char *buf, size_t len, bool flush, void *ctx = nullptr); + void sendRtcpPacket(char *buf, size_t len, bool flush, void *ctx = nullptr); const EventPoller::Ptr& getPoller() const; @@ -100,7 +101,8 @@ protected: virtual void onRtp(const char *buf, size_t len) = 0; virtual void onRtcp(const char *buf, size_t len) = 0; virtual void onShutdown(const SockException &ex) = 0; - virtual void onBeforeEncryptRtp(const char *buf, size_t len, TrackType type) = 0; + virtual void onBeforeEncryptRtp(const char *buf, size_t len, void *ctx) = 0; + virtual void onBeforeEncryptRtcp(const char *buf, size_t len, void *ctx) = 0; protected: const RtcSession& getSdp(SdpType type) const; @@ -125,6 +127,69 @@ private: class RtpReceiverImp; +class NackList { +public: + void push_back(RtpPacket::Ptr rtp) { + auto seq = rtp->getSeq(); + nack_cache_seq.emplace_back(seq); + nack_cache_pkt.emplace(seq, std::move(rtp)); + while (get_cache_ms() > kMaxNackMS) { + //需要清除部分nack缓存 + pop_front(); + } + } + + template + void for_each_nack(const FCI_NACK &nack, const FUNC &func) { + auto seq = nack.getPid(); + for (auto bit : nack.getBitArray()) { + if (!bit) { + //丢包 + RtpPacket::Ptr *ptr = get_rtp(seq); + if (ptr) { + func(*ptr); + } + } + ++seq; + } + } + +private: + void pop_front() { + if (nack_cache_seq.empty()) { + return; + } + nack_cache_pkt.erase(nack_cache_seq.front()); + nack_cache_seq.pop_front(); + } + + RtpPacket::Ptr *get_rtp(uint16_t seq) { + auto it = nack_cache_pkt.find(seq); + if (it == nack_cache_pkt.end()) { + return nullptr; + } + return &it->second; + } + + uint32_t get_cache_ms() { + if (nack_cache_seq.size() < 2) { + return 0; + } + uint32_t back = nack_cache_pkt[nack_cache_seq.back()]->getStampMS(); + uint32_t front = nack_cache_pkt[nack_cache_seq.front()]->getStampMS(); + if (back > front) { + return back - front; + } + //很有可能回环了 + return back + (UINT32_MAX - front); + } + +private: + static constexpr uint32_t kMaxNackMS = 10 * 1000; + deque nack_cache_seq; + unordered_map nack_cache_pkt; +}; + class WebRtcTransportImp : public WebRtcTransport, public MediaSourceEvent, public SockInfo, public std::enable_shared_from_this{ public: using Ptr = std::shared_ptr; @@ -152,7 +217,8 @@ protected: void onRtp(const char *buf, size_t len) override; void onRtcp(const char *buf, size_t len) override; - void onBeforeEncryptRtp(const char *buf, size_t len, TrackType type) override; + void onBeforeEncryptRtp(const char *buf, size_t len, void *ctx) override; + void onBeforeEncryptRtcp(const char *buf, size_t len, void *ctx) override; void onShutdown(const SockException &ex) override; @@ -184,7 +250,7 @@ private: WebRtcTransportImp(const EventPoller::Ptr &poller); void onCreate() override; void onDestory() override; - void onSendRtp(const RtpPacket::Ptr &rtp, bool flush); + void onSendRtp(const RtpPacket::Ptr &rtp, bool flush, bool rtx = false); SdpAttrCandidate::Ptr getIceCandidate() const; bool canSendRtp() const; bool canRecvRtp() const; @@ -198,10 +264,11 @@ private: std::shared_ptr receiver; RtcpContext::Ptr rtcp_context_recv; RtcpContext::Ptr rtcp_context_send; + NackList nack_list; }; - void onSortedRtp(const RtpPayloadInfo &info,RtpPacket::Ptr rtp); - void onBeforeSortedRtp(const RtpPayloadInfo &info,const RtpPacket::Ptr &rtp); + void onSortedRtp(const RtpPayloadInfo &info, RtpPacket::Ptr rtp); + void onBeforeSortedRtp(const RtpPayloadInfo &info, const RtpPacket::Ptr &rtp); private: //用掉的总流量