拷贝后再修改rtp,防止修改共享数据

This commit is contained in:
xiongziliang 2021-05-08 21:16:51 +08:00
parent 47dc661bb2
commit 60a6d4af0b
4 changed files with 36 additions and 59 deletions

View File

@ -231,23 +231,11 @@ namespace RTC
} }
} }
bool SrtpSession::EncryptRtp(const uint8_t** data, size_t* len, uint8_t pt) bool SrtpSession::EncryptRtp(uint8_t* data, size_t* len)
{ {
MS_TRACE(); MS_TRACE();
// Ensure that the resulting SRTP packet fits into the encrypt buffer.
if (*len + SRTP_MAX_TRAILER_LEN > EncryptBufferSize)
{
MS_WARN_TAG(srtp, "cannot encrypt RTP packet, size too big (%zu bytes)", *len);
return false;
}
std::memcpy(EncryptBuffer, *data, *len);
EncryptBuffer[1] = (pt & 0x7F) | (EncryptBuffer[1] & 0x80);
srtp_err_status_t err = srtp_err_status_t err =
srtp_protect(this->session, static_cast<void*>(EncryptBuffer), reinterpret_cast<int*>(len)); srtp_protect(this->session, static_cast<void*>(data), reinterpret_cast<int*>(len));
if (DepLibSRTP::IsError(err)) if (DepLibSRTP::IsError(err))
{ {
@ -256,9 +244,6 @@ namespace RTC
return false; return false;
} }
// Update the given data pointer.
*data = (const uint8_t*)EncryptBuffer;
return true; return true;
} }
@ -279,22 +264,11 @@ namespace RTC
return true; return true;
} }
bool SrtpSession::EncryptRtcp(const uint8_t** data, size_t* len) bool SrtpSession::EncryptRtcp(uint8_t* data, size_t* len)
{ {
MS_TRACE(); MS_TRACE();
// Ensure that the resulting SRTCP packet fits into the encrypt buffer.
if (*len + SRTP_MAX_TRAILER_LEN > EncryptBufferSize)
{
MS_WARN_TAG(srtp, "cannot encrypt RTCP packet, size too big (%zu bytes)", *len);
return false;
}
std::memcpy(EncryptBuffer, *data, *len);
srtp_err_status_t err = srtp_protect_rtcp( srtp_err_status_t err = srtp_protect_rtcp(
this->session, static_cast<void*>(EncryptBuffer), reinterpret_cast<int*>(len)); this->session, static_cast<void*>(data), reinterpret_cast<int*>(len));
if (DepLibSRTP::IsError(err)) if (DepLibSRTP::IsError(err))
{ {
@ -303,9 +277,6 @@ namespace RTC
return false; return false;
} }
// Update the given data pointer.
*data = (const uint8_t*)EncryptBuffer;
return true; return true;
} }

View File

@ -64,9 +64,9 @@ namespace RTC
~SrtpSession(); ~SrtpSession();
public: public:
bool EncryptRtp(const uint8_t** data, size_t* len, uint8_t pt); bool EncryptRtp(uint8_t* data, size_t* len);
bool DecryptSrtp(uint8_t* data, size_t* len); bool DecryptSrtp(uint8_t* data, size_t* len);
bool EncryptRtcp(const uint8_t** data, size_t* len); bool EncryptRtcp(uint8_t* data, size_t* len);
bool DecryptSrtcp(uint8_t* data, size_t* len); bool DecryptSrtcp(uint8_t* data, size_t* len);
void RemoveStream(uint32_t ssrc) void RemoveStream(uint32_t ssrc)
{ {
@ -76,9 +76,6 @@ namespace RTC
private: private:
// Allocated by this. // Allocated by this.
srtp_t session{ nullptr }; srtp_t session{ nullptr };
//rtp包最大1600
static constexpr size_t EncryptBufferSize{ 2000 };
uint8_t EncryptBuffer[EncryptBufferSize];
DepLibSRTP::Ptr _env; DepLibSRTP::Ptr _env;
}; };
} // namespace RTC } // namespace RTC

View File

@ -266,25 +266,24 @@ void WebRtcTransport::inputSockData(char *buf, size_t len, RTC::TransportTuple *
} }
} }
void WebRtcTransport::sendRtpPacket(char *buf, size_t len, bool flush, uint8_t pt) { void WebRtcTransport::sendRtpPacket(char *buf, size_t len, bool flush, TrackType type) {
const uint8_t *p = (uint8_t *) buf;
bool ret = false;
if (_srtp_session_send) { if (_srtp_session_send) {
ret = _srtp_session_send->EncryptRtp(&p, &len, pt); CHECK(len + SRTP_MAX_TRAILER_LEN <= sizeof(_srtp_buf));
memcpy(_srtp_buf, buf, len);
onBeforeEncryptRtp((char *) _srtp_buf, len, type);
if (_srtp_session_send->EncryptRtp(_srtp_buf, &len)) {
onSendSockData((char *) _srtp_buf, len, flush);
} }
if (ret) {
onSendSockData((char *) p, len, flush);
} }
} }
void WebRtcTransport::sendRtcpPacket(char *buf, size_t len, bool flush){ void WebRtcTransport::sendRtcpPacket(char *buf, size_t len, bool flush){
const uint8_t *p = (uint8_t *) buf;
bool ret = false;
if (_srtp_session_send) { if (_srtp_session_send) {
ret = _srtp_session_send->EncryptRtcp(&p, &len); CHECK(len + SRTP_MAX_TRAILER_LEN <= sizeof(_srtp_buf));
memcpy(_srtp_buf, buf, len);
if (_srtp_session_send->EncryptRtcp(_srtp_buf, &len)) {
onSendSockData((char *) _srtp_buf, len, flush);
} }
if (ret) {
onSendSockData((char *) p, len, flush);
} }
} }
@ -611,6 +610,7 @@ void WebRtcTransportImp::onRtcp(const char *buf, size_t len) {
} }
case RtcpType::RTCP_PSFB: case RtcpType::RTCP_PSFB:
case RtcpType::RTCP_RTPFB: { case RtcpType::RTCP_RTPFB: {
DebugL << "\r\n" << rtcp->dumpString();
break; break;
} }
default: break; default: break;
@ -663,8 +663,7 @@ static void setExtType(RtpExt &ext, RtpExtType tp) {
} }
template<typename Type> template<typename Type>
static void changeRtpExtId(const RtpPacket::Ptr &rtp, const Type &map) { static void changeRtpExtId(const RtpHeader *header, const Type &map) {
auto header = rtp->getHeader();
auto ext_map = RtpExt::getExtValue(header); auto ext_map = RtpExt::getExtValue(header);
for (auto &pr : ext_map) { for (auto &pr : ext_map) {
auto it = map.find((typename Type::key_type) (pr.first)); auto it = map.find((typename Type::key_type) (pr.first));
@ -675,29 +674,35 @@ static void changeRtpExtId(const RtpPacket::Ptr &rtp, const Type &map) {
} }
setExtType(pr.second, it->first); setExtType(pr.second, it->first);
setExtType(pr.second, it->second); setExtType(pr.second, it->second);
DebugL << pr.second.dumpString();
pr.second.setExtId((uint8_t) it->second); pr.second.setExtId((uint8_t) it->second);
} }
} }
void WebRtcTransportImp::onBeforeSortedRtp(const RtpPayloadInfo &info, const RtpPacket::Ptr &rtp) { void WebRtcTransportImp::onBeforeSortedRtp(const RtpPayloadInfo &info, const RtpPacket::Ptr &rtp) {
changeRtpExtId(rtp, _rtp_ext_id_to_type); changeRtpExtId(rtp->getHeader(), _rtp_ext_id_to_type);
//统计rtp收到的情况好做rr汇报 //统计rtp收到的情况好做rr汇报
info.rtcp_context_recv->onRtp(rtp->getSeq(), rtp->getStampMS(), rtp->size() - RtpPacket::kRtpTcpHeaderSize); info.rtcp_context_recv->onRtp(rtp->getSeq(), rtp->getStampMS(), rtp->size() - RtpPacket::kRtpTcpHeaderSize);
} }
void WebRtcTransportImp::onSendRtp(const RtpPacket::Ptr &rtp, bool flush){ void WebRtcTransportImp::onSendRtp(const RtpPacket::Ptr &rtp, bool flush){
auto &pt = _send_rtp_pt[rtp->type]; auto pt = _send_rtp_pt[rtp->type];
if (pt == 0xFF) { if (pt == 0xFF) {
//忽略,对方不支持该编码类型 //忽略,对方不支持该编码类型
return; return;
} }
changeRtpExtId(rtp, _rtp_ext_type_to_id);
_bytes_usage += rtp->size() - RtpPacket::kRtpTcpHeaderSize; _bytes_usage += rtp->size() - RtpPacket::kRtpTcpHeaderSize;
sendRtpPacket(rtp->data() + RtpPacket::kRtpTcpHeaderSize, rtp->size() - RtpPacket::kRtpTcpHeaderSize, flush, pt); sendRtpPacket(rtp->data() + RtpPacket::kRtpTcpHeaderSize, rtp->size() - RtpPacket::kRtpTcpHeaderSize, flush, rtp->type);
//统计rtp发送情况好做sr汇报 //统计rtp发送情况好做sr汇报
_rtp_info_pt[pt].rtcp_context_send->onRtp(rtp->getSeq(), rtp->getStampMS(), rtp->size() - RtpPacket::kRtpTcpHeaderSize); _rtp_info_pt[pt].rtcp_context_send->onRtp(rtp->getSeq(), rtp->getStampMS(), rtp->size() - RtpPacket::kRtpTcpHeaderSize);
} }
void WebRtcTransportImp::onBeforeEncryptRtp(const char *buf, size_t len, TrackType type) {
auto header = (RtpHeader *)buf;
header->pt = _send_rtp_pt[type];
changeRtpExtId(header, _rtp_ext_type_to_id);
}
void WebRtcTransportImp::onShutdown(const SockException &ex){ void WebRtcTransportImp::onShutdown(const SockException &ex){
InfoL << ex.what(); InfoL << ex.what();
_self = nullptr; _self = nullptr;

View File

@ -60,9 +60,9 @@ public:
* @param buf rtcp内容 * @param buf rtcp内容
* @param len rtcp长度 * @param len rtcp长度
* @param flush flush socket * @param flush flush socket
* @param pt rtp payload type * @param type rtp类型
*/ */
void sendRtpPacket(char *buf, size_t len, bool flush, uint8_t pt); void sendRtpPacket(char *buf, size_t len, bool flush, TrackType type);
void sendRtcpPacket(char *buf, size_t len, bool flush); void sendRtcpPacket(char *buf, size_t len, bool flush);
const EventPoller::Ptr& getPoller() const; const EventPoller::Ptr& getPoller() const;
@ -100,6 +100,7 @@ protected:
virtual void onRtp(const char *buf, size_t len) = 0; virtual void onRtp(const char *buf, size_t len) = 0;
virtual void onRtcp(const char *buf, size_t len) = 0; virtual void onRtcp(const char *buf, size_t len) = 0;
virtual void onShutdown(const SockException &ex) = 0; virtual void onShutdown(const SockException &ex) = 0;
virtual void onBeforeEncryptRtp(const char *buf, size_t len, TrackType type) = 0;
protected: protected:
const RtcSession& getSdp(SdpType type) const; const RtcSession& getSdp(SdpType type) const;
@ -112,6 +113,7 @@ private:
void setRemoteDtlsFingerprint(const RtcSession &remote); void setRemoteDtlsFingerprint(const RtcSession &remote);
private: private:
uint8_t _srtp_buf[2000];
EventPoller::Ptr _poller; EventPoller::Ptr _poller;
std::shared_ptr<RTC::IceServer> _ice_server; std::shared_ptr<RTC::IceServer> _ice_server;
std::shared_ptr<RTC::DtlsTransport> _dtls_transport; std::shared_ptr<RTC::DtlsTransport> _dtls_transport;
@ -150,6 +152,8 @@ protected:
void onRtp(const char *buf, size_t len) override; void onRtp(const char *buf, size_t len) override;
void onRtcp(const char *buf, size_t len) override; void onRtcp(const char *buf, size_t len) override;
void onBeforeEncryptRtp(const char *buf, size_t len, TrackType type) override;
void onShutdown(const SockException &ex) override; void onShutdown(const SockException &ex) override;
///////MediaSourceEvent override/////// ///////MediaSourceEvent override///////