diff --git a/webrtc/SrtpSession.cpp b/webrtc/SrtpSession.cpp index 4591e466..7aa109c1 100644 --- a/webrtc/SrtpSession.cpp +++ b/webrtc/SrtpSession.cpp @@ -231,23 +231,11 @@ namespace RTC } } - bool SrtpSession::EncryptRtp(const uint8_t** data, size_t* len, uint8_t pt) + bool SrtpSession::EncryptRtp(uint8_t* data, size_t* len) { MS_TRACE(); - - // Ensure that the resulting SRTP packet fits into the encrypt buffer. - if (*len + SRTP_MAX_TRAILER_LEN > EncryptBufferSize) - { - MS_WARN_TAG(srtp, "cannot encrypt RTP packet, size too big (%zu bytes)", *len); - - return false; - } - - std::memcpy(EncryptBuffer, *data, *len); - EncryptBuffer[1] = (pt & 0x7F) | (EncryptBuffer[1] & 0x80); - srtp_err_status_t err = - srtp_protect(this->session, static_cast(EncryptBuffer), reinterpret_cast(len)); + srtp_protect(this->session, static_cast(data), reinterpret_cast(len)); if (DepLibSRTP::IsError(err)) { @@ -256,9 +244,6 @@ namespace RTC return false; } - // Update the given data pointer. - *data = (const uint8_t*)EncryptBuffer; - return true; } @@ -279,22 +264,11 @@ namespace RTC return true; } - bool SrtpSession::EncryptRtcp(const uint8_t** data, size_t* len) + bool SrtpSession::EncryptRtcp(uint8_t* data, size_t* len) { MS_TRACE(); - - // Ensure that the resulting SRTCP packet fits into the encrypt buffer. - if (*len + SRTP_MAX_TRAILER_LEN > EncryptBufferSize) - { - MS_WARN_TAG(srtp, "cannot encrypt RTCP packet, size too big (%zu bytes)", *len); - - return false; - } - - std::memcpy(EncryptBuffer, *data, *len); - srtp_err_status_t err = srtp_protect_rtcp( - this->session, static_cast(EncryptBuffer), reinterpret_cast(len)); + this->session, static_cast(data), reinterpret_cast(len)); if (DepLibSRTP::IsError(err)) { @@ -303,9 +277,6 @@ namespace RTC return false; } - // Update the given data pointer. - *data = (const uint8_t*)EncryptBuffer; - return true; } diff --git a/webrtc/SrtpSession.hpp b/webrtc/SrtpSession.hpp index d9d73fbe..a50684e5 100644 --- a/webrtc/SrtpSession.hpp +++ b/webrtc/SrtpSession.hpp @@ -64,9 +64,9 @@ namespace RTC ~SrtpSession(); public: - bool EncryptRtp(const uint8_t** data, size_t* len, uint8_t pt); + bool EncryptRtp(uint8_t* data, size_t* len); bool DecryptSrtp(uint8_t* data, size_t* len); - bool EncryptRtcp(const uint8_t** data, size_t* len); + bool EncryptRtcp(uint8_t* data, size_t* len); bool DecryptSrtcp(uint8_t* data, size_t* len); void RemoveStream(uint32_t ssrc) { @@ -76,9 +76,6 @@ namespace RTC private: // Allocated by this. srtp_t session{ nullptr }; - //rtp包最大1600 - static constexpr size_t EncryptBufferSize{ 2000 }; - uint8_t EncryptBuffer[EncryptBufferSize]; DepLibSRTP::Ptr _env; }; } // namespace RTC diff --git a/webrtc/WebRtcTransport.cpp b/webrtc/WebRtcTransport.cpp index 38fecd1f..badd921c 100644 --- a/webrtc/WebRtcTransport.cpp +++ b/webrtc/WebRtcTransport.cpp @@ -266,25 +266,24 @@ void WebRtcTransport::inputSockData(char *buf, size_t len, RTC::TransportTuple * } } -void WebRtcTransport::sendRtpPacket(char *buf, size_t len, bool flush, uint8_t pt) { - const uint8_t *p = (uint8_t *) buf; - bool ret = false; +void WebRtcTransport::sendRtpPacket(char *buf, size_t len, bool flush, TrackType type) { if (_srtp_session_send) { - ret = _srtp_session_send->EncryptRtp(&p, &len, pt); - } - if (ret) { - onSendSockData((char *) p, len, flush); + CHECK(len + SRTP_MAX_TRAILER_LEN <= sizeof(_srtp_buf)); + memcpy(_srtp_buf, buf, len); + onBeforeEncryptRtp((char *) _srtp_buf, len, type); + if (_srtp_session_send->EncryptRtp(_srtp_buf, &len)) { + onSendSockData((char *) _srtp_buf, len, flush); + } } } void WebRtcTransport::sendRtcpPacket(char *buf, size_t len, bool flush){ - const uint8_t *p = (uint8_t *) buf; - bool ret = false; if (_srtp_session_send) { - ret = _srtp_session_send->EncryptRtcp(&p, &len); - } - if (ret) { - onSendSockData((char *) p, len, flush); + CHECK(len + SRTP_MAX_TRAILER_LEN <= sizeof(_srtp_buf)); + memcpy(_srtp_buf, buf, len); + if (_srtp_session_send->EncryptRtcp(_srtp_buf, &len)) { + onSendSockData((char *) _srtp_buf, len, flush); + } } } @@ -611,6 +610,7 @@ void WebRtcTransportImp::onRtcp(const char *buf, size_t len) { } case RtcpType::RTCP_PSFB: case RtcpType::RTCP_RTPFB: { + DebugL << "\r\n" << rtcp->dumpString(); break; } default: break; @@ -663,8 +663,7 @@ static void setExtType(RtpExt &ext, RtpExtType tp) { } template -static void changeRtpExtId(const RtpPacket::Ptr &rtp, const Type &map) { - auto header = rtp->getHeader(); +static void changeRtpExtId(const RtpHeader *header, const Type &map) { auto ext_map = RtpExt::getExtValue(header); for (auto &pr : ext_map) { auto it = map.find((typename Type::key_type) (pr.first)); @@ -675,29 +674,35 @@ static void changeRtpExtId(const RtpPacket::Ptr &rtp, const Type &map) { } setExtType(pr.second, it->first); setExtType(pr.second, it->second); + DebugL << pr.second.dumpString(); pr.second.setExtId((uint8_t) it->second); } } void WebRtcTransportImp::onBeforeSortedRtp(const RtpPayloadInfo &info, const RtpPacket::Ptr &rtp) { - changeRtpExtId(rtp, _rtp_ext_id_to_type); + changeRtpExtId(rtp->getHeader(), _rtp_ext_id_to_type); //统计rtp收到的情况,好做rr汇报 info.rtcp_context_recv->onRtp(rtp->getSeq(), rtp->getStampMS(), rtp->size() - RtpPacket::kRtpTcpHeaderSize); } void WebRtcTransportImp::onSendRtp(const RtpPacket::Ptr &rtp, bool flush){ - auto &pt = _send_rtp_pt[rtp->type]; + auto pt = _send_rtp_pt[rtp->type]; if (pt == 0xFF) { //忽略,对方不支持该编码类型 return; } - changeRtpExtId(rtp, _rtp_ext_type_to_id); _bytes_usage += rtp->size() - RtpPacket::kRtpTcpHeaderSize; - sendRtpPacket(rtp->data() + RtpPacket::kRtpTcpHeaderSize, rtp->size() - RtpPacket::kRtpTcpHeaderSize, flush, pt); + sendRtpPacket(rtp->data() + RtpPacket::kRtpTcpHeaderSize, rtp->size() - RtpPacket::kRtpTcpHeaderSize, flush, rtp->type); //统计rtp发送情况,好做sr汇报 _rtp_info_pt[pt].rtcp_context_send->onRtp(rtp->getSeq(), rtp->getStampMS(), rtp->size() - RtpPacket::kRtpTcpHeaderSize); } +void WebRtcTransportImp::onBeforeEncryptRtp(const char *buf, size_t len, TrackType type) { + auto header = (RtpHeader *)buf; + header->pt = _send_rtp_pt[type]; + changeRtpExtId(header, _rtp_ext_type_to_id); +} + void WebRtcTransportImp::onShutdown(const SockException &ex){ InfoL << ex.what(); _self = nullptr; diff --git a/webrtc/WebRtcTransport.h b/webrtc/WebRtcTransport.h index 84fdbde5..44fdf973 100644 --- a/webrtc/WebRtcTransport.h +++ b/webrtc/WebRtcTransport.h @@ -60,9 +60,9 @@ public: * @param buf rtcp内容 * @param len rtcp长度 * @param flush 是否flush socket - * @param pt rtp payload type + * @param type rtp类型 */ - void sendRtpPacket(char *buf, size_t len, bool flush, uint8_t pt); + void sendRtpPacket(char *buf, size_t len, bool flush, TrackType type); void sendRtcpPacket(char *buf, size_t len, bool flush); const EventPoller::Ptr& getPoller() const; @@ -100,6 +100,7 @@ protected: virtual void onRtp(const char *buf, size_t len) = 0; virtual void onRtcp(const char *buf, size_t len) = 0; virtual void onShutdown(const SockException &ex) = 0; + virtual void onBeforeEncryptRtp(const char *buf, size_t len, TrackType type) = 0; protected: const RtcSession& getSdp(SdpType type) const; @@ -112,6 +113,7 @@ private: void setRemoteDtlsFingerprint(const RtcSession &remote); private: + uint8_t _srtp_buf[2000]; EventPoller::Ptr _poller; std::shared_ptr _ice_server; std::shared_ptr _dtls_transport; @@ -150,6 +152,8 @@ protected: void onRtp(const char *buf, size_t len) override; void onRtcp(const char *buf, size_t len) override; + void onBeforeEncryptRtp(const char *buf, size_t len, TrackType type) override; + void onShutdown(const SockException &ex) override; ///////MediaSourceEvent override///////