From f49aed7a32d19b5e556ebbf2b31d61033f620103 Mon Sep 17 00:00:00 2001 From: johzzy Date: Sat, 2 Mar 2024 06:25:32 -0400 Subject: [PATCH] srt optimization code for query poller (#3334) - add querySrtTransport, improve code. - update SrtTransportManager key - fix some warning --- src/Common/strCoding.cpp | 4 +-- srt/SrtSession.cpp | 78 +++++----------------------------------- srt/SrtTransport.cpp | 34 ++++++++---------- srt/SrtTransport.hpp | 20 +++++------ srt/SrtTransportImp.cpp | 26 ++++++++++++++ 5 files changed, 61 insertions(+), 101 deletions(-) diff --git a/src/Common/strCoding.cpp b/src/Common/strCoding.cpp index 90fc7096..59a0b7e4 100644 --- a/src/Common/strCoding.cpp +++ b/src/Common/strCoding.cpp @@ -79,7 +79,7 @@ string strCoding::UrlEncodePath(const string &str) { out.push_back(ch); } else { char buf[4]; - sprintf(buf, "%%%X%X", (uint8_t) ch >> 4, (uint8_t) ch & 0x0F); + snprintf(buf, 4, "%%%X%X", (uint8_t) ch >> 4, (uint8_t) ch & 0x0F); out.append(buf); } } @@ -96,7 +96,7 @@ string strCoding::UrlEncodeComponent(const string &str) { out.push_back(ch); } else { char buf[4]; - sprintf(buf, "%%%X%X", (uint8_t) ch >> 4, (uint8_t) ch & 0x0F); + snprintf(buf, 4, "%%%X%X", (uint8_t) ch >> 4, (uint8_t) ch & 0x0F); out.append(buf); } } diff --git a/srt/SrtSession.cpp b/srt/SrtSession.cpp index 5211fd4c..8a671c96 100644 --- a/srt/SrtSession.cpp +++ b/srt/SrtSession.cpp @@ -16,41 +16,17 @@ SrtSession::SrtSession(const Socket::Ptr &sock) // TraceL<<"after addr len "<data(); - size_t size = buffer->size(); - - if (DataPacket::isDataPacket(data, size)) { - uint32_t socket_id = DataPacket::getSocketID(data, size); - auto trans = SrtTransportManager::Instance().getItem(std::to_string(socket_id)); - return trans ? trans->getPoller() : nullptr; - } - - if (HandshakePacket::isHandshakePacket(data, size)) { - auto type = HandshakePacket::getHandshakeType(data, size); - if (type == HandshakePacket::HS_TYPE_INDUCTION) { - // 握手第一阶段 - return nullptr; - } else if (type == HandshakePacket::HS_TYPE_CONCLUSION) { - // 握手第二阶段 - uint32_t sync_cookie = HandshakePacket::getSynCookie(data, size); - auto trans = SrtTransportManager::Instance().getHandshakeItem(std::to_string(sync_cookie)); - return trans ? trans->getPoller() : nullptr; - } else { - WarnL << " not reach there"; - } - } else { - uint32_t socket_id = ControlPacket::getSocketID(data, size); - auto trans = SrtTransportManager::Instance().getItem(std::to_string(socket_id)); - return trans ? trans->getPoller() : nullptr; - } - return nullptr; -} - void SrtSession::attachServer(const toolkit::Server &server) { SockUtil::setRecvBuf(getSock()->rawFD(), 1024 * 1024); } +extern SrtTransport::Ptr querySrtTransport(uint8_t *data, size_t size, const EventPoller::Ptr& poller); + +EventPoller::Ptr SrtSession::queryPoller(const Buffer::Ptr &buffer) { + auto transport = querySrtTransport((uint8_t *)buffer->data(), buffer->size(), nullptr); + return transport ? transport->getPoller() : nullptr; +} + void SrtSession::onRecv(const Buffer::Ptr &buffer) { uint8_t *data = (uint8_t *)buffer->data(); size_t size = buffer->size(); @@ -58,45 +34,7 @@ void SrtSession::onRecv(const Buffer::Ptr &buffer) { if (_find_transport) { //只允许寻找一次transport _find_transport = false; - - if (DataPacket::isDataPacket(data, size)) { - uint32_t socket_id = DataPacket::getSocketID(data, size); - auto trans = SrtTransportManager::Instance().getItem(std::to_string(socket_id)); - if (trans) { - _transport = std::move(trans); - } else { - WarnL << " data packet not find transport "; - } - } - - if (HandshakePacket::isHandshakePacket(data, size)) { - auto type = HandshakePacket::getHandshakeType(data, size); - if (type == HandshakePacket::HS_TYPE_INDUCTION) { - // 握手第一阶段 - _transport = std::make_shared(getPoller()); - - } else if (type == HandshakePacket::HS_TYPE_CONCLUSION) { - // 握手第二阶段 - uint32_t sync_cookie = HandshakePacket::getSynCookie(data, size); - auto trans = SrtTransportManager::Instance().getHandshakeItem(std::to_string(sync_cookie)); - if (trans) { - _transport = std::move(trans); - } else { - WarnL << " hanshake packet not find transport "; - } - } else { - WarnL << " not reach there"; - } - } else { - uint32_t socket_id = ControlPacket::getSocketID(data, size); - auto trans = SrtTransportManager::Instance().getItem(std::to_string(socket_id)); - if (trans) { - _transport = std::move(trans); - } else { - WarnL << " not find transport"; - } - } - + _transport = querySrtTransport(data, size, getPoller()); if (_transport) { _transport->setSession(static_pointer_cast(shared_from_this())); } diff --git a/srt/SrtTransport.cpp b/srt/SrtTransport.cpp index 7388f3a0..f88cf428 100644 --- a/srt/SrtTransport.cpp +++ b/srt/SrtTransport.cpp @@ -61,7 +61,7 @@ void SrtTransport::switchToOtherTransport(uint8_t *buf, int len, uint32_t socket BufferRaw::Ptr tmp = BufferRaw::create(); struct sockaddr_storage tmp_addr = *addr; tmp->assign((char *)buf, len); - auto trans = SrtTransportManager::Instance().getItem(std::to_string(socketid)); + auto trans = SrtTransportManager::Instance().getItem(socketid); if (trans) { trans->getPoller()->async([tmp, tmp_addr, trans] { trans->inputSockData((uint8_t *)tmp->data(), tmp->size(), (struct sockaddr_storage *)&tmp_addr); @@ -700,30 +700,30 @@ void SrtTransport::sendPacket(Buffer::Ptr pkt, bool flush) { } } -std::string SrtTransport::getIdentifier() { +std::string SrtTransport::getIdentifier() const { return _selected_session ? _selected_session->getIdentifier() : ""; } void SrtTransport::registerSelfHandshake() { - SrtTransportManager::Instance().addHandshakeItem(std::to_string(_sync_cookie), shared_from_this()); + SrtTransportManager::Instance().addHandshakeItem(_sync_cookie, shared_from_this()); } void SrtTransport::unregisterSelfHandshake() { if (_sync_cookie == 0) { return; } - SrtTransportManager::Instance().removeHandshakeItem(std::to_string(_sync_cookie)); + SrtTransportManager::Instance().removeHandshakeItem(_sync_cookie); } void SrtTransport::registerSelf() { if (_socket_id == 0) { return; } - SrtTransportManager::Instance().addItem(std::to_string(_socket_id), shared_from_this()); + SrtTransportManager::Instance().addItem(_socket_id, shared_from_this()); } void SrtTransport::unregisterSelf() { - SrtTransportManager::Instance().removeItem(std::to_string(_socket_id)); + SrtTransportManager::Instance().removeItem(_socket_id); } void SrtTransport::onShutdown(const SockException &ex) { @@ -739,7 +739,7 @@ void SrtTransport::onShutdown(const SockException &ex) { } } -size_t SrtTransport::getPayloadSize() { +size_t SrtTransport::getPayloadSize() const { size_t ret = (_mtu - 28 - 16) / 188 * 188; return ret; } @@ -792,15 +792,13 @@ SrtTransportManager &SrtTransportManager::Instance() { return s_instance; } -void SrtTransportManager::addItem(const std::string &key, const SrtTransport::Ptr &ptr) { +void SrtTransportManager::addItem(const uint32_t key, const SrtTransport::Ptr &ptr) { std::lock_guard lck(_mtx); _map[key] = ptr; } -SrtTransport::Ptr SrtTransportManager::getItem(const std::string &key) { - if (key.empty()) { - return nullptr; - } +SrtTransport::Ptr SrtTransportManager::getItem(const uint32_t key) { + assert(key > 0); std::lock_guard lck(_mtx); auto it = _map.find(key); if (it == _map.end()) { @@ -809,25 +807,23 @@ SrtTransport::Ptr SrtTransportManager::getItem(const std::string &key) { return it->second.lock(); } -void SrtTransportManager::removeItem(const std::string &key) { +void SrtTransportManager::removeItem(const uint32_t key) { std::lock_guard lck(_mtx); _map.erase(key); } -void SrtTransportManager::addHandshakeItem(const std::string &key, const SrtTransport::Ptr &ptr) { +void SrtTransportManager::addHandshakeItem(const uint32_t key, const SrtTransport::Ptr &ptr) { std::lock_guard lck(_handshake_mtx); _handshake_map[key] = ptr; } -void SrtTransportManager::removeHandshakeItem(const std::string &key) { +void SrtTransportManager::removeHandshakeItem(const uint32_t key) { std::lock_guard lck(_handshake_mtx); _handshake_map.erase(key); } -SrtTransport::Ptr SrtTransportManager::getHandshakeItem(const std::string &key) { - if (key.empty()) { - return nullptr; - } +SrtTransport::Ptr SrtTransportManager::getHandshakeItem(const uint32_t key) { + assert(key > 0); std::lock_guard lck(_handshake_mtx); auto it = _handshake_map.find(key); if (it == _handshake_map.end()) { diff --git a/srt/SrtTransport.hpp b/srt/SrtTransport.hpp index fe3dfe69..56b75b4c 100644 --- a/srt/SrtTransport.hpp +++ b/srt/SrtTransport.hpp @@ -45,7 +45,7 @@ public: virtual void inputSockData(uint8_t *buf, int len, struct sockaddr_storage *addr); virtual void onSendTSData(const Buffer::Ptr &buffer, bool flush); - std::string getIdentifier(); + std::string getIdentifier() const; void unregisterSelf(); void unregisterSelfHandshake(); @@ -89,7 +89,7 @@ private: void sendShutDown(); void sendMsgDropReq(uint32_t first, uint32_t last); - size_t getPayloadSize(); + size_t getPayloadSize() const; void createTimerForCheckAlive(); @@ -164,23 +164,23 @@ private: class SrtTransportManager { public: static SrtTransportManager &Instance(); - SrtTransport::Ptr getItem(const std::string &key); - void addItem(const std::string &key, const SrtTransport::Ptr &ptr); - void removeItem(const std::string &key); + SrtTransport::Ptr getItem(const uint32_t key); + void addItem(const uint32_t key, const SrtTransport::Ptr &ptr); + void removeItem(const uint32_t key); - void addHandshakeItem(const std::string &key, const SrtTransport::Ptr &ptr); - void removeHandshakeItem(const std::string &key); - SrtTransport::Ptr getHandshakeItem(const std::string &key); + void addHandshakeItem(const uint32_t key, const SrtTransport::Ptr &ptr); + void removeHandshakeItem(const uint32_t key); + SrtTransport::Ptr getHandshakeItem(const uint32_t key); private: SrtTransportManager() = default; private: std::mutex _mtx; - std::unordered_map> _map; + std::unordered_map> _map; std::mutex _handshake_mtx; - std::unordered_map> _handshake_map; + std::unordered_map> _handshake_map; }; } // namespace SRT diff --git a/srt/SrtTransportImp.cpp b/srt/SrtTransportImp.cpp index 087bb4ae..8f818655 100644 --- a/srt/SrtTransportImp.cpp +++ b/srt/SrtTransportImp.cpp @@ -24,6 +24,32 @@ SrtTransportImp::~SrtTransportImp() { } } + +SrtTransport::Ptr querySrtTransport(uint8_t *data, size_t size, const EventPoller::Ptr& poller) { + if (DataPacket::isDataPacket(data, size)) { + uint32_t socket_id = DataPacket::getSocketID(data, size); + return SrtTransportManager::Instance().getItem(socket_id); + } + + if (HandshakePacket::isHandshakePacket(data, size)) { + auto type = HandshakePacket::getHandshakeType(data, size); + if (type == HandshakePacket::HS_TYPE_INDUCTION) { + // 握手第一阶段 + return poller ? std::make_shared(poller) : nullptr; + } + + if (type == HandshakePacket::HS_TYPE_CONCLUSION) { + // 握手第二阶段 + uint32_t sync_cookie = HandshakePacket::getSynCookie(data, size); + return SrtTransportManager::Instance().getHandshakeItem(sync_cookie); + } + } + + uint32_t socket_id = ControlPacket::getSocketID(data, size); + return SrtTransportManager::Instance().getItem(socket_id); +} + + void SrtTransportImp::onHandShakeFinished(std::string &streamid, struct sockaddr_storage *addr) { SrtTransport::onHandShakeFinished(streamid,addr); // TODO parse stream id like this zlmediakit.com/live/test?token=1213444&type=push