重写webrtc sdp校验逻辑,确保无有效track时抛异常:#1157

This commit is contained in:
ziyue 2021-10-16 16:46:05 +08:00
parent b10fc52384
commit 9433a0c383
3 changed files with 53 additions and 143 deletions

View File

@ -291,12 +291,10 @@ string RtcSessionSdp::toString() const {
//////////////////////////////////////////////////////////////////////////////////////////
#define SDP_THROW() throw std::invalid_argument(StrPrinter << "解析sdp " << getKey() << " 字段失败:" << str)
#define CHECK_SDP(exp) CHECK(exp, "解析sdp ", getKey(), " 字段失败:", str)
void SdpTime::parse(const string &str) {
if (sscanf(str.data(), "%" SCNu64 " %" SCNu64, &start, &stop) != 2) {
SDP_THROW();
}
CHECK_SDP(sscanf(str.data(), "%" SCNu64 " %" SCNu64, &start, &stop) == 2);
}
string SdpTime::toString() const {
@ -308,9 +306,7 @@ string SdpTime::toString() const {
void SdpOrigin::parse(const string &str) {
auto vec = split(str, " ");
if (vec.size() != 6) {
SDP_THROW();
}
CHECK_SDP(vec.size() == 6);
username = vec[0];
session_id = vec[1];
session_version = vec[2];
@ -328,9 +324,7 @@ string SdpOrigin::toString() const {
void SdpConnection::parse(const string &str) {
auto vec = split(str, " ");
if (vec.size() != 3) {
SDP_THROW();
}
CHECK_SDP(vec.size() == 3);
nettype = vec[0];
addrtype = vec[1];
address = vec[2];
@ -345,9 +339,7 @@ string SdpConnection::toString() const {
void SdpBandwidth::parse(const string &str) {
auto vec = split(str, ":");
if (vec.size() != 2) {
SDP_THROW();
}
CHECK_SDP(vec.size() == 2);
bwtype = vec[0];
bandwidth = atoi(vec[1].data());
}
@ -361,20 +353,14 @@ string SdpBandwidth::toString() const {
void SdpMedia::parse(const string &str) {
auto vec = split(str, " ");
if (vec.size() < 4) {
SDP_THROW();
}
CHECK_SDP(vec.size() >= 4);
type = getTrackType(vec[0]);
if (type == TrackInvalid) {
SDP_THROW();
}
CHECK_SDP(type != TrackInvalid);
port = atoi(vec[1].data());
proto = vec[2];
for (size_t i = 3; i < vec.size(); ++i) {
auto pt = atoi(vec[i].data());
if (type != TrackApplication && pt > 0xFF) {
SDP_THROW();
}
CHECK_SDP(type == TrackApplication || pt <= 0xFF);
fmts.emplace_back(pt);
}
}
@ -417,9 +403,7 @@ string SdpAttr::toString() const {
void SdpAttrGroup::parse(const string &str) {
auto vec = split(str, " ");
if (vec.size() < 2) {
SDP_THROW();
}
CHECK_SDP(vec.size() >= 2);
type = vec[0];
vec.erase(vec.begin());
mids = std::move(vec);
@ -438,9 +422,7 @@ string SdpAttrGroup::toString() const {
void SdpAttrMsidSemantic::parse(const string &str) {
auto vec = split(str, " ");
if (vec.size() < 1) {
SDP_THROW();
}
CHECK_SDP(vec.size() >= 1);
msid = vec[0];
token = vec.size() > 1 ? vec[1] : "";
}
@ -458,9 +440,7 @@ string SdpAttrMsidSemantic::toString() const {
void SdpAttrRtcp::parse(const string &str) {
auto vec = split(str, " ");
if (vec.size() != 4) {
SDP_THROW();
}
CHECK_SDP(vec.size() == 4);
port = atoi(vec[0].data());
nettype = vec[1];
addrtype = vec[2];
@ -503,9 +483,7 @@ string SdpAttrIceOption::toString() const{
void SdpAttrFingerprint::parse(const string &str) {
auto vec = split(str, " ");
if (vec.size() != 2) {
SDP_THROW();
}
CHECK_SDP(vec.size() == 2);
algorithm = vec[0];
hash = vec[1];
}
@ -519,9 +497,7 @@ string SdpAttrFingerprint::toString() const {
void SdpAttrSetup::parse(const string &str) {
role = getDtlsRole(str);
if (role == DtlsRole::invalid) {
SDP_THROW();
}
CHECK_SDP(role != DtlsRole::invalid);
}
string SdpAttrSetup::toString() const {
@ -535,9 +511,7 @@ void SdpAttrExtmap::parse(const string &str) {
char buf[128] = {0};
char direction_buf[32] = {0};
if (sscanf(str.data(), "%" SCNd8 "/%31[^ ] %127s", &id, direction_buf, buf) != 3) {
if (sscanf(str.data(), "%" SCNd8 " %127s", &id, buf) != 2) {
SDP_THROW();
}
CHECK_SDP(sscanf(str.data(), "%" SCNd8 " %127s", &id, buf) == 2);
direction = RtpDirection::sendrecv;
} else {
direction = getRtpDirection(direction_buf);
@ -559,9 +533,7 @@ string SdpAttrExtmap::toString() const {
void SdpAttrRtpMap::parse(const string &str) {
char buf[32] = {0};
if (sscanf(str.data(), "%" SCNu8 " %31[^/]/%" SCNd32 "/%" SCNd32, &pt, buf, &sample_rate, &channel) != 4) {
if (sscanf(str.data(), "%" SCNu8 " %31[^/]/%" SCNd32, &pt, buf, &sample_rate) != 3) {
SDP_THROW();
}
CHECK_SDP(sscanf(str.data(), "%" SCNu8 " %31[^/]/%" SCNd32, &pt, buf, &sample_rate) == 3);
if (getTrackType(getCodecId(buf)) == TrackAudio) {
//未指定通道数时且为音频时那么通道数默认为1
channel = 1;
@ -584,9 +556,7 @@ string SdpAttrRtpMap::toString() const {
void SdpAttrRtcpFb::parse(const string &str_in) {
auto str = str_in + "\n";
char rtcp_type_buf[32] = {0};
if (2 != sscanf(str.data(), "%" SCNu8 " %31[^\n]", &pt, rtcp_type_buf)) {
SDP_THROW();
}
CHECK_SDP(sscanf(str.data(), "%" SCNu8 " %31[^\n]", &pt, rtcp_type_buf) == 2);
rtcp_type = rtcp_type_buf;
}
@ -599,9 +569,7 @@ string SdpAttrRtcpFb::toString() const {
void SdpAttrFmtp::parse(const string &str) {
auto pos = str.find(' ');
if (pos == string::npos) {
SDP_THROW();
}
CHECK_SDP(pos != string::npos);
pt = atoi(str.substr(0, pos).data());
auto vec = split(str.substr(pos + 1), ";");
for (auto &item : vec) {
@ -613,9 +581,7 @@ void SdpAttrFmtp::parse(const string &str) {
fmtp.emplace(std::make_pair(item.substr(0, pos), item.substr(pos + 1)));
}
}
if (fmtp.empty()) {
SDP_THROW();
}
CHECK_SDP(!fmtp.empty());
}
string SdpAttrFmtp::toString() const {
@ -640,7 +606,7 @@ void SdpAttrSSRC::parse(const string &str_in) {
} else if (2 == sscanf(str.data(), "%" SCNu32 " %31s[^\n]", &ssrc, attr_buf)) {
attribute = attr_buf;
} else {
SDP_THROW();
CHECK_SDP(0);
}
}
@ -658,15 +624,12 @@ string SdpAttrSSRC::toString() const {
void SdpAttrSSRCGroup::parse(const string &str) {
auto vec = split(str, " ");
if (vec.size() >= 3) {
CHECK_SDP(vec.size() >= 3);
type = std::move(vec[0]);
CHECK(isFID() || isSIM());
vec.erase(vec.begin());
for (auto ssrc : vec) {
ssrcs.emplace_back((uint32_t)atoll(ssrc.data()));
}
} else {
SDP_THROW();
ssrcs.emplace_back((uint32_t) atoll(ssrc.data()));
}
}
@ -685,11 +648,8 @@ string SdpAttrSSRCGroup::toString() const {
void SdpAttrSctpMap::parse(const string &str) {
char subtypes_buf[64] = {0};
if (3 == sscanf(str.data(), "%" SCNu16 " %63[^ ] %" SCNd32, &port, subtypes_buf, &streams)) {
CHECK_SDP(3 == sscanf(str.data(), "%" SCNu16 " %63[^ ] %" SCNd32, &port, subtypes_buf, &streams));
subtypes = subtypes_buf;
} else {
SDP_THROW();
}
}
string SdpAttrSctpMap::toString() const {
@ -710,10 +670,8 @@ void SdpAttrCandidate::parse(const string &str) {
char type_buf[16] = {0};
// https://datatracker.ietf.org/doc/html/rfc5245#section-15.1
if (7 != sscanf(str.data(), "%32[^ ] %" SCNu32 " %15[^ ] %" SCNu32 " %31[^ ] %" SCNu16 " typ %15[^ ]",
foundation_buf, &component, transport_buf, &priority, address_buf, &port, type_buf)) {
SDP_THROW();
}
CHECK_SDP(sscanf(str.data(), "%32[^ ] %" SCNu32 " %15[^ ] %" SCNu32 " %31[^ ] %" SCNu16 " typ %15[^ ]",
foundation_buf, &component, transport_buf, &priority, address_buf, &port, type_buf) == 7);
foundation = foundation_buf;
transport = transport_buf;
address = address_buf;
@ -757,9 +715,7 @@ void SdpAttrSimulcast::parse(const string &str) {
//a=simulcast: recv h;m;l
//
auto vec = split(str, " ");
if (vec.size() != 2) {
SDP_THROW();
}
CHECK_SDP(vec.size() == 2);
direction = vec[0];
rids = split(vec[1], ";");
}
@ -794,7 +750,7 @@ string SdpAttrRid::toString() const {
return SdpItem::toString();
}
void RtcSession::loadFrom(const string &str, bool check) {
void RtcSession::loadFrom(const string &str) {
RtcSessionSdp sdp;
sdp.parse(str);
@ -808,12 +764,6 @@ void RtcSession::loadFrom(const string &str, bool check) {
msid_semantic = sdp.getItemClass<SdpAttrMsidSemantic>('a', "msid-semantic");
for (auto &media : sdp.medias) {
auto mline = media.getItemClass<SdpMedia>('m');
switch (mline.type) {
case TrackVideo:
case TrackAudio:
case TrackApplication: break;
default: throw std::invalid_argument(StrPrinter << "不识别的media类型:" << mline.toString());
}
this->media.emplace_back();
auto &rtc_media = this->media.back();
rtc_media.type = mline.type;
@ -949,19 +899,15 @@ void RtcSession::loadFrom(const string &str, bool check) {
map<uint8_t, SdpAttrFmtp &> fmtp_map;
for (auto &rtpmap : rtpmap_arr) {
if (!rtpmap_map.emplace(rtpmap.pt, rtpmap).second) {
//添加失败,有多条
throw std::invalid_argument(StrPrinter << "该pt存在多条a=rtpmap:" << rtpmap.pt);
}
CHECK(rtpmap_map.emplace(rtpmap.pt, rtpmap).second, "该pt存在多条a=rtpmap:", (int)rtpmap.pt);
}
for (auto &rtpfb : rtcpfb_arr) {
rtcpfb_map.emplace(rtpfb.pt, rtpfb);
}
for (auto &fmtp : fmtp_aar) {
if (!fmtp_map.emplace(fmtp.pt, fmtp).second) {
//添加失败,有多条
throw std::invalid_argument(StrPrinter << "该pt存在多条a=fmtp:" << fmtp.pt);
}
CHECK(fmtp_map.emplace(fmtp.pt, fmtp).second, "该pt存在多条a=fmtp:", (int)fmtp.pt);
}
for (auto &pt : mline.fmts) {
//遍历所有编码方案的pt
@ -992,9 +938,6 @@ void RtcSession::loadFrom(const string &str, bool check) {
}
group = sdp.getItemClass<SdpAttrGroup>('a', "group");
if (check) {
checkValid();
}
}
std::shared_ptr<SdpItem> wrapSdpAttr(SdpItem::Ptr item){
@ -1042,7 +985,6 @@ static void toRtsp(vector <SdpItem::Ptr> &items) {
}
string RtcSession::toRtspSdp() const{
checkValid();
RtcSession copy = *this;
copy.media.clear();
for (auto &m : media) {
@ -1269,7 +1211,6 @@ RtcSessionSdp::Ptr RtcSession::toRtcSessionSdp() const{
}
string RtcSession::toString() const{
checkValid();
return toRtcSessionSdp()->toString();
}
@ -1341,10 +1282,9 @@ void RtcMedia::checkValid() const{
CHECK(!mid.empty());
CHECK(!proto.empty());
CHECK(direction != RtpDirection::invalid || type == TrackApplication);
CHECK(!plan.empty() || type == TrackApplication );
}
CHECK(!plan.empty() || type == TrackApplication);
CHECK(type == TrackApplication || rtcp_mux, "只支持rtcp-mux模式");
void RtcMedia::checkValidSSRC() const {
bool send_rtp = (direction == RtpDirection::sendonly || direction == RtpDirection::sendrecv);
if (rtp_rids.empty() && rtp_ssrc_sim.empty()) {
//非simulcast时检查有没有指定rtp ssrc
@ -1367,16 +1307,19 @@ void RtcSession::checkValid() const{
CHECK(!session_name.empty());
CHECK(!msid_semantic.empty());
CHECK(!media.empty());
CHECK(group.mids.size() <= media.size());
CHECK(!group.mids.empty() && group.mids.size() <= media.size(), "只支持group BUNDLE模式");
bool have_active_media = false;
for (auto &item : media) {
item.checkValid();
switch (item.direction) {
case RtpDirection::sendrecv:
case RtpDirection::sendonly:
case RtpDirection::recvonly: have_active_media = true; break;
default : break;
}
}
void RtcSession::checkValidSSRC() const{
for (auto &item : media) {
item.checkValidSSRC();
}
CHECK(have_active_media, "必须确保最少有一个活跃的track");
}
const RtcMedia *RtcSession::getMedia(TrackType type) const{
@ -1388,15 +1331,6 @@ const RtcMedia *RtcSession::getMedia(TrackType type) const{
return nullptr;
}
bool RtcSession::haveSSRC() const {
for (auto &m : media) {
if (!m.rtp_rtx_ssrc.empty()) {
return true;
}
}
return false;
}
bool RtcSession::supportRtcpFb(const string &name, TrackType type) const {
auto media = getMedia(type);
if (!media) {
@ -1415,17 +1349,6 @@ bool RtcSession::supportSimulcast() const {
return false;
}
void RtcSession::checkSdp() const {
for (auto &m : media) {
if (m.type != TrackApplication && !m.rtcp_mux) {
throw std::invalid_argument("只支持rtcp-mux模式");
}
}
if (group.mids.empty()) {
throw std::invalid_argument("只支持group BUNDLE模式");
}
}
string const SdpConst::kTWCCRtcpFb = "transport-cc";
string const SdpConst::kRembRtcpFb = "goog-remb";
@ -1596,17 +1519,12 @@ void RtcConfigure::enableREMB(bool enable, TrackType type){
shared_ptr<RtcSession> RtcConfigure::createAnswer(const RtcSession &offer){
shared_ptr<RtcSession> ret = std::make_shared<RtcSession>();
ret->version = offer.version;
//todo 此处设置会话id与会话地址貌似没什么作用
ret->origin = offer.origin;
ret->session_name = offer.session_name;
ret->msid_semantic = offer.msid_semantic;
matchMedia(ret, TrackAudio, offer.media, audio);
matchMedia(ret, TrackVideo, offer.media, video);
matchMedia(ret, TrackApplication, offer.media, application);
if (ret->media.empty()) {
throw std::invalid_argument("生成的answer sdp中媒体个数为0");
}
//设置音视频端口复用
if (!offer.group.mids.empty()) {
for (auto &m : ret->media) {
@ -1678,11 +1596,8 @@ RETRY:
answer_media.type = offer_media.type;
answer_media.mid = offer_media.mid;
answer_media.proto = offer_media.proto;
//todo(此处设置rtp端口貌似没什么作用)
answer_media.port = offer_media.port;
//todo(此处设置rtp的ip地址貌似没什么作用)
answer_media.addr = offer_media.addr;
//todo(此处设置rtcp地址貌似没什么作用)
answer_media.rtcp_addr = offer_media.rtcp_addr;
answer_media.rtcp_mux = offer_media.rtcp_mux && configure.rtcp_mux;
answer_media.rtcp_rsize = offer_media.rtcp_rsize && configure.rtcp_rsize;
@ -1804,7 +1719,7 @@ RETRY:
void RtcConfigure::setPlayRtspInfo(const string &sdp){
RtcSession session;
session.loadFrom(sdp, false);
session.loadFrom(sdp);
for (auto &m : session.media) {
switch (m.type) {
case TrackVideo : {

View File

@ -641,8 +641,6 @@ public:
uint32_t sctp_port{0};
void checkValid() const;
//offer sdp,如果指定了发送rtp,那么应该指定ssrc
void checkValidSSRC() const;
const RtcCodecPlan *getPlan(uint8_t pt) const;
const RtcCodecPlan *getPlan(const char *codec) const;
const RtcCodecPlan *getRelatedRtxPlan(uint8_t pt) const;
@ -651,7 +649,7 @@ public:
bool supportSimulcast() const;
};
class RtcSession{
class RtcSession {
public:
using Ptr = std::shared_ptr<RtcSession>;
@ -666,15 +664,11 @@ public:
vector<RtcMedia> media;
SdpAttrGroup group;
void loadFrom(const string &sdp, bool check = true);
void loadFrom(const string &sdp);
void checkValid() const;
void checkSdp() const;
//offer sdp,如果指定了发送rtp,那么应该指定ssrc
void checkValidSSRC() const;
string toString() const;
string toRtspSdp() const;
const RtcMedia *getMedia(TrackType type) const;
bool haveSSRC() const;
bool supportRtcpFb(const string &name, TrackType type = TrackType::TrackVideo) const;
bool supportSimulcast() const;

View File

@ -187,6 +187,7 @@ std::string WebRtcTransport::getAnswerSdp(const string &offer){
_offer_sdp = std::make_shared<RtcSession>();
_offer_sdp->loadFrom(offer);
onCheckSdp(SdpType::offer, *_offer_sdp);
_offer_sdp->checkValid();
setRemoteDtlsFingerprint(*_offer_sdp);
//// sdp 配置 ////
@ -201,6 +202,7 @@ std::string WebRtcTransport::getAnswerSdp(const string &offer){
//// 生成answer sdp ////
_answer_sdp = configure.createAnswer(*_offer_sdp);
onCheckSdp(SdpType::answer, *_answer_sdp);
_answer_sdp->checkValid();
return _answer_sdp->toString();
} catch (exception &ex) {
onShutdown(SockException(Err_shutdown, ex.what()));
@ -457,10 +459,9 @@ void WebRtcTransportImp::onCheckAnswer(RtcSession &sdp) {
}
void WebRtcTransportImp::onCheckSdp(SdpType type, RtcSession &sdp) {
sdp.checkSdp();
switch (type) {
case SdpType::answer: onCheckAnswer(sdp); break;
case SdpType::offer: sdp.checkValidSSRC(); break;
case SdpType::offer: break;
default: /*不可达*/ assert(0); break;
}
}