From 704421b728f417c189eea3513af8c0fd38a3e0df Mon Sep 17 00:00:00 2001 From: ziyue <1213642868@qq.com> Date: Fri, 26 Mar 2021 11:07:03 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AE=8C=E5=96=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- server/CMakeLists.txt | 1 - webrtc/dtls_transport.cc | 75 - webrtc/dtls_transport.h | 58 - webrtc/ice_server.cc | 661 ++++++--- webrtc/ice_server.h | 138 +- webrtc/logger.h | 2 +- webrtc/rtc_dtls_transport.cc | 2539 ++++++++++++++++++---------------- webrtc/rtc_dtls_transport.h | 341 +++-- webrtc/srtp_session.cc | 487 +++---- webrtc/srtp_session.h | 89 +- webrtc/stun_packet.cc | 161 ++- webrtc/utils.cc | 139 -- webrtc/utils.h | 202 --- webrtc/webrtc_transport.cc | 133 +- webrtc/webrtc_transport.h | 55 +- www/webrtc/index.html | 2 +- 16 files changed, 2618 insertions(+), 2465 deletions(-) delete mode 100644 webrtc/dtls_transport.cc delete mode 100644 webrtc/dtls_transport.h delete mode 100644 webrtc/utils.cc diff --git a/server/CMakeLists.txt b/server/CMakeLists.txt index c0b3a009..6e0af5e1 100644 --- a/server/CMakeLists.txt +++ b/server/CMakeLists.txt @@ -49,4 +49,3 @@ else() endif() target_link_libraries(MediaServer jsoncpp ${LINK_LIB_LIST}) -message(${LINK_LIB_LIST}) diff --git a/webrtc/dtls_transport.cc b/webrtc/dtls_transport.cc deleted file mode 100644 index 69e1f402..00000000 --- a/webrtc/dtls_transport.cc +++ /dev/null @@ -1,75 +0,0 @@ -// -// Created by xueyuegui on 19-12-7. -// - -#include "dtls_transport.h" - -#include - -DtlsTransport::DtlsTransport(bool is_server) : is_server_(is_server) { - dtls_transport_.reset(new RTC::DtlsTransport(this)); -} - -DtlsTransport::~DtlsTransport() {} - -void DtlsTransport::Start() { - if (is_server_) { - dtls_transport_->Run(RTC::DtlsTransport::Role::SERVER); - } else { - dtls_transport_->Run(RTC::DtlsTransport::Role::CLIENT); - } -} - -void DtlsTransport::Close() {} - -void DtlsTransport::OnDtlsTransportConnecting(const RTC::DtlsTransport *dtlsTransport) {} - -void DtlsTransport::OnDtlsTransportConnected(const RTC::DtlsTransport *dtlsTransport, - RTC::CryptoSuite srtp_crypto_suite, - uint8_t *srtpLocalKey, size_t srtpLocalKeyLen, - uint8_t *srtpRemoteKey, size_t srtpRemoteKeyLen, - std::string &remoteCert) { - std::string client_key; - std::string server_key; - server_key.assign((char *) srtpLocalKey, srtpLocalKeyLen); - client_key.assign((char *) srtpRemoteKey, srtpRemoteKeyLen); - if (is_server_) { - // If we are server, we swap the keys - client_key.swap(server_key); - } - if (handshake_completed_callback_) { - handshake_completed_callback_(client_key, server_key, srtp_crypto_suite); - } -} - -void DtlsTransport::OnDtlsTransportFailed(const RTC::DtlsTransport *dtlsTransport) { - if (handshake_failed_callback_) { - handshake_failed_callback_(); - } -} - -void DtlsTransport::OnDtlsTransportClosed(const RTC::DtlsTransport *dtlsTransport) {} - -void DtlsTransport::OnDtlsTransportSendData(const RTC::DtlsTransport *dtlsTransport, - const uint8_t *data, size_t len) { - if (output_callback_) { - output_callback_((char *) data, len); - } -} - -void DtlsTransport::OutputData(char *buf, size_t len) { - if (output_callback_) { - output_callback_(buf, len); - } -} - -void DtlsTransport::OnDtlsTransportApplicationDataReceived(const RTC::DtlsTransport *dtlsTransport, - const uint8_t *data, size_t len) {} - -bool DtlsTransport::IsDtlsPacket(const char *buf, size_t len) { - return RTC::DtlsTransport::IsDtls((uint8_t *) buf, len); -} - -void DtlsTransport::InputData(char *buf, size_t len) { - dtls_transport_->ProcessDtlsData((uint8_t *) buf, len); -} diff --git a/webrtc/dtls_transport.h b/webrtc/dtls_transport.h deleted file mode 100644 index 41f4c65c..00000000 --- a/webrtc/dtls_transport.h +++ /dev/null @@ -1,58 +0,0 @@ -// -// Created by xueyuegui on 19-12-7. -// - -#ifndef MYWEBRTC_MYDTLSTRANSPORT_H -#define MYWEBRTC_MYDTLSTRANSPORT_H - -#include -#include - -#include "rtc_dtls_transport.h" - -class DtlsTransport : RTC::DtlsTransport::Listener { -public: - typedef std::shared_ptr Ptr; - - DtlsTransport(bool bServer); - ~DtlsTransport(); - - void Start(); - void Close(); - void InputData(char *buf, size_t len); - void OutputData(char *buf, size_t len); - static bool IsDtlsPacket(const char *buf, size_t len); - std::string GetMyFingerprint() { - auto finger_prints = dtls_transport_->GetLocalFingerprints(); - for (size_t i = 0; i < finger_prints.size(); i++) { - if (finger_prints[i].algorithm == RTC::DtlsTransport::FingerprintAlgorithm::SHA256) { - return finger_prints[i].value; - } - } - return ""; - }; - - void SetHandshakeCompletedCB(std::function cb) { - handshake_completed_callback_ = std::move(cb); - } - void SetHandshakeFailedCB(std::function cb) { handshake_failed_callback_ = std::move(cb); } - void SetOutPutCB(std::function cb) { output_callback_ = std::move(cb); } - - /* Pure virtual methods inherited from RTC::DtlsTransport::Listener. */ -public: - void OnDtlsTransportConnecting(const RTC::DtlsTransport *dtlsTransport) override; - void OnDtlsTransportConnected(const RTC::DtlsTransport *dtlsTransport, RTC::CryptoSuite srtpCryptoSuite, uint8_t *srtpLocalKey, size_t srtpLocalKeyLen, uint8_t *srtpRemoteKey, size_t srtpRemoteKeyLen, std::string &remoteCert) override; - void OnDtlsTransportFailed(const RTC::DtlsTransport *dtlsTransport) override; - void OnDtlsTransportClosed(const RTC::DtlsTransport *dtlsTransport) override; - void OnDtlsTransportSendData(const RTC::DtlsTransport *dtlsTransport, const uint8_t *data,size_t len) override; - void OnDtlsTransportApplicationDataReceived(const RTC::DtlsTransport *dtlsTransport, const uint8_t *data, size_t len) override; - -private: - bool is_server_ = false; - std::function handshake_failed_callback_; - std::shared_ptr dtls_transport_; - std::function output_callback_; - std::function handshake_completed_callback_; -}; - -#endif// MYWEBRTC_MYDTLSTRANSPORT_H diff --git a/webrtc/ice_server.cc b/webrtc/ice_server.cc index cf7fb24b..7dc64ec5 100644 --- a/webrtc/ice_server.cc +++ b/webrtc/ice_server.cc @@ -1,201 +1,512 @@ +#define MS_CLASS "RTC::IceServer" +// #define MS_LOG_DEV_LEVEL 3 + +#include #include "ice_server.h" -#include +namespace RTC +{ + /* Static. */ -static constexpr size_t StunSerializeBufferSize{65536}; -static uint8_t StunSerializeBuffer[StunSerializeBufferSize]; + static constexpr size_t StunSerializeBufferSize{ 65536 }; + static uint8_t StunSerializeBuffer[StunSerializeBufferSize]; -IceServer::IceServer() {} + /* Instance methods. */ -IceServer::~IceServer() {} + IceServer::IceServer(Listener* listener, const std::string& usernameFragment, const std::string& password) + : listener(listener), usernameFragment(usernameFragment), password(password) + { + MS_TRACE(); + } -IceServer::IceServer(const std::string &username_fragment, const std::string &password) - : username_fragment_(username_fragment), password_(password) {} + void IceServer::ProcessStunPacket(RTC::StunPacket* packet, RTC::TransportTuple* tuple) + { + MS_TRACE(); -void IceServer::ProcessStunPacket(RTC::StunPacket *packet, sockaddr_in *remote_address) { - // Must be a Binding method. - if (packet->GetMethod() != RTC::StunPacket::Method::BINDING) { - if (packet->GetClass() == RTC::StunPacket::Class::REQUEST) { - ELOG_WARN("unknown method %#.3x in STUN Request => 400", - static_cast(packet->GetMethod())); - ELOG_WARN("unknown method %#.3x in STUN Request => 400", - static_cast(packet->GetMethod())); - // Reply 400. - RTC::StunPacket *response = packet->CreateErrorResponse(400); - response->Serialize(StunSerializeBuffer); - if (send_callback_) { - send_callback_((char *) StunSerializeBuffer, response->GetSize(), remote_address); - } - delete response; - } else { - ELOG_WARN("ignoring STUN Indication or Response with unknown method %#.3x", - static_cast(packet->GetMethod())); - } - return; - } + // Must be a Binding method. + if (packet->GetMethod() != RTC::StunPacket::Method::BINDING) + { + if (packet->GetClass() == RTC::StunPacket::Class::REQUEST) + { + MS_WARN_TAG( + ice, + "unknown method %#.3x in STUN Request => 400", + static_cast(packet->GetMethod())); - // Must use FINGERPRINT (optional for ICE STUN indications). - if (!packet->HasFingerprint() && packet->GetClass() != RTC::StunPacket::Class::INDICATION) { - if (packet->GetClass() == RTC::StunPacket::Class::REQUEST) { - ELOG_WARN("STUN Binding Request without FINGERPRINT => 400"); - // Reply 400. - RTC::StunPacket *response = packet->CreateErrorResponse(400); - response->Serialize(StunSerializeBuffer); - if (send_callback_) { - send_callback_((char *) StunSerializeBuffer, response->GetSize(), remote_address); - } - delete response; - } else { - ELOG_WARN("ignoring STUN Binding Response without FINGERPRINT"); - } - return; - } + // Reply 400. + RTC::StunPacket* response = packet->CreateErrorResponse(400); - switch (packet->GetClass()) { - case RTC::StunPacket::Class::REQUEST: { - // USERNAME, MESSAGE-INTEGRITY and PRIORITY are required. - if (!packet->HasMessageIntegrity() || (packet->GetPriority() == 0u) || - packet->GetUsername().empty()) { - ELOG_WARN("mising required attributes in STUN Binding Request => 400"); + response->Serialize(StunSerializeBuffer); + this->listener->OnIceServerSendStunPacket(this, response, tuple); - // Reply 400. - RTC::StunPacket *response = packet->CreateErrorResponse(400); - response->Serialize(StunSerializeBuffer); - if (send_callback_) { - send_callback_((char *) StunSerializeBuffer, response->GetSize(), remote_address); - } - delete response; - return; - } + delete response; + } + else + { + MS_WARN_TAG( + ice, + "ignoring STUN Indication or Response with unknown method %#.3x", + static_cast(packet->GetMethod())); + } - // Check authentication. - switch (packet->CheckAuthentication(this->username_fragment_, this->password_)) { - case RTC::StunPacket::Authentication::OK: { - if (!this->old_password_.empty()) { - ELOG_DEBUG("kNew ICE credentials applied"); - this->old_username_fragment_.clear(); - this->old_password_.clear(); - } - break; - } + return; + } - case RTC::StunPacket::Authentication::UNAUTHORIZED: { - // We may have changed our username_fragment_ and password_, so check - // the old ones. - // clang-format off - if (!this->old_username_fragment_.empty() && - !this->old_password_.empty() && - packet->CheckAuthentication(this->old_username_fragment_, this->old_password_) == - RTC::StunPacket::Authentication::OK) { - ELOG_DEBUG("using old ICE credentials"); - break; - } - ELOG_WARN("wrong authentication in STUN Binding Request => 401"); - // Reply 401. - RTC::StunPacket *response = packet->CreateErrorResponse(401); - response->Serialize(StunSerializeBuffer); - if (send_callback_) { - send_callback_((char *) StunSerializeBuffer, response->GetSize(), remote_address); - } - delete response; - return; - } + // Must use FINGERPRINT (optional for ICE STUN indications). + if (!packet->HasFingerprint() && packet->GetClass() != RTC::StunPacket::Class::INDICATION) + { + if (packet->GetClass() == RTC::StunPacket::Class::REQUEST) + { + MS_WARN_TAG(ice, "STUN Binding Request without FINGERPRINT => 400"); - case RTC::StunPacket::Authentication::BAD_REQUEST: { - ELOG_WARN("cannot check authentication in STUN Binding Request => 400"); - // Reply 400. - RTC::StunPacket *response = packet->CreateErrorResponse(400); - response->Serialize(StunSerializeBuffer); - if (send_callback_) { - send_callback_((char *) StunSerializeBuffer, response->GetSize(), remote_address); - } - delete response; - return; - } - } + // Reply 400. + RTC::StunPacket* response = packet->CreateErrorResponse(400); + + response->Serialize(StunSerializeBuffer); + this->listener->OnIceServerSendStunPacket(this, response, tuple); + + delete response; + } + else + { + MS_WARN_TAG(ice, "ignoring STUN Binding Response without FINGERPRINT"); + } + + return; + } + + switch (packet->GetClass()) + { + case RTC::StunPacket::Class::REQUEST: + { + // USERNAME, MESSAGE-INTEGRITY and PRIORITY are required. + if (!packet->HasMessageIntegrity() || (packet->GetPriority() == 0u) || packet->GetUsername().empty()) + { + MS_WARN_TAG(ice, "mising required attributes in STUN Binding Request => 400"); + + // Reply 400. + RTC::StunPacket* response = packet->CreateErrorResponse(400); + + response->Serialize(StunSerializeBuffer); + this->listener->OnIceServerSendStunPacket(this, response, tuple); + + delete response; + + return; + } + + // Check authentication. + switch (packet->CheckAuthentication(this->usernameFragment, this->password)) + { + case RTC::StunPacket::Authentication::OK: + { + if (!this->oldPassword.empty()) + { + MS_DEBUG_TAG(ice, "new ICE credentials applied"); + + this->oldUsernameFragment.clear(); + this->oldPassword.clear(); + } + + break; + } + + case RTC::StunPacket::Authentication::UNAUTHORIZED: + { + // We may have changed our usernameFragment and password, so check + // the old ones. + // clang-format off + if ( + !this->oldUsernameFragment.empty() && + !this->oldPassword.empty() && + packet->CheckAuthentication(this->oldUsernameFragment, this->oldPassword) == RTC::StunPacket::Authentication::OK + ) + // clang-format on + { + MS_DEBUG_TAG(ice, "using old ICE credentials"); + + break; + } + + MS_WARN_TAG(ice, "wrong authentication in STUN Binding Request => 401"); + + // Reply 401. + RTC::StunPacket* response = packet->CreateErrorResponse(401); + + response->Serialize(StunSerializeBuffer); + this->listener->OnIceServerSendStunPacket(this, response, tuple); + + delete response; + + return; + } + + case RTC::StunPacket::Authentication::BAD_REQUEST: + { + MS_WARN_TAG(ice, "cannot check authentication in STUN Binding Request => 400"); + + // Reply 400. + RTC::StunPacket* response = packet->CreateErrorResponse(400); + + response->Serialize(StunSerializeBuffer); + this->listener->OnIceServerSendStunPacket(this, response, tuple); + + delete response; + + return; + } + } #if 0 - // NOTE: Should be rejected with 487, but this makes Chrome happy: - // https://bugs.chromium.org/p/webrtc/issues/detail?id=7478 - // The remote peer must be ICE controlling. - if (packet->GetIceControlled()) { - MS_WARN_TAG(ice, "peer indicates ICE-CONTROLLED in STUN Binding Request => 487"); - // Reply 487 (Role Conflict). - RTC::StunPacket *response = packet->CreateErrorResponse(487); - response->Serialize(StunSerializeBuffer); - if (send_callback_) { - send_callback_((char *) StunSerializeBuffer, response->GetSize(), remote_address); - } - delete response; - return; - } + // The remote peer must be ICE controlling. + if (packet->GetIceControlled()) + { + MS_WARN_TAG(ice, "peer indicates ICE-CONTROLLED in STUN Binding Request => 487"); + + // Reply 487 (Role Conflict). + RTC::StunPacket* response = packet->CreateErrorResponse(487); + + response->Serialize(StunSerializeBuffer); + this->listener->OnIceServerSendStunPacket(this, response, tuple); + + delete response; + + return; + } + #endif - ELOG_DEBUG("processing STUN Binding Request [Priority:%d, UseCandidate:%s]", - static_cast(packet->GetPriority()), - (packet->HasUseCandidate() ? "true" : "false")); - // Create a success response. - RTC::StunPacket *response = packet->CreateSuccessResponse(); - // Add XOR-MAPPED-ADDRESS. - // response->SetXorMappedAddress(tuple->GetRemoteAddress()); - response->SetXorMappedAddress((struct sockaddr *) remote_address); - // Authenticate the response. - if (this->old_password_.empty()) { - response->Authenticate(this->password_); - } else { - response->Authenticate(this->old_password_); - } + MS_DEBUG_DEV( + "processing STUN Binding Request [Priority:%" PRIu32 ", UseCandidate:%s]", + static_cast(packet->GetPriority()), + packet->HasUseCandidate() ? "true" : "false"); - // Send back. - response->Serialize(StunSerializeBuffer); - if (send_callback_) { - send_callback_((char *) StunSerializeBuffer, response->GetSize(), remote_address); - } - delete response; - // Handle the tuple. - HandleTuple(remote_address, packet->HasUseCandidate()); - break; - } + // Create a success response. + RTC::StunPacket* response = packet->CreateSuccessResponse(); - case RTC::StunPacket::Class::INDICATION: { - ELOG_DEBUG("STUN Binding Indication processed"); - break; - } + // Add XOR-MAPPED-ADDRESS. + response->SetXorMappedAddress(tuple); - case RTC::StunPacket::Class::SUCCESS_RESPONSE: { - ELOG_DEBUG("STUN Binding Success Response processed"); - break; - } + // Authenticate the response. + if (this->oldPassword.empty()) + response->Authenticate(this->password); + else + response->Authenticate(this->oldPassword); - case RTC::StunPacket::Class::ERROR_RESPONSE: { - ELOG_DEBUG("STUN Binding Error Response processed"); - break; - } - } -} -void IceServer::HandleTuple(sockaddr_in *remote_address, bool has_use_candidate) { - remote_address_ = *remote_address; - if (has_use_candidate) { - this->state = IceState::kCompleted; - } - if (ice_server_completed_callback_) { - ice_server_completed_callback_(); - ice_server_completed_callback_ = nullptr; - } -} + // Send back. + response->Serialize(StunSerializeBuffer); + this->listener->OnIceServerSendStunPacket(this, response, tuple); -const std::string &IceServer::GetUsernameFragment() const { return this->username_fragment_; } + delete response; -const std::string &IceServer::GetPassword() const { return this->password_; } + // Handle the tuple. + HandleTuple(tuple, packet->HasUseCandidate()); -inline void IceServer::SetUsernameFragment(const std::string &username_fragment) { - this->old_username_fragment_ = this->username_fragment_; - this->username_fragment_ = username_fragment; -} + break; + } -inline void IceServer::SetPassword(const std::string &password) { - this->old_password_ = this->password_; - this->password_ = password; -} + case RTC::StunPacket::Class::INDICATION: + { + MS_DEBUG_TAG(ice, "STUN Binding Indication processed"); -inline IceServer::IceState IceServer::GetState() const { return this->state; } \ No newline at end of file + break; + } + + case RTC::StunPacket::Class::SUCCESS_RESPONSE: + { + MS_DEBUG_TAG(ice, "STUN Binding Success Response processed"); + + break; + } + + case RTC::StunPacket::Class::ERROR_RESPONSE: + { + MS_DEBUG_TAG(ice, "STUN Binding Error Response processed"); + + break; + } + } + } + + bool IceServer::IsValidTuple(const RTC::TransportTuple* tuple) const + { + MS_TRACE(); + + return HasTuple(tuple) != nullptr; + } + + void IceServer::RemoveTuple(RTC::TransportTuple* tuple) + { + MS_TRACE(); + + RTC::TransportTuple* removedTuple{ nullptr }; + + // Find the removed tuple. + auto it = this->tuples.begin(); + + for (; it != this->tuples.end(); ++it) + { + RTC::TransportTuple* storedTuple = std::addressof(*it); + + if (memcmp(storedTuple, tuple, sizeof (RTC::TransportTuple)) == 0) + { + removedTuple = storedTuple; + + break; + } + } + + // If not found, ignore. + if (!removedTuple) + return; + + // Remove from the list of tuples. + this->tuples.erase(it); + + // If this is not the selected tuple, stop here. + if (removedTuple != this->selectedTuple) + return; + + // Otherwise this was the selected tuple. + this->selectedTuple = nullptr; + + // Mark the first tuple as selected tuple (if any). + if (this->tuples.begin() != this->tuples.end()) + { + SetSelectedTuple(std::addressof(*this->tuples.begin())); + } + // Or just emit 'disconnected'. + else + { + // Update state. + this->state = IceState::DISCONNECTED; + // Notify the listener. + this->listener->OnIceServerDisconnected(this); + } + } + + void IceServer::ForceSelectedTuple(const RTC::TransportTuple* tuple) + { + MS_TRACE(); + + MS_ASSERT( + this->selectedTuple, "cannot force the selected tuple if there was not a selected tuple"); + + auto* storedTuple = HasTuple(tuple); + + MS_ASSERT( + storedTuple, + "cannot force the selected tuple if the given tuple was not already a valid tuple"); + + // Mark it as selected tuple. + SetSelectedTuple(storedTuple); + } + + void IceServer::HandleTuple(RTC::TransportTuple* tuple, bool hasUseCandidate) + { + MS_TRACE(); + + switch (this->state) + { + case IceState::NEW: + { + // There should be no tuples. + MS_ASSERT( + this->tuples.empty(), "state is 'new' but there are %zu tuples", this->tuples.size()); + + // There shouldn't be a selected tuple. + MS_ASSERT(!this->selectedTuple, "state is 'new' but there is selected tuple"); + + if (!hasUseCandidate) + { + MS_DEBUG_TAG(ice, "transition from state 'new' to 'connected'"); + + // Store the tuple. + auto* storedTuple = AddTuple(tuple); + + // Mark it as selected tuple. + SetSelectedTuple(storedTuple); + // Update state. + this->state = IceState::CONNECTED; + // Notify the listener. + this->listener->OnIceServerConnected(this); + } + else + { + MS_DEBUG_TAG(ice, "transition from state 'new' to 'completed'"); + + // Store the tuple. + auto* storedTuple = AddTuple(tuple); + + // Mark it as selected tuple. + SetSelectedTuple(storedTuple); + // Update state. + this->state = IceState::COMPLETED; + // Notify the listener. + this->listener->OnIceServerCompleted(this); + } + + break; + } + + case IceState::DISCONNECTED: + { + // There should be no tuples. + MS_ASSERT( + this->tuples.empty(), + "state is 'disconnected' but there are %zu tuples", + this->tuples.size()); + + // There shouldn't be a selected tuple. + MS_ASSERT(!this->selectedTuple, "state is 'disconnected' but there is selected tuple"); + + if (!hasUseCandidate) + { + MS_DEBUG_TAG(ice, "transition from state 'disconnected' to 'connected'"); + + // Store the tuple. + auto* storedTuple = AddTuple(tuple); + + // Mark it as selected tuple. + SetSelectedTuple(storedTuple); + // Update state. + this->state = IceState::CONNECTED; + // Notify the listener. + this->listener->OnIceServerConnected(this); + } + else + { + MS_DEBUG_TAG(ice, "transition from state 'disconnected' to 'completed'"); + + // Store the tuple. + auto* storedTuple = AddTuple(tuple); + + // Mark it as selected tuple. + SetSelectedTuple(storedTuple); + // Update state. + this->state = IceState::COMPLETED; + // Notify the listener. + this->listener->OnIceServerCompleted(this); + } + + break; + } + + case IceState::CONNECTED: + { + // There should be some tuples. + MS_ASSERT(!this->tuples.empty(), "state is 'connected' but there are no tuples"); + + // There should be a selected tuple. + MS_ASSERT(this->selectedTuple, "state is 'connected' but there is not selected tuple"); + + if (!hasUseCandidate) + { + // If a new tuple store it. + if (!HasTuple(tuple)) + AddTuple(tuple); + } + else + { + MS_DEBUG_TAG(ice, "transition from state 'connected' to 'completed'"); + + auto* storedTuple = HasTuple(tuple); + + // If a new tuple store it. + if (!storedTuple) + storedTuple = AddTuple(tuple); + + // Mark it as selected tuple. + SetSelectedTuple(storedTuple); + // Update state. + this->state = IceState::COMPLETED; + // Notify the listener. + this->listener->OnIceServerCompleted(this); + } + + break; + } + + case IceState::COMPLETED: + { + // There should be some tuples. + MS_ASSERT(!this->tuples.empty(), "state is 'completed' but there are no tuples"); + + // There should be a selected tuple. + MS_ASSERT(this->selectedTuple, "state is 'completed' but there is not selected tuple"); + + if (!hasUseCandidate) + { + // If a new tuple store it. + if (!HasTuple(tuple)) + AddTuple(tuple); + } + else + { + auto* storedTuple = HasTuple(tuple); + + // If a new tuple store it. + if (!storedTuple) + storedTuple = AddTuple(tuple); + + // Mark it as selected tuple. + SetSelectedTuple(storedTuple); + } + + break; + } + } + } + + inline RTC::TransportTuple* IceServer::AddTuple(RTC::TransportTuple* tuple) + { + MS_TRACE(); + + // Add the new tuple at the beginning of the list. + this->tuples.push_front(*tuple); + + auto* storedTuple = std::addressof(*this->tuples.begin()); + + // Return the address of the inserted tuple. + return storedTuple; + } + + inline RTC::TransportTuple* IceServer::HasTuple(const RTC::TransportTuple* tuple) const + { + MS_TRACE(); + + // If there is no selected tuple yet then we know that the tuples list + // is empty. + if (!this->selectedTuple) + return nullptr; + + // Check the current selected tuple. + if (memcmp(selectedTuple, tuple, sizeof (RTC::TransportTuple)) == 0) + return this->selectedTuple; + + // Otherwise check other stored tuples. + for (const auto& it : this->tuples) + { + auto* storedTuple = const_cast(std::addressof(it)); + + if (memcmp(storedTuple, tuple, sizeof (RTC::TransportTuple)) == 0) + return storedTuple; + } + + return nullptr; + } + + inline void IceServer::SetSelectedTuple(RTC::TransportTuple* storedTuple) + { + MS_TRACE(); + + // If already the selected tuple do nothing. + if (storedTuple == this->selectedTuple) + return; + + this->selectedTuple = storedTuple; + + // Notify the listener. + this->listener->OnIceServerSelectedTuple(this, this->selectedTuple); + } +} // namespace RTC diff --git a/webrtc/ice_server.h b/webrtc/ice_server.h index d33f26d2..437d9d9f 100644 --- a/webrtc/ice_server.h +++ b/webrtc/ice_server.h @@ -1,40 +1,112 @@ -#pragma once +#ifndef MS_RTC_ICE_SERVER_HPP +#define MS_RTC_ICE_SERVER_HPP +#include "stun_packet.h" +#include "logger.h" +#include +#include #include #include -#include "logger.h" -#include "stun_packet.h" +namespace RTC +{ + using TransportTuple = struct sockaddr; + class IceServer + { + public: + enum class IceState + { + NEW = 1, + CONNECTED, + COMPLETED, + DISCONNECTED + }; -typedef std::function UdpSendCallback; + public: + class Listener + { + public: + virtual ~Listener() = default; -class IceServer { -public: - enum class IceState { kNew = 1, kConnect, kCompleted, kDisconnected }; - typedef std::shared_ptr Ptr; - IceServer(); - IceServer(const std::string &username_fragment, const std::string &password); - const std::string &GetUsernameFragment() const; - const std::string &GetPassword() const; - void SetUsernameFragment(const std::string &username_fragment); - void SetPassword(const std::string &password); - IceState GetState() const; - void ProcessStunPacket(RTC::StunPacket *packet, struct sockaddr_in *remote_address); - void HandleTuple(struct sockaddr_in *remote_address, bool has_use_candidate); - ~IceServer(); - void SetSendCB(UdpSendCallback send_cb) { send_callback_ = send_cb; } - void SetIceServerCompletedCB(std::function cb) { ice_server_completed_callback_ = cb; }; - struct sockaddr_in *GetSelectAddr() { - return &remote_address_; - } + public: + /** + * These callbacks are guaranteed to be called before ProcessStunPacket() + * returns, so the given pointers are still usable. + */ + virtual void OnIceServerSendStunPacket( + const RTC::IceServer* iceServer, const RTC::StunPacket* packet, RTC::TransportTuple* tuple) = 0; + virtual void OnIceServerSelectedTuple( + const RTC::IceServer* iceServer, RTC::TransportTuple* tuple) = 0; + virtual void OnIceServerConnected(const RTC::IceServer* iceServer) = 0; + virtual void OnIceServerCompleted(const RTC::IceServer* iceServer) = 0; + virtual void OnIceServerDisconnected(const RTC::IceServer* iceServer) = 0; + }; -private: - UdpSendCallback send_callback_; - std::function ice_server_completed_callback_; - std::string username_fragment_; - std::string password_; - std::string old_username_fragment_; - std::string old_password_; - IceState state{IceState::kNew}; - struct sockaddr_in remote_address_; -}; + public: + IceServer(Listener* listener, const std::string& usernameFragment, const std::string& password); + + public: + void ProcessStunPacket(RTC::StunPacket* packet, RTC::TransportTuple* tuple); + const std::string& GetUsernameFragment() const + { + return this->usernameFragment; + } + const std::string& GetPassword() const + { + return this->password; + } + IceState GetState() const + { + return this->state; + } + RTC::TransportTuple* GetSelectedTuple() const + { + return this->selectedTuple; + } + void SetUsernameFragment(const std::string& usernameFragment) + { + this->oldUsernameFragment = this->usernameFragment; + this->usernameFragment = usernameFragment; + } + void SetPassword(const std::string& password) + { + this->oldPassword = this->password; + this->password = password; + } + bool IsValidTuple(const RTC::TransportTuple* tuple) const; + void RemoveTuple(RTC::TransportTuple* tuple); + // This should be just called in 'connected' or completed' state + // and the given tuple must be an already valid tuple. + void ForceSelectedTuple(const RTC::TransportTuple* tuple); + + private: + void HandleTuple(RTC::TransportTuple* tuple, bool hasUseCandidate); + /** + * Store the given tuple and return its stored address. + */ + RTC::TransportTuple* AddTuple(RTC::TransportTuple* tuple); + /** + * If the given tuple exists return its stored address, nullptr otherwise. + */ + RTC::TransportTuple* HasTuple(const RTC::TransportTuple* tuple) const; + /** + * Set the given tuple as the selected tuple. + * NOTE: The given tuple MUST be already stored within the list. + */ + void SetSelectedTuple(RTC::TransportTuple* storedTuple); + + private: + // Passed by argument. + Listener* listener{ nullptr }; + // Others. + std::string usernameFragment; + std::string password; + std::string oldUsernameFragment; + std::string oldPassword; + IceState state{ IceState::NEW }; + std::list tuples; + RTC::TransportTuple* selectedTuple{ nullptr }; + }; +} // namespace RTC + +#endif diff --git a/webrtc/logger.h b/webrtc/logger.h index a165536b..3da3984f 100644 --- a/webrtc/logger.h +++ b/webrtc/logger.h @@ -12,7 +12,7 @@ #define MS_DEBUG_2TAGS(tag1, tag2,fmt, ...) printf("debug:" fmt "\n", ##__VA_ARGS__) #define MS_WARN_2TAGS(tag1, tag2,fmt, ...) printf("warn:" fmt "\n", ##__VA_ARGS__) #define MS_DEBUG_TAG(tag,fmt, ...) printf("debug:" fmt "\n", ##__VA_ARGS__) -#define MS_ASSERT(con, log) assert(con) +#define MS_ASSERT(con, fmt, ...) do{if(!(con)) { printf("assert failed:%s" fmt "\n", #con, ##__VA_ARGS__);} assert(con); } while(false); #define MS_ABORT(fmt, ...) do{ printf("abort:" fmt "\n", ##__VA_ARGS__); abort(); } while(false); #define MS_WARN_TAG(tag,fmt, ...) printf("warn:" fmt "\n", ##__VA_ARGS__) #define MS_DEBUG_DEV(fmt, ...) printf("debug:" fmt "\n", ##__VA_ARGS__) \ No newline at end of file diff --git a/webrtc/rtc_dtls_transport.cc b/webrtc/rtc_dtls_transport.cc index 1087af4a..e66f6469 100644 --- a/webrtc/rtc_dtls_transport.cc +++ b/webrtc/rtc_dtls_transport.cc @@ -2,62 +2,56 @@ // #define MS_LOG_DEV_LEVEL 3 #include "rtc_dtls_transport.h" - +#include "logger.h" #include #include #include #include #include +#include // std::sprintf(), std::fopen() +#include // std::memcpy(), std::strcmp() -#include // std::sprintf(), std::fopen() -#include // std::memcpy(), std::strcmp() - -#include "logger.h" - -typedef struct { - long tv_sec; - long tv_usec; -} uv_timeval_t; - -#define LOG_OPENSSL_ERROR(desc) \ - do { \ - if (ERR_peek_error() == 0) \ - MS_ERROR("OpenSSL error [desc:'%s']", desc); \ - else { \ - int64_t err; \ - while ((err = ERR_get_error()) != 0) { \ - MS_ERROR("OpenSSL error [desc:'%s', error:'%s']", desc, ERR_error_string(err, nullptr)); \ - } \ - ERR_clear_error(); \ - } \ - } while (false) +#define LOG_OPENSSL_ERROR(desc) \ + do \ + { \ + if (ERR_peek_error() == 0) \ + MS_ERROR("OpenSSL error [desc:'%s']", desc); \ + else \ + { \ + int64_t err; \ + while ((err = ERR_get_error()) != 0) \ + { \ + MS_ERROR("OpenSSL error [desc:'%s', error:'%s']", desc, ERR_error_string(err, nullptr)); \ + } \ + ERR_clear_error(); \ + } \ + } while (false) /* Static methods for OpenSSL callbacks. */ -inline static int onSslCertificateVerify(int /*preverifyOk*/, X509_STORE_CTX* /*ctx*/) { - MS_TRACE(); +inline static int onSslCertificateVerify(int /*preverifyOk*/, X509_STORE_CTX* /*ctx*/) +{ + MS_TRACE(); - // Always valid since DTLS certificates are self-signed. - return 1; + // Always valid since DTLS certificates are self-signed. + return 1; } -inline static void onSslInfo(const SSL* ssl, int where, int ret) { - static_cast(SSL_get_ex_data(ssl, 0))->OnSslInfo(where, ret); +inline static unsigned int onSslDtlsTimer(SSL* /*ssl*/, unsigned int timerUs) +{ + if (timerUs == 0) + return 100000; + else if (timerUs >= 4000000) + return 4000000; + else + return 2 * timerUs; } -inline static unsigned int onSslDtlsTimer(SSL* /*ssl*/, unsigned int timerUs) { - if (timerUs == 0) - return 100000; - else if (timerUs >= 4000000) - return 4000000; - else - return 2 * timerUs; -} +namespace RTC +{ + /* Static. */ -namespace RTC { -/* Static. */ - -// clang-format off + // clang-format off static constexpr int DtlsMtu{ 1350 }; static constexpr int SslReadBufferSize{ 65536 }; // AES-HMAC: http://tools.ietf.org/html/rfc3711 @@ -71,15 +65,15 @@ namespace RTC { static constexpr size_t SrtpAesGcm128MasterKeyLength{ 16 }; static constexpr size_t SrtpAesGcm128MasterSaltLength{ 12 }; static constexpr size_t SrtpAesGcm128MasterLength{ SrtpAesGcm128MasterKeyLength + SrtpAesGcm128MasterSaltLength }; -// clang-format on + // clang-format on -/* Class variables. */ + /* Class variables. */ -X509* DtlsTransport::certificate{nullptr}; -EVP_PKEY* DtlsTransport::privateKey{nullptr}; -SSL_CTX* DtlsTransport::sslCtx{nullptr}; -uint8_t DtlsTransport::sslReadBuffer[SslReadBufferSize]; -// clang-format off + X509* DtlsTransport::certificate{ nullptr }; + EVP_PKEY* DtlsTransport::privateKey{ nullptr }; + SSL_CTX* DtlsTransport::sslCtx{ nullptr }; + uint8_t DtlsTransport::sslReadBuffer[SslReadBufferSize]; + // clang-format off std::map DtlsTransport::string2FingerprintAlgorithm = { { "sha-1", DtlsTransport::FingerprintAlgorithm::SHA1 }, @@ -105,1219 +99,1376 @@ uint8_t DtlsTransport::sslReadBuffer[SslReadBufferSize]; std::vector DtlsTransport::localFingerprints; std::vector DtlsTransport::srtpCryptoSuites = { - { RTC::CryptoSuite::AEAD_AES_256_GCM, "SRTP_AEAD_AES_256_GCM" }, - { RTC::CryptoSuite::AEAD_AES_128_GCM, "SRTP_AEAD_AES_128_GCM" }, - { RTC::CryptoSuite::AES_CM_128_HMAC_SHA1_80, "SRTP_AES128_CM_SHA1_80" }, - { RTC::CryptoSuite::AES_CM_128_HMAC_SHA1_32, "SRTP_AES128_CM_SHA1_32" } + { RTC::SrtpSession::CryptoSuite::AEAD_AES_256_GCM, "SRTP_AEAD_AES_256_GCM" }, + { RTC::SrtpSession::CryptoSuite::AEAD_AES_128_GCM, "SRTP_AEAD_AES_128_GCM" }, + { RTC::SrtpSession::CryptoSuite::AES_CM_128_HMAC_SHA1_80, "SRTP_AES128_CM_SHA1_80" }, + { RTC::SrtpSession::CryptoSuite::AES_CM_128_HMAC_SHA1_32, "SRTP_AES128_CM_SHA1_32" } }; -// clang-format on + // clang-format on -/* Class methods. */ + /* Class methods. */ -void DtlsTransport::ClassInit() { - MS_TRACE(); + void DtlsTransport::ClassInit() + { + MS_TRACE(); + + // Generate a X509 certificate and private key (unless PEM files are provided). + if (true /* + Settings::configuration.dtlsCertificateFile.empty() || + Settings::configuration.dtlsPrivateKeyFile.empty()*/) + { + GenerateCertificateAndPrivateKey(); + } + else + { + ReadCertificateAndPrivateKeyFromFiles(); + } + + // Create a global SSL_CTX. + CreateSslCtx(); + + // Generate certificate fingerprints. + GenerateFingerprints(); + } + + void DtlsTransport::ClassDestroy() + { + MS_TRACE(); + + if (DtlsTransport::privateKey) + EVP_PKEY_free(DtlsTransport::privateKey); + if (DtlsTransport::certificate) + X509_free(DtlsTransport::certificate); + if (DtlsTransport::sslCtx) + SSL_CTX_free(DtlsTransport::sslCtx); + } + + void DtlsTransport::GenerateCertificateAndPrivateKey() + { + MS_TRACE(); + + int ret{ 0 }; + EC_KEY* ecKey{ nullptr }; + X509_NAME* certName{ nullptr }; + std::string subject = + std::string("mediasoup") + std::to_string(rand() % 999999 + 100000); + + // Create key with curve. + ecKey = EC_KEY_new_by_curve_name(NID_X9_62_prime256v1); + + if (!ecKey) + { + LOG_OPENSSL_ERROR("EC_KEY_new_by_curve_name() failed"); + + goto error; + } + + EC_KEY_set_asn1_flag(ecKey, OPENSSL_EC_NAMED_CURVE); + + // NOTE: This can take some time. + ret = EC_KEY_generate_key(ecKey); + + if (ret == 0) + { + LOG_OPENSSL_ERROR("EC_KEY_generate_key() failed"); + + goto error; + } + + // Create a private key object. + DtlsTransport::privateKey = EVP_PKEY_new(); + + if (!DtlsTransport::privateKey) + { + LOG_OPENSSL_ERROR("EVP_PKEY_new() failed"); + + goto error; + } + + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-cstyle-cast) + ret = EVP_PKEY_assign_EC_KEY(DtlsTransport::privateKey, ecKey); + + if (ret == 0) + { + LOG_OPENSSL_ERROR("EVP_PKEY_assign_EC_KEY() failed"); + + goto error; + } + + // The EC key now belongs to the private key, so don't clean it up separately. + ecKey = nullptr; + + // Create the X509 certificate. + DtlsTransport::certificate = X509_new(); + + if (!DtlsTransport::certificate) + { + LOG_OPENSSL_ERROR("X509_new() failed"); + + goto error; + } + + // Set version 3 (note that 0 means version 1). + X509_set_version(DtlsTransport::certificate, 2); + + // Set serial number (avoid default 0). + ASN1_INTEGER_set( + X509_get_serialNumber(DtlsTransport::certificate), + static_cast(rand() % 999999 + 100000)); + + // Set valid period. + X509_gmtime_adj(X509_get_notBefore(DtlsTransport::certificate), -315360000); // -10 years. + X509_gmtime_adj(X509_get_notAfter(DtlsTransport::certificate), 315360000); // 10 years. + + // Set the public key for the certificate using the key. + ret = X509_set_pubkey(DtlsTransport::certificate, DtlsTransport::privateKey); + + if (ret == 0) + { + LOG_OPENSSL_ERROR("X509_set_pubkey() failed"); + + goto error; + } + + // Set certificate fields. + certName = X509_get_subject_name(DtlsTransport::certificate); + + if (!certName) + { + LOG_OPENSSL_ERROR("X509_get_subject_name() failed"); + + goto error; + } + + X509_NAME_add_entry_by_txt( + certName, "O", MBSTRING_ASC, reinterpret_cast(subject.c_str()), -1, -1, 0); + X509_NAME_add_entry_by_txt( + certName, "CN", MBSTRING_ASC, reinterpret_cast(subject.c_str()), -1, -1, 0); + + // It is self-signed so set the issuer name to be the same as the subject. + ret = X509_set_issuer_name(DtlsTransport::certificate, certName); + + if (ret == 0) + { + LOG_OPENSSL_ERROR("X509_set_issuer_name() failed"); + + goto error; + } + + // Sign the certificate with its own private key. + ret = X509_sign(DtlsTransport::certificate, DtlsTransport::privateKey, EVP_sha1()); + + if (ret == 0) + { + LOG_OPENSSL_ERROR("X509_sign() failed"); + + goto error; + } + + return; + + error: + + if (ecKey) + EC_KEY_free(ecKey); + + if (DtlsTransport::privateKey) + EVP_PKEY_free(DtlsTransport::privateKey); // NOTE: This also frees the EC key. + + if (DtlsTransport::certificate) + X509_free(DtlsTransport::certificate); + + MS_THROW_ERROR("DTLS certificate and private key generation failed"); + } + + void DtlsTransport::ReadCertificateAndPrivateKeyFromFiles() + { #if 0 - // Generate a X509 certificate and private key (unless PEM files are provided). - if (Settings::configuration.dtlsCertificateFile.empty() || - Settings::configuration.dtlsPrivateKeyFile.empty()) { - GenerateCertificateAndPrivateKey(); - } else { - ReadCertificateAndPrivateKeyFromFiles(); - } -#else - GenerateCertificateAndPrivateKey(); + MS_TRACE(); + + FILE* file{ nullptr }; + + file = fopen(Settings::configuration.dtlsCertificateFile.c_str(), "r"); + + if (!file) + { + MS_ERROR("error reading DTLS certificate file: %s", std::strerror(errno)); + + goto error; + } + + DtlsTransport::certificate = PEM_read_X509(file, nullptr, nullptr, nullptr); + + if (!DtlsTransport::certificate) + { + LOG_OPENSSL_ERROR("PEM_read_X509() failed"); + + goto error; + } + + fclose(file); + + file = fopen(Settings::configuration.dtlsPrivateKeyFile.c_str(), "r"); + + if (!file) + { + MS_ERROR("error reading DTLS private key file: %s", std::strerror(errno)); + + goto error; + } + + DtlsTransport::privateKey = PEM_read_PrivateKey(file, nullptr, nullptr, nullptr); + + if (!DtlsTransport::privateKey) + { + LOG_OPENSSL_ERROR("PEM_read_PrivateKey() failed"); + + goto error; + } + + fclose(file); + + return; + + error: + + MS_THROW_ERROR("error reading DTLS certificate and private key PEM files"); #endif + } - // Create a global SSL_CTX. - CreateSslCtx(); + void DtlsTransport::CreateSslCtx() + { + MS_TRACE(); - // Generate certificate fingerprints. - GenerateFingerprints(); -} + std::string dtlsSrtpCryptoSuites; + int ret; -void DtlsTransport::ClassDestroy() { - MS_TRACE(); + /* Set the global DTLS context. */ - if (DtlsTransport::privateKey) EVP_PKEY_free(DtlsTransport::privateKey); - if (DtlsTransport::certificate) X509_free(DtlsTransport::certificate); - if (DtlsTransport::sslCtx) SSL_CTX_free(DtlsTransport::sslCtx); -} + // Both DTLS 1.0 and 1.2 (requires OpenSSL >= 1.1.0). + DtlsTransport::sslCtx = SSL_CTX_new(DTLS_method()); -void DtlsTransport::GenerateCertificateAndPrivateKey() { - MS_TRACE(); + if (!DtlsTransport::sslCtx) + { + LOG_OPENSSL_ERROR("SSL_CTX_new() failed"); - int ret{0}; - EC_KEY* ecKey{nullptr}; - X509_NAME* certName{nullptr}; - std::string subject = std::string("mediasoup") + std::to_string(rand() % 999999 + 100000); - // std::string("mediasoup") + std::to_string(Utils::Crypto::GetRandomUInt(100000, 999999)); + goto error; + } - // Create key with curve. - ecKey = EC_KEY_new_by_curve_name(NID_X9_62_prime256v1); + ret = SSL_CTX_use_certificate(DtlsTransport::sslCtx, DtlsTransport::certificate); - if (!ecKey) { - LOG_OPENSSL_ERROR("EC_KEY_new_by_curve_name() failed"); + if (ret == 0) + { + LOG_OPENSSL_ERROR("SSL_CTX_use_certificate() failed"); - goto error; - } + goto error; + } - EC_KEY_set_asn1_flag(ecKey, OPENSSL_EC_NAMED_CURVE); + ret = SSL_CTX_use_PrivateKey(DtlsTransport::sslCtx, DtlsTransport::privateKey); - // NOTE: This can take some time. - ret = EC_KEY_generate_key(ecKey); + if (ret == 0) + { + LOG_OPENSSL_ERROR("SSL_CTX_use_PrivateKey() failed"); - if (ret == 0) { - LOG_OPENSSL_ERROR("EC_KEY_generate_key() failed"); + goto error; + } - goto error; - } + ret = SSL_CTX_check_private_key(DtlsTransport::sslCtx); - // Create a private key object. - DtlsTransport::privateKey = EVP_PKEY_new(); + if (ret == 0) + { + LOG_OPENSSL_ERROR("SSL_CTX_check_private_key() failed"); - if (!DtlsTransport::privateKey) { - LOG_OPENSSL_ERROR("EVP_PKEY_new() failed"); + goto error; + } - goto error; - } + // Set options. + SSL_CTX_set_options( + DtlsTransport::sslCtx, + SSL_OP_CIPHER_SERVER_PREFERENCE | SSL_OP_NO_TICKET | SSL_OP_SINGLE_ECDH_USE | + SSL_OP_NO_QUERY_MTU); - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-cstyle-cast) - ret = EVP_PKEY_assign_EC_KEY(DtlsTransport::privateKey, ecKey); + // Don't use sessions cache. + SSL_CTX_set_session_cache_mode(DtlsTransport::sslCtx, SSL_SESS_CACHE_OFF); - if (ret == 0) { - LOG_OPENSSL_ERROR("EVP_PKEY_assign_EC_KEY() failed"); + // Read always as much into the buffer as possible. + // NOTE: This is the default for DTLS, but a bug in non latest OpenSSL + // versions makes this call required. + SSL_CTX_set_read_ahead(DtlsTransport::sslCtx, 1); - goto error; - } + SSL_CTX_set_verify_depth(DtlsTransport::sslCtx, 4); - // The EC key now belongs to the private key, so don't clean it up separately. - ecKey = nullptr; + // Require certificate from peer. + SSL_CTX_set_verify( + DtlsTransport::sslCtx, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, onSslCertificateVerify); - // Create the X509 certificate. - DtlsTransport::certificate = X509_new(); + // Set SSL info callback. + SSL_CTX_set_info_callback(DtlsTransport::sslCtx, [](const SSL* ssl, int where, int ret){ + static_cast(SSL_get_ex_data(ssl, 0))->OnSslInfo(where, ret); + }); + // Set ciphers. + ret = SSL_CTX_set_cipher_list( + DtlsTransport::sslCtx, "DEFAULT:!NULL:!aNULL:!SHA256:!SHA384:!aECDH:!AESGCM+AES256:!aPSK"); - if (!DtlsTransport::certificate) { - LOG_OPENSSL_ERROR("X509_new() failed"); + if (ret == 0) + { + LOG_OPENSSL_ERROR("SSL_CTX_set_cipher_list() failed"); - goto error; - } + goto error; + } - // Set version 3 (note that 0 means version 1). - X509_set_version(DtlsTransport::certificate, 2); + // Enable ECDH ciphers. + // DOC: http://en.wikibooks.org/wiki/OpenSSL/Diffie-Hellman_parameters + // NOTE: https://code.google.com/p/chromium/issues/detail?id=406458 + // NOTE: https://bugs.ruby-lang.org/issues/12324 - // Set serial number (avoid default 0). - // ASN1_INTEGER_set(X509_get_serialNumber(DtlsTransport::certificate), - // static_cast(Utils::Crypto::GetRandomUInt(1000000, 9999999))); - ASN1_INTEGER_set(X509_get_serialNumber(DtlsTransport::certificate), - static_cast(rand() % 999999 + 100000)); + // For OpenSSL >= 1.0.2. + SSL_CTX_set_ecdh_auto(DtlsTransport::sslCtx, 1); - // Set valid period. - X509_gmtime_adj(X509_get_notBefore(DtlsTransport::certificate), -315360000); // -10 years. - X509_gmtime_adj(X509_get_notAfter(DtlsTransport::certificate), 315360000); // 10 years. + // Set the "use_srtp" DTLS extension. + for (auto it = DtlsTransport::srtpCryptoSuites.begin(); + it != DtlsTransport::srtpCryptoSuites.end(); + ++it) + { + if (it != DtlsTransport::srtpCryptoSuites.begin()) + dtlsSrtpCryptoSuites += ":"; - // Set the public key for the certificate using the key. - ret = X509_set_pubkey(DtlsTransport::certificate, DtlsTransport::privateKey); + SrtpCryptoSuiteMapEntry* cryptoSuiteEntry = std::addressof(*it); + dtlsSrtpCryptoSuites += cryptoSuiteEntry->name; + } - if (ret == 0) { - LOG_OPENSSL_ERROR("X509_set_pubkey() failed"); + MS_DEBUG_2TAGS(dtls, srtp, "setting SRTP cryptoSuites for DTLS: %s", dtlsSrtpCryptoSuites.c_str()); - goto error; - } + // NOTE: This function returns 0 on success. + ret = SSL_CTX_set_tlsext_use_srtp(DtlsTransport::sslCtx, dtlsSrtpCryptoSuites.c_str()); - // Set certificate fields. - certName = X509_get_subject_name(DtlsTransport::certificate); + if (ret != 0) + { + MS_ERROR( + "SSL_CTX_set_tlsext_use_srtp() failed when entering '%s'", dtlsSrtpCryptoSuites.c_str()); + LOG_OPENSSL_ERROR("SSL_CTX_set_tlsext_use_srtp() failed"); - if (!certName) { - LOG_OPENSSL_ERROR("X509_get_subject_name() failed"); + goto error; + } - goto error; - } + return; - X509_NAME_add_entry_by_txt(certName, "O", MBSTRING_ASC, - reinterpret_cast(subject.c_str()), -1, -1, 0); - X509_NAME_add_entry_by_txt(certName, "CN", MBSTRING_ASC, - reinterpret_cast(subject.c_str()), -1, -1, 0); + error: - // It is self-signed so set the issuer name to be the same as the subject. - ret = X509_set_issuer_name(DtlsTransport::certificate, certName); + if (DtlsTransport::sslCtx) + { + SSL_CTX_free(DtlsTransport::sslCtx); + DtlsTransport::sslCtx = nullptr; + } - if (ret == 0) { - LOG_OPENSSL_ERROR("X509_set_issuer_name() failed"); + MS_THROW_ERROR("SSL context creation failed"); + } - goto error; - } + void DtlsTransport::GenerateFingerprints() + { + MS_TRACE(); - // Sign the certificate with its own private key. - ret = X509_sign(DtlsTransport::certificate, DtlsTransport::privateKey, EVP_sha1()); + for (auto& kv : DtlsTransport::string2FingerprintAlgorithm) + { + const std::string& algorithmString = kv.first; + FingerprintAlgorithm algorithm = kv.second; + uint8_t binaryFingerprint[EVP_MAX_MD_SIZE]; + unsigned int size{ 0 }; + char hexFingerprint[(EVP_MAX_MD_SIZE * 3) + 1]; + const EVP_MD* hashFunction; + int ret; - if (ret == 0) { - LOG_OPENSSL_ERROR("X509_sign() failed"); + switch (algorithm) + { + case FingerprintAlgorithm::SHA1: + hashFunction = EVP_sha1(); + break; - goto error; - } + case FingerprintAlgorithm::SHA224: + hashFunction = EVP_sha224(); + break; - return; + case FingerprintAlgorithm::SHA256: + hashFunction = EVP_sha256(); + break; -error: + case FingerprintAlgorithm::SHA384: + hashFunction = EVP_sha384(); + break; - if (ecKey) EC_KEY_free(ecKey); + case FingerprintAlgorithm::SHA512: + hashFunction = EVP_sha512(); + break; - if (DtlsTransport::privateKey) - EVP_PKEY_free(DtlsTransport::privateKey); // NOTE: This also frees the EC key. + default: + MS_THROW_ERROR("unknown algorithm"); + } - if (DtlsTransport::certificate) X509_free(DtlsTransport::certificate); + ret = X509_digest(DtlsTransport::certificate, hashFunction, binaryFingerprint, &size); - MS_THROW_ERROR("DTLS certificate and private key generation failed"); -} + if (ret == 0) + { + MS_ERROR("X509_digest() failed"); + MS_THROW_ERROR("Fingerprints generation failed"); + } -void DtlsTransport::ReadCertificateAndPrivateKeyFromFiles() { + // Convert to hexadecimal format in uppercase with colons. + for (unsigned int i{ 0 }; i < size; ++i) + { + std::sprintf(hexFingerprint + (i * 3), "%.2X:", binaryFingerprint[i]); + } + hexFingerprint[(size * 3) - 1] = '\0'; + + MS_DEBUG_TAG(dtls, "%-7s fingerprint: %s", algorithmString.c_str(), hexFingerprint); + + // Store it in the vector. + DtlsTransport::Fingerprint fingerprint; + + fingerprint.algorithm = DtlsTransport::GetFingerprintAlgorithm(algorithmString); + fingerprint.value = hexFingerprint; + + DtlsTransport::localFingerprints.push_back(fingerprint); + } + } + + /* Instance methods. */ + + DtlsTransport::DtlsTransport(EventPoller::Ptr poller,Listener* listener) : poller(std::move(poller)), listener(listener) + { + MS_TRACE(); + + /* Set SSL. */ + + this->ssl = SSL_new(DtlsTransport::sslCtx); + + if (!this->ssl) + { + LOG_OPENSSL_ERROR("SSL_new() failed"); + + goto error; + } + + // Set this as custom data. + SSL_set_ex_data(this->ssl, 0, static_cast(this)); + + this->sslBioFromNetwork = BIO_new(BIO_s_mem()); + + if (!this->sslBioFromNetwork) + { + LOG_OPENSSL_ERROR("BIO_new() failed"); + + SSL_free(this->ssl); + + goto error; + } + + this->sslBioToNetwork = BIO_new(BIO_s_mem()); + + if (!this->sslBioToNetwork) + { + LOG_OPENSSL_ERROR("BIO_new() failed"); + + BIO_free(this->sslBioFromNetwork); + SSL_free(this->ssl); + + goto error; + } + + SSL_set_bio(this->ssl, this->sslBioFromNetwork, this->sslBioToNetwork); + + // Set the MTU so that we don't send packets that are too large with no fragmentation. + SSL_set_mtu(this->ssl, DtlsMtu); + DTLS_set_link_mtu(this->ssl, DtlsMtu); + + // Set callback handler for setting DTLS timer interval. + DTLS_set_timer_cb(this->ssl, onSslDtlsTimer); + + return; + + error: + + // NOTE: At this point SSL_set_bio() was not called so we must free BIOs as + // well. + if (this->sslBioFromNetwork) + BIO_free(this->sslBioFromNetwork); + + if (this->sslBioToNetwork) + BIO_free(this->sslBioToNetwork); + + if (this->ssl) + SSL_free(this->ssl); + + // NOTE: If this is not catched by the caller the program will abort, but + // this should never happen. + MS_THROW_ERROR("DtlsTransport instance creation failed"); + } + + DtlsTransport::~DtlsTransport() + { + MS_TRACE(); + + if (IsRunning()) + { + // Send close alert to the peer. + SSL_shutdown(this->ssl); + SendPendingOutgoingDtlsData(); + } + + if (this->ssl) + { + SSL_free(this->ssl); + + this->ssl = nullptr; + this->sslBioFromNetwork = nullptr; + this->sslBioToNetwork = nullptr; + } + + // Close the DTLS timer. + this->timer = nullptr; + } + + void DtlsTransport::Dump() const + { + MS_TRACE(); + + std::string state{ "new" }; + std::string role{ "none " }; + + switch (this->state) + { + case DtlsState::CONNECTING: + state = "connecting"; + break; + case DtlsState::CONNECTED: + state = "connected"; + break; + case DtlsState::FAILED: + state = "failed"; + break; + case DtlsState::CLOSED: + state = "closed"; + break; + default:; + } + + switch (this->localRole) + { + case Role::AUTO: + role = "auto"; + break; + case Role::SERVER: + role = "server"; + break; + case Role::CLIENT: + role = "client"; + break; + default:; + } + + MS_DUMP(""); + MS_DUMP(" state : %s", state.c_str()); + MS_DUMP(" role : %s", role.c_str()); + MS_DUMP(" handshake done: : %s", this->handshakeDone ? "yes" : "no"); + MS_DUMP(""); + } + + void DtlsTransport::Run(Role localRole) + { + MS_TRACE(); + + MS_ASSERT( + localRole == Role::CLIENT || localRole == Role::SERVER, + "local DTLS role must be 'client' or 'server'"); + + Role previousLocalRole = this->localRole; + + if (localRole == previousLocalRole) + { + MS_ERROR("same local DTLS role provided, doing nothing"); + + return; + } + + // If the previous local DTLS role was 'client' or 'server' do reset. + if (previousLocalRole == Role::CLIENT || previousLocalRole == Role::SERVER) + { + MS_DEBUG_TAG(dtls, "resetting DTLS due to local role change"); + + Reset(); + } + + // Update local role. + this->localRole = localRole; + + // Set state and notify the listener. + this->state = DtlsState::CONNECTING; + this->listener->OnDtlsTransportConnecting(this); + + switch (this->localRole) + { + case Role::CLIENT: + { + MS_DEBUG_TAG(dtls, "running [role:client]"); + + SSL_set_connect_state(this->ssl); + SSL_do_handshake(this->ssl); + SendPendingOutgoingDtlsData(); + SetTimeout(); + + break; + } + + case Role::SERVER: + { + MS_DEBUG_TAG(dtls, "running [role:server]"); + + SSL_set_accept_state(this->ssl); + SSL_do_handshake(this->ssl); + + break; + } + + default: + { + MS_ABORT("invalid local DTLS role"); + } + } + } + + bool DtlsTransport::SetRemoteFingerprint(Fingerprint fingerprint) + { + MS_TRACE(); + + MS_ASSERT( + fingerprint.algorithm != FingerprintAlgorithm::NONE, "no fingerprint algorithm provided"); + + this->remoteFingerprint = fingerprint; + + // The remote fingerpring may have been set after DTLS handshake was done, + // so we may need to process it now. + if (this->handshakeDone && this->state != DtlsState::CONNECTED) + { + MS_DEBUG_TAG(dtls, "handshake already done, processing it right now"); + + return ProcessHandshake(); + } + + return true; + } + + void DtlsTransport::ProcessDtlsData(const uint8_t* data, size_t len) + { + MS_TRACE(); + + int written; + int read; + + if (!IsRunning()) + { + MS_ERROR("cannot process data while not running"); + + return; + } + + // Write the received DTLS data into the sslBioFromNetwork. + written = + BIO_write(this->sslBioFromNetwork, static_cast(data), static_cast(len)); + + if (written != static_cast(len)) + { + MS_WARN_TAG( + dtls, + "OpenSSL BIO_write() wrote less (%zu bytes) than given data (%zu bytes)", + static_cast(written), + len); + } + + // Must call SSL_read() to process received DTLS data. + read = SSL_read(this->ssl, static_cast(DtlsTransport::sslReadBuffer), SslReadBufferSize); + + // Send data if it's ready. + SendPendingOutgoingDtlsData(); + + // Check SSL status and return if it is bad/closed. + if (!CheckStatus(read)) + return; + + // Set/update the DTLS timeout. + if (!SetTimeout()) + return; + + // Application data received. Notify to the listener. + if (read > 0) + { + // It is allowed to receive DTLS data even before validating remote fingerprint. + if (!this->handshakeDone) + { + MS_WARN_TAG(dtls, "ignoring application data received while DTLS handshake not done"); + + return; + } + + // Notify the listener. + this->listener->OnDtlsTransportApplicationDataReceived( + this, (uint8_t*)DtlsTransport::sslReadBuffer, static_cast(read)); + } + } + + void DtlsTransport::SendApplicationData(const uint8_t* data, size_t len) + { + MS_TRACE(); + + // We cannot send data to the peer if its remote fingerprint is not validated. + if (this->state != DtlsState::CONNECTED) + { + MS_WARN_TAG(dtls, "cannot send application data while DTLS is not fully connected"); + + return; + } + + if (len == 0) + { + MS_WARN_TAG(dtls, "ignoring 0 length data"); + + return; + } + + int written; + + written = SSL_write(this->ssl, static_cast(data), static_cast(len)); + + if (written < 0) + { + LOG_OPENSSL_ERROR("SSL_write() failed"); + + if (!CheckStatus(written)) + return; + } + else if (written != static_cast(len)) + { + MS_WARN_TAG( + dtls, "OpenSSL SSL_write() wrote less (%d bytes) than given data (%zu bytes)", written, len); + } + + // Send data. + SendPendingOutgoingDtlsData(); + } + + void DtlsTransport::Reset() + { + MS_TRACE(); + + int ret; + + if (!IsRunning()) + return; + + MS_WARN_TAG(dtls, "resetting DTLS transport"); + + // Stop the DTLS timer. + this->timer = nullptr; + + // We need to reset the SSL instance so we need to "shutdown" it, but we + // don't want to send a Close Alert to the peer, so just don't call + // SendPendingOutgoingDTLSData(). + SSL_shutdown(this->ssl); + + this->localRole = Role::NONE; + this->state = DtlsState::NEW; + this->handshakeDone = false; + this->handshakeDoneNow = false; + + // Reset SSL status. + // NOTE: For this to properly work, SSL_shutdown() must be called before. + // NOTE: This may fail if not enough DTLS handshake data has been received, + // but we don't care so just clear the error queue. + ret = SSL_clear(this->ssl); + + if (ret == 0) + ERR_clear_error(); + } + + inline bool DtlsTransport::CheckStatus(int returnCode) + { + MS_TRACE(); + + int err; + bool wasHandshakeDone = this->handshakeDone; + + err = SSL_get_error(this->ssl, returnCode); + + switch (err) + { + case SSL_ERROR_NONE: + break; + + case SSL_ERROR_SSL: + LOG_OPENSSL_ERROR("SSL status: SSL_ERROR_SSL"); + break; + + case SSL_ERROR_WANT_READ: + break; + + case SSL_ERROR_WANT_WRITE: + MS_WARN_TAG(dtls, "SSL status: SSL_ERROR_WANT_WRITE"); + break; + + case SSL_ERROR_WANT_X509_LOOKUP: + MS_DEBUG_TAG(dtls, "SSL status: SSL_ERROR_WANT_X509_LOOKUP"); + break; + + case SSL_ERROR_SYSCALL: + LOG_OPENSSL_ERROR("SSL status: SSL_ERROR_SYSCALL"); + break; + + case SSL_ERROR_ZERO_RETURN: + break; + + case SSL_ERROR_WANT_CONNECT: + MS_WARN_TAG(dtls, "SSL status: SSL_ERROR_WANT_CONNECT"); + break; + + case SSL_ERROR_WANT_ACCEPT: + MS_WARN_TAG(dtls, "SSL status: SSL_ERROR_WANT_ACCEPT"); + break; + + default: + MS_WARN_TAG(dtls, "SSL status: unknown error"); + } + + // Check if the handshake (or re-handshake) has been done right now. + if (this->handshakeDoneNow) + { + this->handshakeDoneNow = false; + this->handshakeDone = true; + + // Stop the timer. + this->timer = nullptr; + + // Process the handshake just once (ignore if DTLS renegotiation). + if (!wasHandshakeDone && this->remoteFingerprint.algorithm != FingerprintAlgorithm::NONE) + return ProcessHandshake(); + + return true; + } + // Check if the peer sent close alert or a fatal error happened. + else if (((SSL_get_shutdown(this->ssl) & SSL_RECEIVED_SHUTDOWN) != 0) || err == SSL_ERROR_SSL || err == SSL_ERROR_SYSCALL) + { + if (this->state == DtlsState::CONNECTED) + { + MS_DEBUG_TAG(dtls, "disconnected"); + + Reset(); + + // Set state and notify the listener. + this->state = DtlsState::CLOSED; + this->listener->OnDtlsTransportClosed(this); + } + else + { + MS_WARN_TAG(dtls, "connection failed"); + + Reset(); + + // Set state and notify the listener. + this->state = DtlsState::FAILED; + this->listener->OnDtlsTransportFailed(this); + } + + return false; + } + else + { + return true; + } + } + + inline void DtlsTransport::SendPendingOutgoingDtlsData() + { + MS_TRACE(); + + if (BIO_eof(this->sslBioToNetwork)) + return; + + int64_t read; + char* data{ nullptr }; + + read = BIO_get_mem_data(this->sslBioToNetwork, &data); // NOLINT + + if (read <= 0) + return; + + MS_DEBUG_DEV("%" PRIu64 " bytes of DTLS data ready to sent to the peer", read); + + // Notify the listener. + this->listener->OnDtlsTransportSendData( + this, reinterpret_cast(data), static_cast(read)); + + // Clear the BIO buffer. + // NOTE: the (void) avoids the -Wunused-value warning. + (void)BIO_reset(this->sslBioToNetwork); + } + + inline bool DtlsTransport::SetTimeout() + { + MS_TRACE(); + + MS_ASSERT( + this->state == DtlsState::CONNECTING || this->state == DtlsState::CONNECTED, + "invalid DTLS state"); + + int64_t ret; + struct timeval dtlsTimeout{ 0, 0 }; + uint64_t timeoutMs; + + // NOTE: If ret == 0 then ignore the value in dtlsTimeout. + // NOTE: No DTLSv_1_2_get_timeout() or DTLS_get_timeout() in OpenSSL 1.1.0-dev. + ret = DTLSv1_get_timeout(this->ssl, static_cast(&dtlsTimeout)); // NOLINT + + if (ret == 0) + return true; + + timeoutMs = (dtlsTimeout.tv_sec * static_cast(1000)) + (dtlsTimeout.tv_usec / 1000); + + if (timeoutMs == 0) + { + return true; + } + else if (timeoutMs < 30000) + { + MS_DEBUG_DEV("DTLS timer set in %" PRIu64 "ms", timeoutMs); + + weak_ptr weak_self = shared_from_this(); + this->timer = std::make_shared(timeoutMs / 1000.0f, [weak_self](){ + auto strong_self = weak_self.lock(); + if(strong_self){ + strong_self->OnTimer(); + } + return true; + }, this->poller); + + return true; + } + // NOTE: Don't start the timer again if the timeout is greater than 30 seconds. + else + { + MS_WARN_TAG(dtls, "DTLS timeout too high (%" PRIu64 "ms), resetting DLTS", timeoutMs); + + Reset(); + + // Set state and notify the listener. + this->state = DtlsState::FAILED; + this->listener->OnDtlsTransportFailed(this); + + return false; + } + } + + inline bool DtlsTransport::ProcessHandshake() + { + MS_TRACE(); + + MS_ASSERT(this->handshakeDone, "handshake not done yet"); + MS_ASSERT( + this->remoteFingerprint.algorithm != FingerprintAlgorithm::NONE, "remote fingerprint not set"); + + // Validate the remote fingerprint. + if (!CheckRemoteFingerprint()) + { + Reset(); + + // Set state and notify the listener. + this->state = DtlsState::FAILED; + this->listener->OnDtlsTransportFailed(this); + + return false; + } + + // Get the negotiated SRTP crypto suite. + RTC::SrtpSession::CryptoSuite srtpCryptoSuite = GetNegotiatedSrtpCryptoSuite(); + + if (srtpCryptoSuite != RTC::SrtpSession::CryptoSuite::NONE) + { + // Extract the SRTP keys (will notify the listener with them). + ExtractSrtpKeys(srtpCryptoSuite); + + return true; + } + + // NOTE: We assume that "use_srtp" DTLS extension is required even if + // there is no audio/video. + MS_WARN_2TAGS(dtls, srtp, "SRTP crypto suite not negotiated"); + + Reset(); + + // Set state and notify the listener. + this->state = DtlsState::FAILED; + this->listener->OnDtlsTransportFailed(this); + + return false; + } + + inline bool DtlsTransport::CheckRemoteFingerprint() + { + MS_TRACE(); + + MS_ASSERT( + this->remoteFingerprint.algorithm != FingerprintAlgorithm::NONE, "remote fingerprint not set"); + + X509* certificate; + uint8_t binaryFingerprint[EVP_MAX_MD_SIZE]; + unsigned int size{ 0 }; + char hexFingerprint[(EVP_MAX_MD_SIZE * 3) + 1]; + const EVP_MD* hashFunction; + int ret; + + certificate = SSL_get_peer_certificate(this->ssl); + + if (!certificate) + { + MS_WARN_TAG(dtls, "no certificate was provided by the peer"); + + return false; + } + + switch (this->remoteFingerprint.algorithm) + { + case FingerprintAlgorithm::SHA1: + hashFunction = EVP_sha1(); + break; + + case FingerprintAlgorithm::SHA224: + hashFunction = EVP_sha224(); + break; + + case FingerprintAlgorithm::SHA256: + hashFunction = EVP_sha256(); + break; + + case FingerprintAlgorithm::SHA384: + hashFunction = EVP_sha384(); + break; + + case FingerprintAlgorithm::SHA512: + hashFunction = EVP_sha512(); + break; + + default: + MS_ABORT("unknown algorithm"); + } + + // Compare the remote fingerprint with the value given via signaling. + ret = X509_digest(certificate, hashFunction, binaryFingerprint, &size); + + if (ret == 0) + { + MS_ERROR("X509_digest() failed"); + + X509_free(certificate); + + return false; + } + + // Convert to hexadecimal format in uppercase with colons. + for (unsigned int i{ 0 }; i < size; ++i) + { + std::sprintf(hexFingerprint + (i * 3), "%.2X:", binaryFingerprint[i]); + } + hexFingerprint[(size * 3) - 1] = '\0'; + + if (this->remoteFingerprint.value != hexFingerprint) + { + MS_WARN_TAG( + dtls, + "fingerprint in the remote certificate (%s) does not match the announced one (%s)", + hexFingerprint, + this->remoteFingerprint.value.c_str()); + + //todo 先屏蔽检查客户端签名 #if 0 - MS_TRACE(); + X509_free(certificate); - FILE* file{nullptr}; - - file = fopen(Settings::configuration.dtlsCertificateFile.c_str(), "r"); - - if (!file) { - MS_ERROR("error reading DTLS certificate file: %s", std::strerror(errno)); - - goto error; - } - - DtlsTransport::certificate = PEM_read_X509(file, nullptr, nullptr, nullptr); - - if (!DtlsTransport::certificate) { - LOG_OPENSSL_ERROR("PEM_read_X509() failed"); - - goto error; - } - - fclose(file); - - file = fopen(Settings::configuration.dtlsPrivateKeyFile.c_str(), "r"); - - if (!file) { - MS_ERROR("error reading DTLS private key file: %s", std::strerror(errno)); - - goto error; - } - - DtlsTransport::privateKey = PEM_read_PrivateKey(file, nullptr, nullptr, nullptr); - - if (!DtlsTransport::privateKey) { - LOG_OPENSSL_ERROR("PEM_read_PrivateKey() failed"); - - goto error; - } - - fclose(file); - - return; - - error: - - MS_THROW_ERROR("error reading DTLS certificate and private key PEM files"); + return false; #endif -} + } -void DtlsTransport::CreateSslCtx() { - MS_TRACE(); + MS_DEBUG_TAG(dtls, "valid remote fingerprint"); - std::string dtlsSrtpCryptoSuites; - int ret; + // Get the remote certificate in PEM format. - /* Set the global DTLS context. */ + BIO* bio = BIO_new(BIO_s_mem()); - // Both DTLS 1.0 and 1.2 (requires OpenSSL >= 1.1.0). - DtlsTransport::sslCtx = SSL_CTX_new(DTLS_method()); + // Ensure the underlying BUF_MEM structure is also freed. + // NOTE: Avoid stupid "warning: value computed is not used [-Wunused-value]" since + // BIO_set_close() always returns 1. + (void)BIO_set_close(bio, BIO_CLOSE); - if (!DtlsTransport::sslCtx) { - LOG_OPENSSL_ERROR("SSL_CTX_new() failed"); + ret = PEM_write_bio_X509(bio, certificate); - goto error; - } + if (ret != 1) + { + LOG_OPENSSL_ERROR("PEM_write_bio_X509() failed"); - ret = SSL_CTX_use_certificate(DtlsTransport::sslCtx, DtlsTransport::certificate); + X509_free(certificate); + BIO_free(bio); - if (ret == 0) { - LOG_OPENSSL_ERROR("SSL_CTX_use_certificate() failed"); + return false; + } - goto error; - } + BUF_MEM* mem; - ret = SSL_CTX_use_PrivateKey(DtlsTransport::sslCtx, DtlsTransport::privateKey); + BIO_get_mem_ptr(bio, &mem); // NOLINT[cppcoreguidelines-pro-type-cstyle-cast] - if (ret == 0) { - LOG_OPENSSL_ERROR("SSL_CTX_use_PrivateKey() failed"); + if (!mem || !mem->data || mem->length == 0u) + { + LOG_OPENSSL_ERROR("BIO_get_mem_ptr() failed"); - goto error; - } + X509_free(certificate); + BIO_free(bio); - ret = SSL_CTX_check_private_key(DtlsTransport::sslCtx); + return false; + } - if (ret == 0) { - LOG_OPENSSL_ERROR("SSL_CTX_check_private_key() failed"); + this->remoteCert = std::string(mem->data, mem->length); - goto error; - } - - // Set options. - SSL_CTX_set_options(DtlsTransport::sslCtx, SSL_OP_CIPHER_SERVER_PREFERENCE | SSL_OP_NO_TICKET | - SSL_OP_SINGLE_ECDH_USE | SSL_OP_NO_QUERY_MTU); - - // Don't use sessions cache. - SSL_CTX_set_session_cache_mode(DtlsTransport::sslCtx, SSL_SESS_CACHE_OFF); - - // Read always as much into the buffer as possible. - // NOTE: This is the default for DTLS, but a bug in non latest OpenSSL - // versions makes this call required. - SSL_CTX_set_read_ahead(DtlsTransport::sslCtx, 1); - - SSL_CTX_set_verify_depth(DtlsTransport::sslCtx, 4); - - // Require certificate from peer. - SSL_CTX_set_verify(DtlsTransport::sslCtx, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, - onSslCertificateVerify); - - // Set SSL info callback. - SSL_CTX_set_info_callback(DtlsTransport::sslCtx, onSslInfo); - - // Set ciphers. - ret = SSL_CTX_set_cipher_list(DtlsTransport::sslCtx, - "DEFAULT:!NULL:!aNULL:!SHA256:!SHA384:!aECDH:!AESGCM+AES256:!aPSK"); - - if (ret == 0) { - LOG_OPENSSL_ERROR("SSL_CTX_set_cipher_list() failed"); - - goto error; - } - - // Enable ECDH ciphers. - // DOC: http://en.wikibooks.org/wiki/OpenSSL/Diffie-Hellman_parameters - // NOTE: https://code.google.com/p/chromium/issues/detail?id=406458 - // NOTE: https://bugs.ruby-lang.org/issues/12324 - - // For OpenSSL >= 1.0.2. - SSL_CTX_set_ecdh_auto(DtlsTransport::sslCtx, 1); - - // Set the "use_srtp" DTLS extension. - for (auto it = DtlsTransport::srtpCryptoSuites.begin(); - it != DtlsTransport::srtpCryptoSuites.end(); ++it) { - if (it != DtlsTransport::srtpCryptoSuites.begin()) dtlsSrtpCryptoSuites += ":"; - - SrtpCryptoSuiteMapEntry* cryptoSuiteEntry = std::addressof(*it); - dtlsSrtpCryptoSuites += cryptoSuiteEntry->name; - } - - MS_DEBUG_2TAGS(dtls, srtp, "setting SRTP cryptoSuites for DTLS: %s", - dtlsSrtpCryptoSuites.c_str()); - - // NOTE: This function returns 0 on success. - ret = SSL_CTX_set_tlsext_use_srtp(DtlsTransport::sslCtx, dtlsSrtpCryptoSuites.c_str()); - - if (ret != 0) { - MS_ERROR("SSL_CTX_set_tlsext_use_srtp() failed when entering '%s'", - dtlsSrtpCryptoSuites.c_str()); - LOG_OPENSSL_ERROR("SSL_CTX_set_tlsext_use_srtp() failed"); - - goto error; - } - - return; - -error: - - if (DtlsTransport::sslCtx) { - SSL_CTX_free(DtlsTransport::sslCtx); - DtlsTransport::sslCtx = nullptr; - } - - MS_THROW_ERROR("SSL context creation failed"); -} - -void DtlsTransport::GenerateFingerprints() { - MS_TRACE(); - - for (auto& kv : DtlsTransport::string2FingerprintAlgorithm) { - const std::string& algorithmString = kv.first; - FingerprintAlgorithm algorithm = kv.second; - uint8_t binaryFingerprint[EVP_MAX_MD_SIZE]; - unsigned int size{0}; - char hexFingerprint[(EVP_MAX_MD_SIZE * 3) + 1]; - const EVP_MD* hashFunction; - int ret; - - switch (algorithm) { - case FingerprintAlgorithm::SHA1: - hashFunction = EVP_sha1(); - break; - - case FingerprintAlgorithm::SHA224: - hashFunction = EVP_sha224(); - break; - - case FingerprintAlgorithm::SHA256: - hashFunction = EVP_sha256(); - break; - - case FingerprintAlgorithm::SHA384: - hashFunction = EVP_sha384(); - break; - - case FingerprintAlgorithm::SHA512: - hashFunction = EVP_sha512(); - break; - - default: - MS_THROW_ERROR("unknown algorithm"); - } - - ret = X509_digest(DtlsTransport::certificate, hashFunction, binaryFingerprint, &size); - - if (ret == 0) { - MS_ERROR("X509_digest() failed"); - MS_THROW_ERROR("Fingerprints generation failed"); - } - - // Convert to hexadecimal format in uppercase with colons. - for (unsigned int i{0}; i < size; ++i) { - std::sprintf(hexFingerprint + (i * 3), "%.2X:", binaryFingerprint[i]); - } - hexFingerprint[(size * 3) - 1] = '\0'; - - MS_DEBUG_TAG(dtls, "%-7s fingerprint: %s", algorithmString.c_str(), hexFingerprint); - // Store it in the vector. - DtlsTransport::Fingerprint fingerprint; - - fingerprint.algorithm = DtlsTransport::GetFingerprintAlgorithm(algorithmString); - fingerprint.value = hexFingerprint; - - DtlsTransport::localFingerprints.push_back(fingerprint); - } -} - -/* Instance methods. */ - -DtlsTransport::DtlsTransport(Listener* listener) : listener(listener) { - MS_TRACE(); - - /* Set SSL. */ - - this->ssl = SSL_new(DtlsTransport::sslCtx); - - if (!this->ssl) { - LOG_OPENSSL_ERROR("SSL_new() failed"); - - goto error; - } - - // Set this as custom data. - SSL_set_ex_data(this->ssl, 0, static_cast(this)); - - this->sslBioFromNetwork = BIO_new(BIO_s_mem()); - - if (!this->sslBioFromNetwork) { - LOG_OPENSSL_ERROR("BIO_new() failed"); - - SSL_free(this->ssl); - - goto error; - } - - this->sslBioToNetwork = BIO_new(BIO_s_mem()); - - if (!this->sslBioToNetwork) { - LOG_OPENSSL_ERROR("BIO_new() failed"); - - BIO_free(this->sslBioFromNetwork); - SSL_free(this->ssl); - - goto error; - } - - SSL_set_bio(this->ssl, this->sslBioFromNetwork, this->sslBioToNetwork); - - // Set the MTU so that we don't send packets that are too large with no fragmentation. - SSL_set_mtu(this->ssl, DtlsMtu); - DTLS_set_link_mtu(this->ssl, DtlsMtu); - - // Set callback handler for setting DTLS timer interval. - DTLS_set_timer_cb(this->ssl, onSslDtlsTimer); - - // Set the DTLS timer. - // this->timer = new Timer(this); - - return; - -error: - - // NOTE: At this point SSL_set_bio() was not called so we must free BIOs as - // well. - if (this->sslBioFromNetwork) BIO_free(this->sslBioFromNetwork); - - if (this->sslBioToNetwork) BIO_free(this->sslBioToNetwork); - - if (this->ssl) SSL_free(this->ssl); - - // NOTE: If this is not catched by the caller the program will abort, but - // this should never happen. - MS_THROW_ERROR("DtlsTransport instance creation failed"); -} - -DtlsTransport::~DtlsTransport() { - MS_TRACE(); - - if (IsRunning()) { - // Send close alert to the peer. - SSL_shutdown(this->ssl); - SendPendingOutgoingDtlsData(); - } - - if (this->ssl) { - SSL_free(this->ssl); - - this->ssl = nullptr; - this->sslBioFromNetwork = nullptr; - this->sslBioToNetwork = nullptr; - } - - // Close the DTLS timer. - // delete this->timer; -} - -void DtlsTransport::Dump() const { - MS_TRACE(); - - std::string state{"new"}; - std::string role{"none "}; - - switch (this->state) { - case DtlsState::CONNECTING: - state = "connecting"; - break; - case DtlsState::CONNECTED: - state = "connected"; - break; - case DtlsState::FAILED: - state = "failed"; - break; - case DtlsState::CLOSED: - state = "closed"; - break; - default:; - } - - switch (this->localRole) { - case Role::AUTO: - role = "auto"; - break; - case Role::SERVER: - role = "server"; - break; - case Role::CLIENT: - role = "client"; - break; - default:; - } - - MS_DUMP(""); - MS_DUMP(" state : %s", state.c_str()); - MS_DUMP(" role : %s", role.c_str()); - MS_DUMP(" handshake done: : %s", this->handshakeDone ? "yes" : "no"); - MS_DUMP(""); -} - -void DtlsTransport::Run(Role localRole) { - MS_TRACE(); - - MS_ASSERT(localRole == Role::CLIENT || localRole == Role::SERVER, - "local DTLS role must be 'client' or 'server'"); - - Role previousLocalRole = this->localRole; - - if (localRole == previousLocalRole) { - MS_ERROR("same local DTLS role provided, doing nothing"); - - return; - } - - // If the previous local DTLS role was 'client' or 'server' do reset. - if (previousLocalRole == Role::CLIENT || previousLocalRole == Role::SERVER) { - MS_DEBUG_TAG(dtls, "resetting DTLS due to local role change"); - - Reset(); - } - - // Update local role. - this->localRole = localRole; - - // Set state and notify the listener. - this->state = DtlsState::CONNECTING; - this->listener->OnDtlsTransportConnecting(this); - - switch (this->localRole) { - case Role::CLIENT: { - MS_DEBUG_TAG(dtls, "running [role:client]"); - - SSL_set_connect_state(this->ssl); - SSL_do_handshake(this->ssl); - SendPendingOutgoingDtlsData(); - SetTimeout(); - - break; - } - - case Role::SERVER: { - MS_DEBUG_TAG(dtls, "running [role:server]"); - - SSL_set_accept_state(this->ssl); - SSL_do_handshake(this->ssl); - - break; - } - - default: { - MS_ABORT("invalid local DTLS role"); - } - } -} - -bool DtlsTransport::SetRemoteFingerprint(Fingerprint fingerprint) { - MS_TRACE(); - - MS_ASSERT(fingerprint.algorithm != FingerprintAlgorithm::NONE, - "no fingerprint algorithm provided"); - - this->remoteFingerprint = fingerprint; - - // The remote fingerpring may have been set after DTLS handshake was done, - // so we may need to process it now. - if (this->handshakeDone && this->state != DtlsState::CONNECTED) { - MS_DEBUG_TAG(dtls, "handshake already done, processing it right now"); - - return ProcessHandshake(); - } - - return true; -} - -void DtlsTransport::ProcessDtlsData(const uint8_t* data, size_t len) { - MS_TRACE(); - - int written; - int read; - - if (!IsRunning()) { - MS_ERROR("cannot process data while not running"); - - return; - } - - // Write the received DTLS data into the sslBioFromNetwork. - written = - BIO_write(this->sslBioFromNetwork, static_cast(data), static_cast(len)); - - if (written != static_cast(len)) { - MS_WARN_TAG(dtls, "OpenSSL BIO_write() wrote less (%zu bytes) than given data (%zu bytes)", - static_cast(written), len); - } - - // Must call SSL_read() to process received DTLS data. - read = SSL_read(this->ssl, static_cast(DtlsTransport::sslReadBuffer), SslReadBufferSize); - - // Send data if it's ready. - SendPendingOutgoingDtlsData(); - - // Check SSL status and return if it is bad/closed. - if (!CheckStatus(read)) return; - - // Set/update the DTLS timeout. - if (!SetTimeout()) return; - - // Application data received. Notify to the listener. - if (read > 0) { - // It is allowed to receive DTLS data even before validating remote fingerprint. - if (!this->handshakeDone) { - MS_WARN_TAG(dtls, "ignoring application data received while DTLS handshake not done"); - - return; - } - - // Notify the listener. - this->listener->OnDtlsTransportApplicationDataReceived( - this, (uint8_t*)DtlsTransport::sslReadBuffer, static_cast(read)); - } -} - -void DtlsTransport::SendApplicationData(const uint8_t* data, size_t len) { - MS_TRACE(); - - // We cannot send data to the peer if its remote fingerprint is not validated. - if (this->state != DtlsState::CONNECTED) { - MS_WARN_TAG(dtls, "cannot send application data while DTLS is not fully connected"); - - return; - } - - if (len == 0) { - MS_WARN_TAG(dtls, "ignoring 0 length data"); - - return; - } - - int written; - - written = SSL_write(this->ssl, static_cast(data), static_cast(len)); - - if (written < 0) { - LOG_OPENSSL_ERROR("SSL_write() failed"); - - if (!CheckStatus(written)) return; - } else if (written != static_cast(len)) { - MS_WARN_TAG(dtls, "OpenSSL SSL_write() wrote less (%d bytes) than given data (%zu bytes)", - written, len); - } - - // Send data. - SendPendingOutgoingDtlsData(); -} - -void DtlsTransport::Reset() { - MS_TRACE(); - - int ret; - - if (!IsRunning()) return; - - MS_WARN_TAG(dtls, "resetting DTLS transport"); - - // Stop the DTLS timer. - // this->timer->Stop(); - - // We need to reset the SSL instance so we need to "shutdown" it, but we - // don't want to send a Close Alert to the peer, so just don't call - // SendPendingOutgoingDTLSData(). - SSL_shutdown(this->ssl); - - this->localRole = Role::NONE; - this->state = DtlsState::NEW; - this->handshakeDone = false; - this->handshakeDoneNow = false; - - // Reset SSL status. - // NOTE: For this to properly work, SSL_shutdown() must be called before. - // NOTE: This may fail if not enough DTLS handshake data has been received, - // but we don't care so just clear the error queue. - ret = SSL_clear(this->ssl); - - if (ret == 0) ERR_clear_error(); -} - -inline bool DtlsTransport::CheckStatus(int returnCode) { - MS_TRACE(); - - int err; - bool wasHandshakeDone = this->handshakeDone; - - err = SSL_get_error(this->ssl, returnCode); - - switch (err) { - case SSL_ERROR_NONE: - break; - - case SSL_ERROR_SSL: - LOG_OPENSSL_ERROR("SSL status: SSL_ERROR_SSL"); - break; - - case SSL_ERROR_WANT_READ: - break; - - case SSL_ERROR_WANT_WRITE: - MS_WARN_TAG(dtls, "SSL status: SSL_ERROR_WANT_WRITE"); - break; - - case SSL_ERROR_WANT_X509_LOOKUP: - MS_DEBUG_TAG(dtls, "SSL status: SSL_ERROR_WANT_X509_LOOKUP"); - break; - - case SSL_ERROR_SYSCALL: - LOG_OPENSSL_ERROR("SSL status: SSL_ERROR_SYSCALL"); - break; - - case SSL_ERROR_ZERO_RETURN: - break; - - case SSL_ERROR_WANT_CONNECT: - MS_WARN_TAG(dtls, "SSL status: SSL_ERROR_WANT_CONNECT"); - break; - - case SSL_ERROR_WANT_ACCEPT: - MS_WARN_TAG(dtls, "SSL status: SSL_ERROR_WANT_ACCEPT"); - break; - - default: - MS_WARN_TAG(dtls, "SSL status: unknown error"); - } - - // Check if the handshake (or re-handshake) has been done right now. - if (this->handshakeDoneNow) { - this->handshakeDoneNow = false; - this->handshakeDone = true; - - // Stop the timer. - // this->timer->Stop(); - - // Process the handshake just once (ignore if DTLS renegotiation). - // if (!wasHandshakeDone && this->remoteFingerprint.algorithm != FingerprintAlgorithm::NONE) - // return ProcessHandshake(); - if (!wasHandshakeDone) { - return ProcessHandshake(); - } - - return true; - } - // Check if the peer sent close alert or a fatal error happened. - else if (((SSL_get_shutdown(this->ssl) & SSL_RECEIVED_SHUTDOWN) != 0) || err == SSL_ERROR_SSL || - err == SSL_ERROR_SYSCALL) { - if (this->state == DtlsState::CONNECTED) { - MS_DEBUG_TAG(dtls, "disconnected"); - - Reset(); - - // Set state and notify the listener. - this->state = DtlsState::CLOSED; - this->listener->OnDtlsTransportClosed(this); - } else { - MS_WARN_TAG(dtls, "connection failed"); - - Reset(); - - // Set state and notify the listener. - this->state = DtlsState::FAILED; - this->listener->OnDtlsTransportFailed(this); - } - - return false; - } else { - return true; - } -} - -inline void DtlsTransport::SendPendingOutgoingDtlsData() { - MS_TRACE(); - - if (BIO_eof(this->sslBioToNetwork)) return; - - int64_t read; - char* data{nullptr}; - - read = BIO_get_mem_data(this->sslBioToNetwork, &data); // NOLINT - - if (read <= 0) return; - - MS_DEBUG_DEV("%ld bytes of DTLS data ready to sent to the peer", read); - - // Notify the listener. - this->listener->OnDtlsTransportSendData(this, reinterpret_cast(data), - static_cast(read)); - - // Clear the BIO buffer. - // NOTE: the (void) avoids the -Wunused-value warning. - (void)BIO_reset(this->sslBioToNetwork); -} - -inline bool DtlsTransport::SetTimeout() { - MS_TRACE(); - - MS_ASSERT(this->state == DtlsState::CONNECTING || this->state == DtlsState::CONNECTED, - "invalid DTLS state"); - - int64_t ret; - uv_timeval_t dtlsTimeout{0, 0}; - uint64_t timeoutMs; - - // NOTE: If ret == 0 then ignore the value in dtlsTimeout. - // NOTE: No DTLSv_1_2_get_timeout() or DTLS_get_timeout() in OpenSSL 1.1.0-dev. - ret = DTLSv1_get_timeout(this->ssl, static_cast(&dtlsTimeout)); // NOLINT - - if (ret == 0) return true; - - timeoutMs = (dtlsTimeout.tv_sec * static_cast(1000)) + (dtlsTimeout.tv_usec / 1000); - - if (timeoutMs == 0) { - return true; - } else if (timeoutMs < 30000) { - MS_DEBUG_DEV("DTLS timer set in %lu ms", timeoutMs); - - // this->timer->Start(timeoutMs); - - return true; - } - // NOTE: Don't start the timer again if the timeout is greater than 30 seconds. - else { - MS_WARN_TAG(dtls, "DTLS timeout too high (%lu ms), resetting DLTS", timeoutMs); - - Reset(); - - // Set state and notify the listener. - this->state = DtlsState::FAILED; - this->listener->OnDtlsTransportFailed(this); - - return false; - } -} - -inline bool DtlsTransport::ProcessHandshake() { - MS_TRACE(); - - MS_ASSERT(this->handshakeDone, "handshake not done yet"); -// MS_ASSERT(this->remoteFingerprint.algorithm != FingerprintAlgorithm::NONE, -// "remote fingerprint not set"); - - // Validate the remote fingerprint. - // if (!CheckRemoteFingerprint()) { - // Reset(); - - // // Set state and notify the listener. - // this->state = DtlsState::FAILED; - // this->listener->OnDtlsTransportFailed(this); - - // return false; - // } - - // Get the negotiated SRTP crypto suite. - RTC::CryptoSuite srtpCryptoSuite = GetNegotiatedSrtpCryptoSuite(); - - if (srtpCryptoSuite != RTC::CryptoSuite::NONE) { - // Extract the SRTP keys (will notify the listener with them). - ExtractSrtpKeys(srtpCryptoSuite); - - return true; - } - - // NOTE: We assume that "use_srtp" DTLS extension is required even if - // there is no audio/video. - MS_WARN_2TAGS(dtls, srtp, "SRTP crypto suite not negotiated"); - - Reset(); - - // Set state and notify the listener. - this->state = DtlsState::FAILED; - this->listener->OnDtlsTransportFailed(this); - - return false; -} - -inline bool DtlsTransport::CheckRemoteFingerprint() { - MS_TRACE(); - - MS_ASSERT(this->remoteFingerprint.algorithm != FingerprintAlgorithm::NONE, - "remote fingerprint not set"); - - X509* certificate; - uint8_t binaryFingerprint[EVP_MAX_MD_SIZE]; - unsigned int size{0}; - char hexFingerprint[(EVP_MAX_MD_SIZE * 3) + 1]; - const EVP_MD* hashFunction; - int ret; - - certificate = SSL_get_peer_certificate(this->ssl); - - if (!certificate) { - MS_WARN_TAG(dtls, "no certificate was provided by the peer"); - - return false; - } - - switch (this->remoteFingerprint.algorithm) { - case FingerprintAlgorithm::SHA1: - hashFunction = EVP_sha1(); - break; - - case FingerprintAlgorithm::SHA224: - hashFunction = EVP_sha224(); - break; - - case FingerprintAlgorithm::SHA256: - hashFunction = EVP_sha256(); - break; - - case FingerprintAlgorithm::SHA384: - hashFunction = EVP_sha384(); - break; - - case FingerprintAlgorithm::SHA512: - hashFunction = EVP_sha512(); - break; - - default: - MS_ABORT("unknown algorithm"); - } - - // Compare the remote fingerprint with the value given via signaling. - ret = X509_digest(certificate, hashFunction, binaryFingerprint, &size); - - if (ret == 0) { - MS_ERROR("X509_digest() failed"); - - X509_free(certificate); - - return false; - } - - // Convert to hexadecimal format in uppercase with colons. - for (unsigned int i{0}; i < size; ++i) { - std::sprintf(hexFingerprint + (i * 3), "%.2X:", binaryFingerprint[i]); - } - hexFingerprint[(size * 3) - 1] = '\0'; - - if (this->remoteFingerprint.value != hexFingerprint) { - MS_WARN_TAG(dtls, - "fingerprint in the remote certificate (%s) does not match the announced one (%s)", - hexFingerprint, this->remoteFingerprint.value.c_str()); - - X509_free(certificate); - - return false; - } - - MS_DEBUG_TAG(dtls, "valid remote fingerprint"); - - // Get the remote certificate in PEM format. - - BIO* bio = BIO_new(BIO_s_mem()); - - // Ensure the underlying BUF_MEM structure is also freed. - // NOTE: Avoid stupid "warning: value computed is not used [-Wunused-value]" since - // BIO_set_close() always returns 1. - (void)BIO_set_close(bio, BIO_CLOSE); - - ret = PEM_write_bio_X509(bio, certificate); - - if (ret != 1) { - LOG_OPENSSL_ERROR("PEM_write_bio_X509() failed"); - - X509_free(certificate); - BIO_free(bio); - - return false; - } - - BUF_MEM* mem; - - BIO_get_mem_ptr(bio, &mem); // NOLINT[cppcoreguidelines-pro-type-cstyle-cast] - - if (!mem || !mem->data || mem->length == 0u) { - LOG_OPENSSL_ERROR("BIO_get_mem_ptr() failed"); - - X509_free(certificate); - BIO_free(bio); - - return false; - } - - this->remoteCert = std::string(mem->data, mem->length); - - X509_free(certificate); - BIO_free(bio); - - return true; -} - -inline void DtlsTransport::ExtractSrtpKeys(RTC::CryptoSuite srtpCryptoSuite) { - MS_TRACE(); - - size_t srtpKeyLength{0}; - size_t srtpSaltLength{0}; - size_t srtpMasterLength{0}; - - switch (srtpCryptoSuite) { - case RTC::CryptoSuite::AES_CM_128_HMAC_SHA1_80: - case RTC::CryptoSuite::AES_CM_128_HMAC_SHA1_32: { - srtpKeyLength = SrtpMasterKeyLength; - srtpSaltLength = SrtpMasterSaltLength; - srtpMasterLength = SrtpMasterLength; - - break; - } - - case RTC::CryptoSuite::AEAD_AES_256_GCM: { - srtpKeyLength = SrtpAesGcm256MasterKeyLength; - srtpSaltLength = SrtpAesGcm256MasterSaltLength; - srtpMasterLength = SrtpAesGcm256MasterLength; - - break; - } - - case RTC::CryptoSuite::AEAD_AES_128_GCM: { - srtpKeyLength = SrtpAesGcm128MasterKeyLength; - srtpSaltLength = SrtpAesGcm128MasterSaltLength; - srtpMasterLength = SrtpAesGcm128MasterLength; - - break; - } - - default: { - MS_ABORT("unknown SRTP crypto suite"); - } - } - - auto* srtpMaterial = new uint8_t[srtpMasterLength * 2]; - uint8_t* srtpLocalKey{nullptr}; - uint8_t* srtpLocalSalt{nullptr}; - uint8_t* srtpRemoteKey{nullptr}; - uint8_t* srtpRemoteSalt{nullptr}; - auto* srtpLocalMasterKey = new uint8_t[srtpMasterLength]; - auto* srtpRemoteMasterKey = new uint8_t[srtpMasterLength]; - int ret; - - ret = SSL_export_keying_material(this->ssl, srtpMaterial, srtpMasterLength * 2, - "EXTRACTOR-dtls_srtp", 19, nullptr, 0, 0); - - MS_ASSERT(ret != 0, "SSL_export_keying_material() failed"); - - switch (this->localRole) { - case Role::SERVER: { - srtpRemoteKey = srtpMaterial; - srtpLocalKey = srtpRemoteKey + srtpKeyLength; - srtpRemoteSalt = srtpLocalKey + srtpKeyLength; - srtpLocalSalt = srtpRemoteSalt + srtpSaltLength; - - break; - } - - case Role::CLIENT: { - srtpLocalKey = srtpMaterial; - srtpRemoteKey = srtpLocalKey + srtpKeyLength; - srtpLocalSalt = srtpRemoteKey + srtpKeyLength; - srtpRemoteSalt = srtpLocalSalt + srtpSaltLength; - - break; - } - - default: { - MS_ABORT("no DTLS role set"); - } - } - - // Create the SRTP local master key. - std::memcpy(srtpLocalMasterKey, srtpLocalKey, srtpKeyLength); - std::memcpy(srtpLocalMasterKey + srtpKeyLength, srtpLocalSalt, srtpSaltLength); - // Create the SRTP remote master key. - std::memcpy(srtpRemoteMasterKey, srtpRemoteKey, srtpKeyLength); - std::memcpy(srtpRemoteMasterKey + srtpKeyLength, srtpRemoteSalt, srtpSaltLength); - - // Set state and notify the listener. - this->state = DtlsState::CONNECTED; - this->listener->OnDtlsTransportConnected(this, srtpCryptoSuite, srtpLocalMasterKey, - srtpMasterLength, srtpRemoteMasterKey, srtpMasterLength, - this->remoteCert); - - delete[] srtpMaterial; - delete[] srtpLocalMasterKey; - delete[] srtpRemoteMasterKey; -} - -inline RTC::CryptoSuite DtlsTransport::GetNegotiatedSrtpCryptoSuite() { - MS_TRACE(); - - RTC::CryptoSuite negotiatedSrtpCryptoSuite = RTC::CryptoSuite::NONE; - - // Ensure that the SRTP crypto suite has been negotiated. - // NOTE: This is a OpenSSL type. - SRTP_PROTECTION_PROFILE* sslSrtpCryptoSuite = SSL_get_selected_srtp_profile(this->ssl); - - if (!sslSrtpCryptoSuite) return negotiatedSrtpCryptoSuite; - - // Get the negotiated SRTP crypto suite. - for (auto& srtpCryptoSuite : DtlsTransport::srtpCryptoSuites) { - SrtpCryptoSuiteMapEntry* cryptoSuiteEntry = std::addressof(srtpCryptoSuite); - - if (std::strcmp(sslSrtpCryptoSuite->name, cryptoSuiteEntry->name) == 0) { - MS_DEBUG_2TAGS(dtls, srtp, "chosen SRTP crypto suite: %s", cryptoSuiteEntry->name); - - negotiatedSrtpCryptoSuite = cryptoSuiteEntry->cryptoSuite; - } - } - - MS_ASSERT(negotiatedSrtpCryptoSuite != RTC::CryptoSuite::NONE, - "chosen SRTP crypto suite is not an available one"); - - return negotiatedSrtpCryptoSuite; -} - -inline void DtlsTransport::OnSslInfo(int where, int ret) { - MS_TRACE(); - - int w = where & -SSL_ST_MASK; - const char* role; - - if ((w & SSL_ST_CONNECT) != 0) - role = "client"; - else if ((w & SSL_ST_ACCEPT) != 0) - role = "server"; - else - role = "undefined"; - - if ((where & SSL_CB_LOOP) != 0) { - MS_DEBUG_TAG(dtls, "[role:%s, action:'%s']", role, SSL_state_string_long(this->ssl)); - } else if ((where & SSL_CB_ALERT) != 0) { - const char* alertType; - - switch (*SSL_alert_type_string(ret)) { - case 'W': - alertType = "warning"; - break; - - case 'F': - alertType = "fatal"; - break; - - default: - alertType = "undefined"; - } - - if ((where & SSL_CB_READ) != 0) { - MS_WARN_TAG(dtls, "received DTLS %s alert: %s", alertType, SSL_alert_desc_string_long(ret)); - } else if ((where & SSL_CB_WRITE) != 0) { - MS_DEBUG_TAG(dtls, "sending DTLS %s alert: %s", alertType, SSL_alert_desc_string_long(ret)); - } else { - MS_DEBUG_TAG(dtls, "DTLS %s alert: %s", alertType, SSL_alert_desc_string_long(ret)); - } - } else if ((where & SSL_CB_EXIT) != 0) { - if (ret == 0) - MS_DEBUG_TAG(dtls, "[role:%s, failed:'%s']", role, SSL_state_string_long(this->ssl)); - else if (ret < 0) - MS_DEBUG_TAG(dtls, "role: %s, waiting:'%s']", role, SSL_state_string_long(this->ssl)); - } else if ((where & SSL_CB_HANDSHAKE_START) != 0) { - MS_DEBUG_TAG(dtls, "DTLS handshake start"); - } else if ((where & SSL_CB_HANDSHAKE_DONE) != 0) { - MS_DEBUG_TAG(dtls, "DTLS handshake done"); - - this->handshakeDoneNow = true; - } - - // NOTE: checking SSL_get_shutdown(this->ssl) & SSL_RECEIVED_SHUTDOWN here upon - // receipt of a close alert does not work (the flag is set after this callback). -} - -inline void DtlsTransport::OnTimer() { - MS_TRACE(); - - // Workaround for https://github.com/openssl/openssl/issues/7998. - if (this->handshakeDone) { - MS_DEBUG_DEV("handshake is done so return"); - - return; - } - - DTLSv1_handle_timeout(this->ssl); - - // If required, send DTLS data. - SendPendingOutgoingDtlsData(); - - // Set the DTLS timer again. - SetTimeout(); -} -} // namespace RTC + X509_free(certificate); + BIO_free(bio); + + return true; + } + + inline void DtlsTransport::ExtractSrtpKeys(RTC::SrtpSession::CryptoSuite srtpCryptoSuite) + { + MS_TRACE(); + + size_t srtpKeyLength{ 0 }; + size_t srtpSaltLength{ 0 }; + size_t srtpMasterLength{ 0 }; + + switch (srtpCryptoSuite) + { + case RTC::SrtpSession::CryptoSuite::AES_CM_128_HMAC_SHA1_80: + case RTC::SrtpSession::CryptoSuite::AES_CM_128_HMAC_SHA1_32: + { + srtpKeyLength = SrtpMasterKeyLength; + srtpSaltLength = SrtpMasterSaltLength; + srtpMasterLength = SrtpMasterLength; + + break; + } + + case RTC::SrtpSession::CryptoSuite::AEAD_AES_256_GCM: + { + srtpKeyLength = SrtpAesGcm256MasterKeyLength; + srtpSaltLength = SrtpAesGcm256MasterSaltLength; + srtpMasterLength = SrtpAesGcm256MasterLength; + + break; + } + + case RTC::SrtpSession::CryptoSuite::AEAD_AES_128_GCM: + { + srtpKeyLength = SrtpAesGcm128MasterKeyLength; + srtpSaltLength = SrtpAesGcm128MasterSaltLength; + srtpMasterLength = SrtpAesGcm128MasterLength; + + break; + } + + default: + { + MS_ABORT("unknown SRTP crypto suite"); + } + } + + auto* srtpMaterial = new uint8_t[srtpMasterLength * 2]; + uint8_t* srtpLocalKey{ nullptr }; + uint8_t* srtpLocalSalt{ nullptr }; + uint8_t* srtpRemoteKey{ nullptr }; + uint8_t* srtpRemoteSalt{ nullptr }; + auto* srtpLocalMasterKey = new uint8_t[srtpMasterLength]; + auto* srtpRemoteMasterKey = new uint8_t[srtpMasterLength]; + int ret; + + ret = SSL_export_keying_material( + this->ssl, srtpMaterial, srtpMasterLength * 2, "EXTRACTOR-dtls_srtp", 19, nullptr, 0, 0); + + MS_ASSERT(ret != 0, "SSL_export_keying_material() failed"); + + switch (this->localRole) + { + case Role::SERVER: + { + srtpRemoteKey = srtpMaterial; + srtpLocalKey = srtpRemoteKey + srtpKeyLength; + srtpRemoteSalt = srtpLocalKey + srtpKeyLength; + srtpLocalSalt = srtpRemoteSalt + srtpSaltLength; + + break; + } + + case Role::CLIENT: + { + srtpLocalKey = srtpMaterial; + srtpRemoteKey = srtpLocalKey + srtpKeyLength; + srtpLocalSalt = srtpRemoteKey + srtpKeyLength; + srtpRemoteSalt = srtpLocalSalt + srtpSaltLength; + + break; + } + + default: + { + MS_ABORT("no DTLS role set"); + } + } + + // Create the SRTP local master key. + std::memcpy(srtpLocalMasterKey, srtpLocalKey, srtpKeyLength); + std::memcpy(srtpLocalMasterKey + srtpKeyLength, srtpLocalSalt, srtpSaltLength); + // Create the SRTP remote master key. + std::memcpy(srtpRemoteMasterKey, srtpRemoteKey, srtpKeyLength); + std::memcpy(srtpRemoteMasterKey + srtpKeyLength, srtpRemoteSalt, srtpSaltLength); + + // Set state and notify the listener. + this->state = DtlsState::CONNECTED; + this->listener->OnDtlsTransportConnected( + this, + srtpCryptoSuite, + srtpLocalMasterKey, + srtpMasterLength, + srtpRemoteMasterKey, + srtpMasterLength, + this->remoteCert); + + delete[] srtpMaterial; + delete[] srtpLocalMasterKey; + delete[] srtpRemoteMasterKey; + } + + inline RTC::SrtpSession::CryptoSuite DtlsTransport::GetNegotiatedSrtpCryptoSuite() + { + MS_TRACE(); + + RTC::SrtpSession::CryptoSuite negotiatedSrtpCryptoSuite = RTC::SrtpSession::CryptoSuite::NONE; + + // Ensure that the SRTP crypto suite has been negotiated. + // NOTE: This is a OpenSSL type. + SRTP_PROTECTION_PROFILE* sslSrtpCryptoSuite = SSL_get_selected_srtp_profile(this->ssl); + + if (!sslSrtpCryptoSuite) + return negotiatedSrtpCryptoSuite; + + // Get the negotiated SRTP crypto suite. + for (auto& srtpCryptoSuite : DtlsTransport::srtpCryptoSuites) + { + SrtpCryptoSuiteMapEntry* cryptoSuiteEntry = std::addressof(srtpCryptoSuite); + + if (std::strcmp(sslSrtpCryptoSuite->name, cryptoSuiteEntry->name) == 0) + { + MS_DEBUG_2TAGS(dtls, srtp, "chosen SRTP crypto suite: %s", cryptoSuiteEntry->name); + + negotiatedSrtpCryptoSuite = cryptoSuiteEntry->cryptoSuite; + } + } + + MS_ASSERT( + negotiatedSrtpCryptoSuite != RTC::SrtpSession::CryptoSuite::NONE, + "chosen SRTP crypto suite is not an available one"); + + return negotiatedSrtpCryptoSuite; + } + + inline void DtlsTransport::OnSslInfo(int where, int ret) + { + MS_TRACE(); + + int w = where & -SSL_ST_MASK; + const char* role; + + if ((w & SSL_ST_CONNECT) != 0) + role = "client"; + else if ((w & SSL_ST_ACCEPT) != 0) + role = "server"; + else + role = "undefined"; + + if ((where & SSL_CB_LOOP) != 0) + { + MS_DEBUG_TAG(dtls, "[role:%s, action:'%s']", role, SSL_state_string_long(this->ssl)); + } + else if ((where & SSL_CB_ALERT) != 0) + { + const char* alertType; + + switch (*SSL_alert_type_string(ret)) + { + case 'W': + alertType = "warning"; + break; + + case 'F': + alertType = "fatal"; + break; + + default: + alertType = "undefined"; + } + + if ((where & SSL_CB_READ) != 0) + { + MS_WARN_TAG(dtls, "received DTLS %s alert: %s", alertType, SSL_alert_desc_string_long(ret)); + } + else if ((where & SSL_CB_WRITE) != 0) + { + MS_DEBUG_TAG(dtls, "sending DTLS %s alert: %s", alertType, SSL_alert_desc_string_long(ret)); + } + else + { + MS_DEBUG_TAG(dtls, "DTLS %s alert: %s", alertType, SSL_alert_desc_string_long(ret)); + } + } + else if ((where & SSL_CB_EXIT) != 0) + { + if (ret == 0) + MS_DEBUG_TAG(dtls, "[role:%s, failed:'%s']", role, SSL_state_string_long(this->ssl)); + else if (ret < 0) + MS_DEBUG_TAG(dtls, "role: %s, waiting:'%s']", role, SSL_state_string_long(this->ssl)); + } + else if ((where & SSL_CB_HANDSHAKE_START) != 0) + { + MS_DEBUG_TAG(dtls, "DTLS handshake start"); + } + else if ((where & SSL_CB_HANDSHAKE_DONE) != 0) + { + MS_DEBUG_TAG(dtls, "DTLS handshake done"); + + this->handshakeDoneNow = true; + } + + // NOTE: checking SSL_get_shutdown(this->ssl) & SSL_RECEIVED_SHUTDOWN here upon + // receipt of a close alert does not work (the flag is set after this callback). + } + + inline void DtlsTransport::OnTimer() + { + MS_TRACE(); + + // Workaround for https://github.com/openssl/openssl/issues/7998. + if (this->handshakeDone) + { + MS_DEBUG_DEV("handshake is done so return"); + + return; + } + + DTLSv1_handle_timeout(this->ssl); + + // If required, send DTLS data. + SendPendingOutgoingDtlsData(); + + // Set the DTLS timer again. + SetTimeout(); + } +} // namespace RTC diff --git a/webrtc/rtc_dtls_transport.h b/webrtc/rtc_dtls_transport.h index 628f966b..b3e58b1d 100644 --- a/webrtc/rtc_dtls_transport.h +++ b/webrtc/rtc_dtls_transport.h @@ -1,187 +1,224 @@ #ifndef MS_RTC_DTLS_TRANSPORT_HPP #define MS_RTC_DTLS_TRANSPORT_HPP +#include "srtp_session.h" #include #include #include - #include #include #include +#include "Poller/Timer.h" +#include "Poller/EventPoller.h" +using namespace toolkit; -namespace RTC { -enum class CryptoSuite { - NONE = 0, - AES_CM_128_HMAC_SHA1_80 = 1, - AES_CM_128_HMAC_SHA1_32, - AEAD_AES_256_GCM, - AEAD_AES_128_GCM -}; -class DtlsTransport { - public: - enum class DtlsState { NEW = 1, CONNECTING, CONNECTED, FAILED, CLOSED }; +namespace RTC +{ +class DtlsTransport : public std::enable_shared_from_this + { + public: + enum class DtlsState + { + NEW = 1, + CONNECTING, + CONNECTED, + FAILED, + CLOSED + }; - public: - enum class Role { NONE = 0, AUTO = 1, CLIENT, SERVER }; + public: + enum class Role + { + NONE = 0, + AUTO = 1, + CLIENT, + SERVER + }; - public: - enum class FingerprintAlgorithm { NONE = 0, SHA1 = 1, SHA224, SHA256, SHA384, SHA512 }; + public: + enum class FingerprintAlgorithm + { + NONE = 0, + SHA1 = 1, + SHA224, + SHA256, + SHA384, + SHA512 + }; - public: - struct Fingerprint { - FingerprintAlgorithm algorithm{FingerprintAlgorithm::NONE}; - std::string value; - }; + public: + struct Fingerprint + { + FingerprintAlgorithm algorithm{ FingerprintAlgorithm::NONE }; + std::string value; + }; - private: - struct SrtpCryptoSuiteMapEntry { - RTC::CryptoSuite cryptoSuite; - const char* name; - }; + private: + struct SrtpCryptoSuiteMapEntry + { + RTC::SrtpSession::CryptoSuite cryptoSuite; + const char* name; + }; - public: - class Listener { - public: - // DTLS is in the process of negotiating a secure connection. Incoming - // media can flow through. - // NOTE: The caller MUST NOT call any method during this callback. - virtual void OnDtlsTransportConnecting(const RTC::DtlsTransport* dtlsTransport) = 0; - // DTLS has completed negotiation of a secure connection (including DTLS-SRTP - // and remote fingerprint verification). Outgoing media can now flow through. - // NOTE: The caller MUST NOT call any method during this callback. - virtual void OnDtlsTransportConnected(const RTC::DtlsTransport* dtlsTransport, - RTC::CryptoSuite srtpCryptoSuite, uint8_t* srtpLocalKey, - size_t srtpLocalKeyLen, uint8_t* srtpRemoteKey, - size_t srtpRemoteKeyLen, std::string& remoteCert) = 0; - // The DTLS connection has been closed as the result of an error (such as a - // DTLS alert or a failure to validate the remote fingerprint). - virtual void OnDtlsTransportFailed(const RTC::DtlsTransport* dtlsTransport) = 0; - // The DTLS connection has been closed due to receipt of a close_notify alert. - virtual void OnDtlsTransportClosed(const RTC::DtlsTransport* dtlsTransport) = 0; - // Need to send DTLS data to the peer. - virtual void OnDtlsTransportSendData(const RTC::DtlsTransport* dtlsTransport, - const uint8_t* data, size_t len) = 0; - // DTLS application data received. - virtual void OnDtlsTransportApplicationDataReceived(const RTC::DtlsTransport* dtlsTransport, - const uint8_t* data, size_t len) = 0; - }; + public: + class Listener + { + public: + // DTLS is in the process of negotiating a secure connection. Incoming + // media can flow through. + // NOTE: The caller MUST NOT call any method during this callback. + virtual void OnDtlsTransportConnecting(const RTC::DtlsTransport* dtlsTransport) = 0; + // DTLS has completed negotiation of a secure connection (including DTLS-SRTP + // and remote fingerprint verification). Outgoing media can now flow through. + // NOTE: The caller MUST NOT call any method during this callback. + virtual void OnDtlsTransportConnected( + const RTC::DtlsTransport* dtlsTransport, + RTC::SrtpSession::CryptoSuite srtpCryptoSuite, + uint8_t* srtpLocalKey, + size_t srtpLocalKeyLen, + uint8_t* srtpRemoteKey, + size_t srtpRemoteKeyLen, + std::string& remoteCert) = 0; + // The DTLS connection has been closed as the result of an error (such as a + // DTLS alert or a failure to validate the remote fingerprint). + virtual void OnDtlsTransportFailed(const RTC::DtlsTransport* dtlsTransport) = 0; + // The DTLS connection has been closed due to receipt of a close_notify alert. + virtual void OnDtlsTransportClosed(const RTC::DtlsTransport* dtlsTransport) = 0; + // Need to send DTLS data to the peer. + virtual void OnDtlsTransportSendData( + const RTC::DtlsTransport* dtlsTransport, const uint8_t* data, size_t len) = 0; + // DTLS application data received. + virtual void OnDtlsTransportApplicationDataReceived( + const RTC::DtlsTransport* dtlsTransport, const uint8_t* data, size_t len) = 0; + }; - public: - static void ClassInit(); - static void ClassDestroy(); - static Role StringToRole(const std::string& role) { - auto it = DtlsTransport::string2Role.find(role); + public: + static void ClassInit(); + static void ClassDestroy(); + static Role StringToRole(const std::string& role) + { + auto it = DtlsTransport::string2Role.find(role); - if (it != DtlsTransport::string2Role.end()) - return it->second; - else - return DtlsTransport::Role::NONE; - } - static FingerprintAlgorithm GetFingerprintAlgorithm(const std::string& fingerprint) { - auto it = DtlsTransport::string2FingerprintAlgorithm.find(fingerprint); + if (it != DtlsTransport::string2Role.end()) + return it->second; + else + return DtlsTransport::Role::NONE; + } + static FingerprintAlgorithm GetFingerprintAlgorithm(const std::string& fingerprint) + { + auto it = DtlsTransport::string2FingerprintAlgorithm.find(fingerprint); - if (it != DtlsTransport::string2FingerprintAlgorithm.end()) - return it->second; - else - return DtlsTransport::FingerprintAlgorithm::NONE; - } - static std::string& GetFingerprintAlgorithmString(FingerprintAlgorithm fingerprint) { - auto it = DtlsTransport::fingerprintAlgorithm2String.find(fingerprint); + if (it != DtlsTransport::string2FingerprintAlgorithm.end()) + return it->second; + else + return DtlsTransport::FingerprintAlgorithm::NONE; + } + static std::string& GetFingerprintAlgorithmString(FingerprintAlgorithm fingerprint) + { + auto it = DtlsTransport::fingerprintAlgorithm2String.find(fingerprint); - return it->second; - } - static bool IsDtls(const uint8_t* data, size_t len) { - // clang-format off + return it->second; + } + static bool IsDtls(const uint8_t* data, size_t len) + { + // clang-format off return ( // Minimum DTLS record length is 13 bytes. (len >= 13) && // DOC: https://tools.ietf.org/html/draft-ietf-avtcore-rfc5764-mux-fixes (data[0] > 19 && data[0] < 64) ); - // clang-format on - } + // clang-format on + } - private: - static void GenerateCertificateAndPrivateKey(); - static void ReadCertificateAndPrivateKeyFromFiles(); - static void CreateSslCtx(); - static void GenerateFingerprints(); + private: + static void GenerateCertificateAndPrivateKey(); + static void ReadCertificateAndPrivateKeyFromFiles(); + static void CreateSslCtx(); + static void GenerateFingerprints(); - private: - static X509* certificate; - static EVP_PKEY* privateKey; - static SSL_CTX* sslCtx; - static uint8_t sslReadBuffer[]; - static std::map string2Role; - static std::map string2FingerprintAlgorithm; - static std::map fingerprintAlgorithm2String; - static std::vector localFingerprints; - static std::vector srtpCryptoSuites; + private: + static X509* certificate; + static EVP_PKEY* privateKey; + static SSL_CTX* sslCtx; + static uint8_t sslReadBuffer[]; + static std::map string2Role; + static std::map string2FingerprintAlgorithm; + static std::map fingerprintAlgorithm2String; + static std::vector localFingerprints; + static std::vector srtpCryptoSuites; - public: - explicit DtlsTransport(Listener* listener); - ~DtlsTransport(); + public: + DtlsTransport(EventPoller::Ptr poller, Listener* listener); + ~DtlsTransport(); - public: - void Dump() const; - void Run(Role localRole); - std::vector& GetLocalFingerprints() const { - return DtlsTransport::localFingerprints; - } - bool SetRemoteFingerprint(Fingerprint fingerprint); - void ProcessDtlsData(const uint8_t* data, size_t len); - DtlsState GetState() const { return this->state; } - Role GetLocalRole() const { return this->localRole; } - void SendApplicationData(const uint8_t* data, size_t len); + public: + void Dump() const; + void Run(Role localRole); + std::vector& GetLocalFingerprints() const + { + return DtlsTransport::localFingerprints; + } + bool SetRemoteFingerprint(Fingerprint fingerprint); + void ProcessDtlsData(const uint8_t* data, size_t len); + DtlsState GetState() const + { + return this->state; + } + Role GetLocalRole() const + { + return this->localRole; + } + void SendApplicationData(const uint8_t* data, size_t len); - private: - bool IsRunning() const { - switch (this->state) { - case DtlsState::NEW: - return false; - case DtlsState::CONNECTING: - case DtlsState::CONNECTED: - return true; - case DtlsState::FAILED: - case DtlsState::CLOSED: - return false; - } + private: + bool IsRunning() const + { + switch (this->state) + { + case DtlsState::NEW: + return false; + case DtlsState::CONNECTING: + case DtlsState::CONNECTED: + return true; + case DtlsState::FAILED: + case DtlsState::CLOSED: + return false; + } - // Make GCC 4.9 happy. - return false; - } - void Reset(); - bool CheckStatus(int returnCode); - void SendPendingOutgoingDtlsData(); - bool SetTimeout(); - bool ProcessHandshake(); - bool CheckRemoteFingerprint(); - void ExtractSrtpKeys(RTC::CryptoSuite srtpCryptoSuite); - RTC::CryptoSuite GetNegotiatedSrtpCryptoSuite(); + // Make GCC 4.9 happy. + return false; + } + void Reset(); + bool CheckStatus(int returnCode); + void SendPendingOutgoingDtlsData(); + bool SetTimeout(); + bool ProcessHandshake(); + bool CheckRemoteFingerprint(); + void ExtractSrtpKeys(RTC::SrtpSession::CryptoSuite srtpCryptoSuite); + RTC::SrtpSession::CryptoSuite GetNegotiatedSrtpCryptoSuite(); - /* Callbacks fired by OpenSSL events. */ - public: - void OnSslInfo(int where, int ret); + private: + void OnSslInfo(int where, int ret); + void OnTimer(); - /* Pure virtual methods inherited from Timer::Listener. */ - public: - void OnTimer(); - - private: - // Passed by argument. - Listener* listener{nullptr}; - // Allocated by this. - SSL* ssl{nullptr}; - BIO* sslBioFromNetwork{nullptr}; // The BIO from which ssl reads. - BIO* sslBioToNetwork{nullptr}; // The BIO in which ssl writes. - // Others. - DtlsState state{DtlsState::NEW}; - Role localRole{Role::NONE}; - Fingerprint remoteFingerprint; - bool handshakeDone{false}; - bool handshakeDoneNow{false}; - std::string remoteCert; -}; -} // namespace RTC + private: + EventPoller::Ptr poller; + // Passed by argument. + Listener* listener{ nullptr }; + // Allocated by this. + SSL* ssl{ nullptr }; + BIO* sslBioFromNetwork{ nullptr }; // The BIO from which ssl reads. + BIO* sslBioToNetwork{ nullptr }; // The BIO in which ssl writes. + Timer::Ptr timer; + // Others. + DtlsState state{ DtlsState::NEW }; + Role localRole{ Role::NONE }; + Fingerprint remoteFingerprint; + bool handshakeDone{ false }; + bool handshakeDoneNow{ false }; + std::string remoteCert; + }; +} // namespace RTC #endif diff --git a/webrtc/srtp_session.cc b/webrtc/srtp_session.cc index b21a83cd..fd8de70d 100644 --- a/webrtc/srtp_session.cc +++ b/webrtc/srtp_session.cc @@ -2,268 +2,287 @@ // #define MS_LOG_DEV_LEVEL 3 #include "srtp_session.h" - -#include // std::memset(), std::memcpy() -#include - +#include // std::memset(), std::memcpy() #include "logger.h" -namespace RTC { -/* Static. */ +namespace RTC +{ + /* Static. */ -static constexpr size_t EncryptBufferSize{65536}; -static uint8_t EncryptBuffer[EncryptBufferSize]; + static constexpr size_t EncryptBufferSize{ 65536 }; + static uint8_t EncryptBuffer[EncryptBufferSize]; -/* Class methods. */ - -std::vector DepLibSRTP::errors = { - // From 0 (srtp_err_status_ok) to 24 (srtp_err_status_pfkey_err). - "success (srtp_err_status_ok)", - "unspecified failure (srtp_err_status_fail)", - "unsupported parameter (srtp_err_status_bad_param)", - "couldn't allocate memory (srtp_err_status_alloc_fail)", - "couldn't deallocate memory (srtp_err_status_dealloc_fail)", - "couldn't initialize (srtp_err_status_init_fail)", - "can’t process as much data as requested (srtp_err_status_terminus)", - "authentication failure (srtp_err_status_auth_fail)", - "cipher failure (srtp_err_status_cipher_fail)", - "replay check failed (bad index) (srtp_err_status_replay_fail)", - "replay check failed (index too old) (srtp_err_status_replay_old)", - "algorithm failed test routine (srtp_err_status_algo_fail)", - "unsupported operation (srtp_err_status_no_such_op)", - "no appropriate context found (srtp_err_status_no_ctx)", - "unable to perform desired validation (srtp_err_status_cant_check)", - "can’t use key any more (srtp_err_status_key_expired)", - "error in use of socket (srtp_err_status_socket_err)", - "error in use POSIX signals (srtp_err_status_signal_err)", - "nonce check failed (srtp_err_status_nonce_bad)", - "couldn’t read data (srtp_err_status_read_fail)", - "couldn’t write data (srtp_err_status_write_fail)", - "error parsing data (srtp_err_status_parse_err)", - "error encoding data (srtp_err_status_encode_err)", - "error while using semaphores (srtp_err_status_semaphore_err)", - "error while using pfkey (srtp_err_status_pfkey_err)"}; + std::vector DepLibSRTP::errors = { + // From 0 (srtp_err_status_ok) to 24 (srtp_err_status_pfkey_err). + "success (srtp_err_status_ok)", + "unspecified failure (srtp_err_status_fail)", + "unsupported parameter (srtp_err_status_bad_param)", + "couldn't allocate memory (srtp_err_status_alloc_fail)", + "couldn't deallocate memory (srtp_err_status_dealloc_fail)", + "couldn't initialize (srtp_err_status_init_fail)", + "can’t process as much data as requested (srtp_err_status_terminus)", + "authentication failure (srtp_err_status_auth_fail)", + "cipher failure (srtp_err_status_cipher_fail)", + "replay check failed (bad index) (srtp_err_status_replay_fail)", + "replay check failed (index too old) (srtp_err_status_replay_old)", + "algorithm failed test routine (srtp_err_status_algo_fail)", + "unsupported operation (srtp_err_status_no_such_op)", + "no appropriate context found (srtp_err_status_no_ctx)", + "unable to perform desired validation (srtp_err_status_cant_check)", + "can’t use key any more (srtp_err_status_key_expired)", + "error in use of socket (srtp_err_status_socket_err)", + "error in use POSIX signals (srtp_err_status_signal_err)", + "nonce check failed (srtp_err_status_nonce_bad)", + "couldn’t read data (srtp_err_status_read_fail)", + "couldn’t write data (srtp_err_status_write_fail)", + "error parsing data (srtp_err_status_parse_err)", + "error encoding data (srtp_err_status_encode_err)", + "error while using semaphores (srtp_err_status_semaphore_err)", + "error while using pfkey (srtp_err_status_pfkey_err)"}; // clang-format on /* Static methods. */ -void DepLibSRTP::ClassInit() { - MS_TRACE(); + void DepLibSRTP::ClassInit() { + MS_TRACE(); - MS_DEBUG_TAG(info, "libsrtp version: \"%s\"", srtp_get_version_string()); + MS_DEBUG_TAG(info, "libsrtp version: \"%s\"", srtp_get_version_string()); - srtp_err_status_t err = srtp_init(); - - if (DepLibSRTP::IsError(err)) - MS_THROW_ERROR("srtp_init() failed: %s", DepLibSRTP::GetErrorString(err)); -} - -void DepLibSRTP::ClassDestroy() { - MS_TRACE(); - - srtp_shutdown(); -} - -void SrtpSession::ClassInit() { - // Set libsrtp event handler. - srtp_err_status_t err = - srtp_install_event_handler(static_cast(OnSrtpEvent)); - if (DepLibSRTP::IsError(err)) { - MS_THROW_ERROR("srtp_install_event_handler() failed: %s", DepLibSRTP::GetErrorString(err)); - std::cout << "srtp_install_event_handler() failed :" << DepLibSRTP::GetErrorString(err); - } -} - -void SrtpSession::OnSrtpEvent(srtp_event_data_t *data) { - MS_TRACE(); - - switch (data->event) { - case event_ssrc_collision: - MS_WARN_TAG(srtp, "SSRC collision occurred"); - break; - - case event_key_soft_limit: - MS_WARN_TAG(srtp, "stream reached the soft key usage limit and will expire soon"); - break; - - case event_key_hard_limit: - MS_WARN_TAG(srtp, "stream reached the hard key usage limit and has expired"); - break; - - case event_packet_index_limit: - MS_WARN_TAG(srtp, "stream reached the hard packet limit (2^48 packets)"); - break; - } -} - -/* Instance methods. */ - -SrtpSession::SrtpSession(Type type, CryptoSuite cryptoSuite, uint8_t *key, size_t keyLen) { - MS_TRACE(); - - srtp_policy_t policy;// NOLINT(cppcoreguidelines-pro-type-member-init) - - // Set all policy fields to 0. - std::memset(&policy, 0, sizeof(srtp_policy_t)); - - switch (cryptoSuite) { - case CryptoSuite::AES_CM_128_HMAC_SHA1_80: { - srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&policy.rtp); - srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&policy.rtcp); - - break; - } - - case CryptoSuite::AES_CM_128_HMAC_SHA1_32: { - srtp_crypto_policy_set_aes_cm_128_hmac_sha1_32(&policy.rtp); - // NOTE: Must be 80 for RTCP. - srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&policy.rtcp); - - break; - } - - case CryptoSuite::AEAD_AES_256_GCM: { - srtp_crypto_policy_set_aes_gcm_256_16_auth(&policy.rtp); - srtp_crypto_policy_set_aes_gcm_256_16_auth(&policy.rtcp); - - break; - } - - case CryptoSuite::AEAD_AES_128_GCM: { - srtp_crypto_policy_set_aes_gcm_128_16_auth(&policy.rtp); - srtp_crypto_policy_set_aes_gcm_128_16_auth(&policy.rtcp); - - break; - } - - default: { - MS_ABORT("unknown SRTP crypto suite"); - } - } - - MS_ASSERT((int) keyLen == policy.rtp.cipher_key_len, - "given keyLen does not match policy.rtp.cipher_keyLen"); - - switch (type) { - case Type::INBOUND: - policy.ssrc.type = ssrc_any_inbound; - break; - - case Type::OUTBOUND: - policy.ssrc.type = ssrc_any_outbound; - break; - } - - policy.ssrc.value = 0; - policy.key = key; - // Required for sending RTP retransmission without RTX. - policy.allow_repeat_tx = 1; - policy.window_size = 1024; - policy.next = nullptr; - - // Set the SRTP session. - srtp_err_status_t err = srtp_create(&this->session, &policy); - if (DepLibSRTP::IsError(err)) { - is_init = false; - MS_THROW_ERROR("srtp_create() failed: %s", DepLibSRTP::GetErrorString(err)); - } else { - is_init = true; - } -} - -SrtpSession::~SrtpSession() { - MS_TRACE(); - - if (this->session != nullptr) { - srtp_err_status_t err = srtp_dealloc(this->session); + srtp_err_status_t err = srtp_init(); if (DepLibSRTP::IsError(err)) - MS_ABORT("srtp_dealloc() failed: %s", DepLibSRTP::GetErrorString(err)); - } -} - -bool SrtpSession::EncryptRtp(const uint8_t **data, size_t *len) { - MS_TRACE(); - if (!is_init) { - return false; - } - // Ensure that the resulting SRTP packet fits into the encrypt buffer. - if (*len + SRTP_MAX_TRAILER_LEN > EncryptBufferSize) { - MS_WARN_TAG(srtp, "cannot encrypt RTP packet, size too big (%zu bytes)", *len); - - return false; - } - std::memcpy(EncryptBuffer, *data, *len); - - srtp_err_status_t err = - srtp_protect(this->session, static_cast(EncryptBuffer), reinterpret_cast(len)); - - if (DepLibSRTP::IsError(err)) { - MS_WARN_TAG(srtp, "srtp_protect() failed: %s", DepLibSRTP::GetErrorString(err)); - - return false; + MS_THROW_ERROR("srtp_init() failed: %s", DepLibSRTP::GetErrorString(err)); } - // Update the given data pointer. - *data = (const uint8_t *) EncryptBuffer; + void DepLibSRTP::ClassDestroy() { + MS_TRACE(); - return true; -} - -bool SrtpSession::DecryptSrtp(uint8_t *data, size_t *len) { - MS_TRACE(); - - srtp_err_status_t err = - srtp_unprotect(this->session, static_cast(data), reinterpret_cast(len)); - - if (DepLibSRTP::IsError(err)) { - MS_DEBUG_TAG(srtp, "srtp_unprotect() failed: %s", DepLibSRTP::GetErrorString(err)); - - return false; + srtp_shutdown(); } - return true; -} + /* Class methods. */ -bool SrtpSession::EncryptRtcp(const uint8_t **data, size_t *len) { - MS_TRACE(); + void SrtpSession::ClassInit() + { + // Set libsrtp event handler. + srtp_err_status_t err = + srtp_install_event_handler(static_cast(OnSrtpEvent)); - // Ensure that the resulting SRTCP packet fits into the encrypt buffer. - if (*len + SRTP_MAX_TRAILER_LEN > EncryptBufferSize) { - MS_WARN_TAG(srtp, "cannot encrypt RTCP packet, size too big (%zu bytes)", *len); + if (DepLibSRTP::IsError(err)) + { + MS_THROW_ERROR("srtp_install_event_handler() failed: %s", DepLibSRTP::GetErrorString(err)); + } + } - return false; - } + void SrtpSession::OnSrtpEvent(srtp_event_data_t* data) + { + MS_TRACE(); - std::memcpy(EncryptBuffer, *data, *len); + switch (data->event) + { + case event_ssrc_collision: + MS_WARN_TAG(srtp, "SSRC collision occurred"); + break; - srtp_err_status_t err = srtp_protect_rtcp(this->session, static_cast(EncryptBuffer), - reinterpret_cast(len)); + case event_key_soft_limit: + MS_WARN_TAG(srtp, "stream reached the soft key usage limit and will expire soon"); + break; - if (DepLibSRTP::IsError(err)) { - MS_WARN_TAG(srtp, "srtp_protect_rtcp() failed: %s", DepLibSRTP::GetErrorString(err)); + case event_key_hard_limit: + MS_WARN_TAG(srtp, "stream reached the hard key usage limit and has expired"); + break; - return false; - } + case event_packet_index_limit: + MS_WARN_TAG(srtp, "stream reached the hard packet limit (2^48 packets)"); + break; + } + } - // Update the given data pointer. - *data = (const uint8_t *) EncryptBuffer; + /* Instance methods. */ - return true; -} + SrtpSession::SrtpSession(Type type, CryptoSuite cryptoSuite, uint8_t* key, size_t keyLen) + { + MS_TRACE(); -bool SrtpSession::DecryptSrtcp(uint8_t *data, size_t *len) { - MS_TRACE(); + srtp_policy_t policy; // NOLINT(cppcoreguidelines-pro-type-member-init) - srtp_err_status_t err = - srtp_unprotect_rtcp(this->session, static_cast(data), reinterpret_cast(len)); + // Set all policy fields to 0. + std::memset(&policy, 0, sizeof(srtp_policy_t)); - if (DepLibSRTP::IsError(err)) { - MS_DEBUG_TAG(srtp, "srtp_unprotect_rtcp() failed: %s", DepLibSRTP::GetErrorString(err)); + switch (cryptoSuite) + { + case CryptoSuite::AES_CM_128_HMAC_SHA1_80: + { + srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&policy.rtp); + srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&policy.rtcp); - return false; - } + break; + } - return true; -} -}// namespace RTC + case CryptoSuite::AES_CM_128_HMAC_SHA1_32: + { + srtp_crypto_policy_set_aes_cm_128_hmac_sha1_32(&policy.rtp); + // NOTE: Must be 80 for RTCP. + srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&policy.rtcp); + + break; + } + + case CryptoSuite::AEAD_AES_256_GCM: + { + srtp_crypto_policy_set_aes_gcm_256_16_auth(&policy.rtp); + srtp_crypto_policy_set_aes_gcm_256_16_auth(&policy.rtcp); + + break; + } + + case CryptoSuite::AEAD_AES_128_GCM: + { + srtp_crypto_policy_set_aes_gcm_128_16_auth(&policy.rtp); + srtp_crypto_policy_set_aes_gcm_128_16_auth(&policy.rtcp); + + break; + } + + default: + { + MS_ABORT("unknown SRTP crypto suite"); + } + } + + MS_ASSERT( + (int)keyLen == policy.rtp.cipher_key_len, + "given keyLen does not match policy.rtp.cipher_keyLen"); + + switch (type) + { + case Type::INBOUND: + policy.ssrc.type = ssrc_any_inbound; + break; + + case Type::OUTBOUND: + policy.ssrc.type = ssrc_any_outbound; + break; + } + + policy.ssrc.value = 0; + policy.key = key; + // Required for sending RTP retransmission without RTX. + policy.allow_repeat_tx = 1; + policy.window_size = 1024; + policy.next = nullptr; + + // Set the SRTP session. + srtp_err_status_t err = srtp_create(&this->session, &policy); + + if (DepLibSRTP::IsError(err)) + MS_THROW_ERROR("srtp_create() failed: %s", DepLibSRTP::GetErrorString(err)); + } + + SrtpSession::~SrtpSession() + { + MS_TRACE(); + + if (this->session != nullptr) + { + srtp_err_status_t err = srtp_dealloc(this->session); + + if (DepLibSRTP::IsError(err)) + MS_ABORT("srtp_dealloc() failed: %s", DepLibSRTP::GetErrorString(err)); + } + } + + bool SrtpSession::EncryptRtp(const uint8_t** data, size_t* len) + { + MS_TRACE(); + + // Ensure that the resulting SRTP packet fits into the encrypt buffer. + if (*len + SRTP_MAX_TRAILER_LEN > EncryptBufferSize) + { + MS_WARN_TAG(srtp, "cannot encrypt RTP packet, size too big (%zu bytes)", *len); + + return false; + } + + std::memcpy(EncryptBuffer, *data, *len); + + srtp_err_status_t err = + srtp_protect(this->session, static_cast(EncryptBuffer), reinterpret_cast(len)); + + if (DepLibSRTP::IsError(err)) + { + MS_WARN_TAG(srtp, "srtp_protect() failed: %s", DepLibSRTP::GetErrorString(err)); + + return false; + } + + // Update the given data pointer. + *data = (const uint8_t*)EncryptBuffer; + + return true; + } + + bool SrtpSession::DecryptSrtp(uint8_t* data, size_t* len) + { + MS_TRACE(); + + srtp_err_status_t err = + srtp_unprotect(this->session, static_cast(data), reinterpret_cast(len)); + + if (DepLibSRTP::IsError(err)) + { + MS_DEBUG_TAG(srtp, "srtp_unprotect() failed: %s", DepLibSRTP::GetErrorString(err)); + + return false; + } + + return true; + } + + bool SrtpSession::EncryptRtcp(const uint8_t** data, size_t* len) + { + MS_TRACE(); + + // Ensure that the resulting SRTCP packet fits into the encrypt buffer. + if (*len + SRTP_MAX_TRAILER_LEN > EncryptBufferSize) + { + MS_WARN_TAG(srtp, "cannot encrypt RTCP packet, size too big (%zu bytes)", *len); + + return false; + } + + std::memcpy(EncryptBuffer, *data, *len); + + srtp_err_status_t err = srtp_protect_rtcp( + this->session, static_cast(EncryptBuffer), reinterpret_cast(len)); + + if (DepLibSRTP::IsError(err)) + { + MS_WARN_TAG(srtp, "srtp_protect_rtcp() failed: %s", DepLibSRTP::GetErrorString(err)); + + return false; + } + + // Update the given data pointer. + *data = (const uint8_t*)EncryptBuffer; + + return true; + } + + bool SrtpSession::DecryptSrtcp(uint8_t* data, size_t* len) + { + MS_TRACE(); + + srtp_err_status_t err = + srtp_unprotect_rtcp(this->session, static_cast(data), reinterpret_cast(len)); + + if (DepLibSRTP::IsError(err)) + { + MS_DEBUG_TAG(srtp, "srtp_unprotect_rtcp() failed: %s", DepLibSRTP::GetErrorString(err)); + + return false; + } + + return true; + } +} // namespace RTC diff --git a/webrtc/srtp_session.h b/webrtc/srtp_session.h index 95da5fbf..a2cc30de 100644 --- a/webrtc/srtp_session.h +++ b/webrtc/srtp_session.h @@ -1,54 +1,69 @@ #ifndef MS_RTC_SRTP_SESSION_HPP #define MS_RTC_SRTP_SESSION_HPP -#include "rtc_dtls_transport.h" #include "utils.h" #include #include -namespace RTC { +namespace RTC +{ + class DepLibSRTP { + public: + static void ClassInit(); + static void ClassDestroy(); + static bool IsError(srtp_err_status_t code) { return (code != srtp_err_status_ok); } + static const char *GetErrorString(srtp_err_status_t code) { + // This throws out_of_range if the given index is not in the vector. + return DepLibSRTP::errors.at(code); + } -class DepLibSRTP { -public: - static void ClassInit(); - static void ClassDestroy(); - static bool IsError(srtp_err_status_t code) { return (code != srtp_err_status_ok); } - static const char *GetErrorString(srtp_err_status_t code) { - // This throws out_of_range if the given index is not in the vector. - return DepLibSRTP::errors.at(code); - } + private: + static std::vector errors; + }; -private: - static std::vector errors; -}; + class SrtpSession + { + public: + enum class CryptoSuite + { + NONE = 0, + AES_CM_128_HMAC_SHA1_80 = 1, + AES_CM_128_HMAC_SHA1_32, + AEAD_AES_256_GCM, + AEAD_AES_128_GCM + }; -class SrtpSession { -public: -public: - enum class Type { INBOUND = 1, OUTBOUND }; + public: + enum class Type + { + INBOUND = 1, + OUTBOUND + }; -public: - static void ClassInit(); + public: + static void ClassInit(); -private: - static void OnSrtpEvent(srtp_event_data_t *data); + private: + static void OnSrtpEvent(srtp_event_data_t* data); -public: - SrtpSession(Type type, CryptoSuite cryptoSuite, uint8_t *key, size_t keyLen); - ~SrtpSession(); + public: + SrtpSession(Type type, CryptoSuite cryptoSuite, uint8_t* key, size_t keyLen); + ~SrtpSession(); -public: - bool EncryptRtp(const uint8_t **data, size_t *len); - bool DecryptSrtp(uint8_t *data, size_t *len); - bool EncryptRtcp(const uint8_t **data, size_t *len); - bool DecryptSrtcp(uint8_t *data, size_t *len); - void RemoveStream(uint32_t ssrc) { srtp_remove_stream(this->session, uint32_t{htonl(ssrc)}); } + public: + bool EncryptRtp(const uint8_t** data, size_t* len); + bool DecryptSrtp(uint8_t* data, size_t* len); + bool EncryptRtcp(const uint8_t** data, size_t* len); + bool DecryptSrtcp(uint8_t* data, size_t* len); + void RemoveStream(uint32_t ssrc) + { + srtp_remove_stream(this->session, uint32_t{ htonl(ssrc) }); + } -private: - bool is_init = false; - // Allocated by this. - srtp_t session{nullptr}; -}; -}// namespace RTC + private: + // Allocated by this. + srtp_t session{ nullptr }; + }; +} // namespace RTC #endif diff --git a/webrtc/stun_packet.cc b/webrtc/stun_packet.cc index 926f863b..1977f7f1 100644 --- a/webrtc/stun_packet.cc +++ b/webrtc/stun_packet.cc @@ -6,10 +6,79 @@ #include // std::snprintf() #include // std::memcmp(), std::memcpy() -#include "utils.h" - namespace RTC { +static const uint32_t crc32Table[] = +{ + 0x00000000, 0x77073096, 0xee0e612c, 0x990951ba, 0x076dc419, 0x706af48f, 0xe963a535, 0x9e6495a3, + 0x0edb8832, 0x79dcb8a4, 0xe0d5e91e, 0x97d2d988, 0x09b64c2b, 0x7eb17cbd, 0xe7b82d07, 0x90bf1d91, + 0x1db71064, 0x6ab020f2, 0xf3b97148, 0x84be41de, 0x1adad47d, 0x6ddde4eb, 0xf4d4b551, 0x83d385c7, + 0x136c9856, 0x646ba8c0, 0xfd62f97a, 0x8a65c9ec, 0x14015c4f, 0x63066cd9, 0xfa0f3d63, 0x8d080df5, + 0x3b6e20c8, 0x4c69105e, 0xd56041e4, 0xa2677172, 0x3c03e4d1, 0x4b04d447, 0xd20d85fd, 0xa50ab56b, + 0x35b5a8fa, 0x42b2986c, 0xdbbbc9d6, 0xacbcf940, 0x32d86ce3, 0x45df5c75, 0xdcd60dcf, 0xabd13d59, + 0x26d930ac, 0x51de003a, 0xc8d75180, 0xbfd06116, 0x21b4f4b5, 0x56b3c423, 0xcfba9599, 0xb8bda50f, + 0x2802b89e, 0x5f058808, 0xc60cd9b2, 0xb10be924, 0x2f6f7c87, 0x58684c11, 0xc1611dab, 0xb6662d3d, + 0x76dc4190, 0x01db7106, 0x98d220bc, 0xefd5102a, 0x71b18589, 0x06b6b51f, 0x9fbfe4a5, 0xe8b8d433, + 0x7807c9a2, 0x0f00f934, 0x9609a88e, 0xe10e9818, 0x7f6a0dbb, 0x086d3d2d, 0x91646c97, 0xe6635c01, + 0x6b6b51f4, 0x1c6c6162, 0x856530d8, 0xf262004e, 0x6c0695ed, 0x1b01a57b, 0x8208f4c1, 0xf50fc457, + 0x65b0d9c6, 0x12b7e950, 0x8bbeb8ea, 0xfcb9887c, 0x62dd1ddf, 0x15da2d49, 0x8cd37cf3, 0xfbd44c65, + 0x4db26158, 0x3ab551ce, 0xa3bc0074, 0xd4bb30e2, 0x4adfa541, 0x3dd895d7, 0xa4d1c46d, 0xd3d6f4fb, + 0x4369e96a, 0x346ed9fc, 0xad678846, 0xda60b8d0, 0x44042d73, 0x33031de5, 0xaa0a4c5f, 0xdd0d7cc9, + 0x5005713c, 0x270241aa, 0xbe0b1010, 0xc90c2086, 0x5768b525, 0x206f85b3, 0xb966d409, 0xce61e49f, + 0x5edef90e, 0x29d9c998, 0xb0d09822, 0xc7d7a8b4, 0x59b33d17, 0x2eb40d81, 0xb7bd5c3b, 0xc0ba6cad, + 0xedb88320, 0x9abfb3b6, 0x03b6e20c, 0x74b1d29a, 0xead54739, 0x9dd277af, 0x04db2615, 0x73dc1683, + 0xe3630b12, 0x94643b84, 0x0d6d6a3e, 0x7a6a5aa8, 0xe40ecf0b, 0x9309ff9d, 0x0a00ae27, 0x7d079eb1, + 0xf00f9344, 0x8708a3d2, 0x1e01f268, 0x6906c2fe, 0xf762575d, 0x806567cb, 0x196c3671, 0x6e6b06e7, + 0xfed41b76, 0x89d32be0, 0x10da7a5a, 0x67dd4acc, 0xf9b9df6f, 0x8ebeeff9, 0x17b7be43, 0x60b08ed5, + 0xd6d6a3e8, 0xa1d1937e, 0x38d8c2c4, 0x4fdff252, 0xd1bb67f1, 0xa6bc5767, 0x3fb506dd, 0x48b2364b, + 0xd80d2bda, 0xaf0a1b4c, 0x36034af6, 0x41047a60, 0xdf60efc3, 0xa867df55, 0x316e8eef, 0x4669be79, + 0xcb61b38c, 0xbc66831a, 0x256fd2a0, 0x5268e236, 0xcc0c7795, 0xbb0b4703, 0x220216b9, 0x5505262f, + 0xc5ba3bbe, 0xb2bd0b28, 0x2bb45a92, 0x5cb36a04, 0xc2d7ffa7, 0xb5d0cf31, 0x2cd99e8b, 0x5bdeae1d, + 0x9b64c2b0, 0xec63f226, 0x756aa39c, 0x026d930a, 0x9c0906a9, 0xeb0e363f, 0x72076785, 0x05005713, + 0x95bf4a82, 0xe2b87a14, 0x7bb12bae, 0x0cb61b38, 0x92d28e9b, 0xe5d5be0d, 0x7cdcefb7, 0x0bdbdf21, + 0x86d3d2d4, 0xf1d4e242, 0x68ddb3f8, 0x1fda836e, 0x81be16cd, 0xf6b9265b, 0x6fb077e1, 0x18b74777, + 0x88085ae6, 0xff0f6a70, 0x66063bca, 0x11010b5c, 0x8f659eff, 0xf862ae69, 0x616bffd3, 0x166ccf45, + 0xa00ae278, 0xd70dd2ee, 0x4e048354, 0x3903b3c2, 0xa7672661, 0xd06016f7, 0x4969474d, 0x3e6e77db, + 0xaed16a4a, 0xd9d65adc, 0x40df0b66, 0x37d83bf0, 0xa9bcae53, 0xdebb9ec5, 0x47b2cf7f, 0x30b5ffe9, + 0xbdbdf21c, 0xcabac28a, 0x53b39330, 0x24b4a3a6, 0xbad03605, 0xcdd70693, 0x54de5729, 0x23d967bf, + 0xb3667a2e, 0xc4614ab8, 0x5d681b02, 0x2a6f2b94, 0xb40bbe37, 0xc30c8ea1, 0x5a05df1b, 0x2d02ef8d +}; + +inline uint32_t GetCRC32(const uint8_t *data, size_t size) { + uint32_t crc{0xFFFFFFFF}; + const uint8_t *p = data; + + while (size--) { + crc = crc32Table[(crc ^ *p++) & 0xFF] ^ (crc >> 8); + } + + return crc ^ ~0U; +} + +static std::string openssl_HMACsha1(const void *key, size_t key_len, const void *data, size_t data_len){ + std::string str; + str.resize(20); + unsigned int out_len; +#if defined(OPENSSL_VERSION_NUMBER) && (OPENSSL_VERSION_NUMBER > 0x10100000L) + //openssl 1.1.0新增api,老版本api作废 + HMAC_CTX *ctx = HMAC_CTX_new(); + HMAC_CTX_reset(ctx); + HMAC_Init_ex(ctx, key, (int)key_len, EVP_sha1(), NULL); + HMAC_Update(ctx, (unsigned char*)data, data_len); + HMAC_Final(ctx, (unsigned char *)str.data(), &out_len); + HMAC_CTX_reset(ctx); + HMAC_CTX_free(ctx); +#else + HMAC_CTX ctx; + HMAC_CTX_init(&ctx); + HMAC_Init_ex(&ctx, key, key_len, EVP_sha1(), NULL); + HMAC_Update(&ctx, (unsigned char*)data, data_len); + HMAC_Final(&ctx, (unsigned char *)str.data(), &out_len); + HMAC_CTX_cleanup(&ctx); +#endif //defined(OPENSSL_VERSION_NUMBER) && (OPENSSL_VERSION_NUMBER > 0x10100000L) + return str; +} + /* Class variables. */ const uint8_t StunPacket::kMagicCookie[] = {0x21, 0x12, 0xA4, 0x42}; @@ -258,7 +327,7 @@ StunPacket* StunPacket::Parse(const uint8_t* data, size_t len) { if (hasFingerprint) { // Compute the CRC32 of the received packet up to (but excluding) the // FINGERPRINT attribute and XOR it with 0x5354554e. - uint32_t computedFingerprint = Utils::Crypto::GetCRC32(data, fingerprintAttrPos) ^ 0x5354554e; + uint32_t computedFingerprint = GetCRC32(data, fingerprintAttrPos) ^ 0x5354554e; // Compare with the FINGERPRINT value in the packet. if (fingerprint != computedFingerprint) { @@ -290,79 +359,6 @@ StunPacket::~StunPacket() { // MS_TRACE(); } -void StunPacket::Dump() const { - // MS_TRACE(); - - // MS_DUMP(""); - - std::string klass; - switch (this->klass) { - case Class::REQUEST: - klass = "Request"; - break; - case Class::INDICATION: - klass = "Indication"; - break; - case Class::SUCCESS_RESPONSE: - klass = "SuccessResponse"; - break; - case Class::ERROR_RESPONSE: - klass = "ErrorResponse"; - break; - } - if (this->method == Method::BINDING) { - // MS_DUMP(" Binding %s", klass.c_str()); - } else { - // This prints the unknown method number. Example: TURN Allocate => 0x003. - // MS_DUMP(" %s with unknown method %#.3x", klass.c_str(), - // static_cast(this->method)); - } - // MS_DUMP(" size: %zu bytes", this->size); - - static char transactionId[25]; - - for (int i{0}; i < 12; ++i) { - // NOTE: n must be 3 because snprintf adds a \0 after printed chars. - std::snprintf(transactionId + (i * 2), 3, "%.2x", this->transactionId[i]); - } - // MS_DUMP(" transactionId: %s", transactionId); - if (this->errorCode != 0u) - // MS_DUMP(" errorCode: %" PRIu16, this->errorCode); - if (!this->username.empty()) - // MS_DUMP(" username: %s", this->username.c_str()); - if (this->priority != 0u) - // MS_DUMP(" priority: %" PRIu32, this->priority); - if (this->iceControlling != 0u) - // MS_DUMP(" iceControlling: %" PRIu64, this->iceControlling); - if (this->iceControlled != 0u) - // MS_DUMP(" iceControlled: %" PRIu64, this->iceControlled); - if (this->hasUseCandidate) - // MS_DUMP(" useCandidate"); - if (this->xorMappedAddress != nullptr) { - int family; - uint16_t port; - std::string ip; - - Utils::IP::GetAddressInfo(this->xorMappedAddress, family, ip, port); - - // MS_DUMP(" xorMappedAddress: %s : %" PRIu16, ip.c_str(), port); - } - if (this->messageIntegrity != nullptr) { - static char messageIntegrity[41]; - - for (int i{0}; i < 20; ++i) { - std::snprintf(messageIntegrity + (i * 2), 3, "%.2x", this->messageIntegrity[i]); - } - - // MS_DUMP(" messageIntegrity: %s", messageIntegrity); - } - if (this->hasFingerprint) { - } - // MS_DUMP(" has fingerprint"); - - // MS_DUMP(""); -} - StunPacket::Authentication StunPacket::CheckAuthentication(const std::string& localUsername, const std::string& localPassword) { // MS_TRACE(); @@ -402,13 +398,13 @@ StunPacket::Authentication StunPacket::CheckAuthentication(const std::string& lo Utils::Byte::Set2Bytes(this->data, 2, static_cast(this->size - 20 - 8)); // Calculate the HMAC-SHA1 of the message according to MESSAGE-INTEGRITY rules. - const uint8_t* computedMessageIntegrity = Utils::Crypto::GetHmacShA1( - localPassword, this->data, (this->messageIntegrity - 4) - this->data); + auto computedMessageIntegrity = openssl_HMACsha1( + localPassword.data(),localPassword.size(), this->data, (this->messageIntegrity - 4) - this->data); Authentication result; // Compare the computed HMAC-SHA1 with the MESSAGE-INTEGRITY in the packet. - if (std::memcmp(this->messageIntegrity, computedMessageIntegrity, 20) == 0) + if (std::memcmp(this->messageIntegrity, computedMessageIntegrity.data(), computedMessageIntegrity.size()) == 0) result = Authentication::OK; else result = Authentication::UNAUTHORIZED; @@ -670,12 +666,11 @@ void StunPacket::Serialize(uint8_t* buffer) { Utils::Byte::Set2Bytes(buffer, 2, static_cast(this->size - 20 - 8)); // Calculate the HMAC-SHA1 of the packet according to MESSAGE-INTEGRITY rules. - const uint8_t* computedMessageIntegrity = - Utils::Crypto::GetHmacShA1(this->password, buffer, pos); + auto computedMessageIntegrity = openssl_HMACsha1(this->password.data(), this->password.size(), buffer, pos); Utils::Byte::Set2Bytes(buffer, pos, static_cast(Attribute::MESSAGE_INTEGRITY)); Utils::Byte::Set2Bytes(buffer, pos + 2, 20); - std::memcpy(buffer + pos + 4, computedMessageIntegrity, 20); + std::memcpy(buffer + pos + 4, computedMessageIntegrity.data(), computedMessageIntegrity.size()); // Update the pointer. this->messageIntegrity = buffer + pos + 4; @@ -692,7 +687,7 @@ void StunPacket::Serialize(uint8_t* buffer) { if (addFingerprint) { // Compute the CRC32 of the packet up to (but excluding) the FINGERPRINT // attribute and XOR it with 0x5354554e. - uint32_t computedFingerprint = Utils::Crypto::GetCRC32(buffer, pos) ^ 0x5354554e; + uint32_t computedFingerprint = GetCRC32(buffer, pos) ^ 0x5354554e; Utils::Byte::Set2Bytes(buffer, pos, static_cast(Attribute::FINGERPRINT)); Utils::Byte::Set2Bytes(buffer, pos + 2, 4); diff --git a/webrtc/utils.cc b/webrtc/utils.cc deleted file mode 100644 index ab9f1fda..00000000 --- a/webrtc/utils.cc +++ /dev/null @@ -1,139 +0,0 @@ -#define MS_CLASS "Utils::Crypto" -// #define MS_LOG_DEV - -#include "utils.h" - -#include "openssl/sha.h" - -namespace Utils { -/* Static variables. */ - -uint32_t Crypto::seed; -HMAC_CTX *Crypto::hmacSha1Ctx{nullptr}; -uint8_t Crypto::hmacSha1Buffer[20];// SHA-1 result is 20 bytes long. -// clang-format off -const uint32_t Crypto::crc32Table[] = -{ - 0x00000000, 0x77073096, 0xee0e612c, 0x990951ba, 0x076dc419, 0x706af48f, 0xe963a535, 0x9e6495a3, - 0x0edb8832, 0x79dcb8a4, 0xe0d5e91e, 0x97d2d988, 0x09b64c2b, 0x7eb17cbd, 0xe7b82d07, 0x90bf1d91, - 0x1db71064, 0x6ab020f2, 0xf3b97148, 0x84be41de, 0x1adad47d, 0x6ddde4eb, 0xf4d4b551, 0x83d385c7, - 0x136c9856, 0x646ba8c0, 0xfd62f97a, 0x8a65c9ec, 0x14015c4f, 0x63066cd9, 0xfa0f3d63, 0x8d080df5, - 0x3b6e20c8, 0x4c69105e, 0xd56041e4, 0xa2677172, 0x3c03e4d1, 0x4b04d447, 0xd20d85fd, 0xa50ab56b, - 0x35b5a8fa, 0x42b2986c, 0xdbbbc9d6, 0xacbcf940, 0x32d86ce3, 0x45df5c75, 0xdcd60dcf, 0xabd13d59, - 0x26d930ac, 0x51de003a, 0xc8d75180, 0xbfd06116, 0x21b4f4b5, 0x56b3c423, 0xcfba9599, 0xb8bda50f, - 0x2802b89e, 0x5f058808, 0xc60cd9b2, 0xb10be924, 0x2f6f7c87, 0x58684c11, 0xc1611dab, 0xb6662d3d, - 0x76dc4190, 0x01db7106, 0x98d220bc, 0xefd5102a, 0x71b18589, 0x06b6b51f, 0x9fbfe4a5, 0xe8b8d433, - 0x7807c9a2, 0x0f00f934, 0x9609a88e, 0xe10e9818, 0x7f6a0dbb, 0x086d3d2d, 0x91646c97, 0xe6635c01, - 0x6b6b51f4, 0x1c6c6162, 0x856530d8, 0xf262004e, 0x6c0695ed, 0x1b01a57b, 0x8208f4c1, 0xf50fc457, - 0x65b0d9c6, 0x12b7e950, 0x8bbeb8ea, 0xfcb9887c, 0x62dd1ddf, 0x15da2d49, 0x8cd37cf3, 0xfbd44c65, - 0x4db26158, 0x3ab551ce, 0xa3bc0074, 0xd4bb30e2, 0x4adfa541, 0x3dd895d7, 0xa4d1c46d, 0xd3d6f4fb, - 0x4369e96a, 0x346ed9fc, 0xad678846, 0xda60b8d0, 0x44042d73, 0x33031de5, 0xaa0a4c5f, 0xdd0d7cc9, - 0x5005713c, 0x270241aa, 0xbe0b1010, 0xc90c2086, 0x5768b525, 0x206f85b3, 0xb966d409, 0xce61e49f, - 0x5edef90e, 0x29d9c998, 0xb0d09822, 0xc7d7a8b4, 0x59b33d17, 0x2eb40d81, 0xb7bd5c3b, 0xc0ba6cad, - 0xedb88320, 0x9abfb3b6, 0x03b6e20c, 0x74b1d29a, 0xead54739, 0x9dd277af, 0x04db2615, 0x73dc1683, - 0xe3630b12, 0x94643b84, 0x0d6d6a3e, 0x7a6a5aa8, 0xe40ecf0b, 0x9309ff9d, 0x0a00ae27, 0x7d079eb1, - 0xf00f9344, 0x8708a3d2, 0x1e01f268, 0x6906c2fe, 0xf762575d, 0x806567cb, 0x196c3671, 0x6e6b06e7, - 0xfed41b76, 0x89d32be0, 0x10da7a5a, 0x67dd4acc, 0xf9b9df6f, 0x8ebeeff9, 0x17b7be43, 0x60b08ed5, - 0xd6d6a3e8, 0xa1d1937e, 0x38d8c2c4, 0x4fdff252, 0xd1bb67f1, 0xa6bc5767, 0x3fb506dd, 0x48b2364b, - 0xd80d2bda, 0xaf0a1b4c, 0x36034af6, 0x41047a60, 0xdf60efc3, 0xa867df55, 0x316e8eef, 0x4669be79, - 0xcb61b38c, 0xbc66831a, 0x256fd2a0, 0x5268e236, 0xcc0c7795, 0xbb0b4703, 0x220216b9, 0x5505262f, - 0xc5ba3bbe, 0xb2bd0b28, 0x2bb45a92, 0x5cb36a04, 0xc2d7ffa7, 0xb5d0cf31, 0x2cd99e8b, 0x5bdeae1d, - 0x9b64c2b0, 0xec63f226, 0x756aa39c, 0x026d930a, 0x9c0906a9, 0xeb0e363f, 0x72076785, 0x05005713, - 0x95bf4a82, 0xe2b87a14, 0x7bb12bae, 0x0cb61b38, 0x92d28e9b, 0xe5d5be0d, 0x7cdcefb7, 0x0bdbdf21, - 0x86d3d2d4, 0xf1d4e242, 0x68ddb3f8, 0x1fda836e, 0x81be16cd, 0xf6b9265b, 0x6fb077e1, 0x18b74777, - 0x88085ae6, 0xff0f6a70, 0x66063bca, 0x11010b5c, 0x8f659eff, 0xf862ae69, 0x616bffd3, 0x166ccf45, - 0xa00ae278, 0xd70dd2ee, 0x4e048354, 0x3903b3c2, 0xa7672661, 0xd06016f7, 0x4969474d, 0x3e6e77db, - 0xaed16a4a, 0xd9d65adc, 0x40df0b66, 0x37d83bf0, 0xa9bcae53, 0xdebb9ec5, 0x47b2cf7f, 0x30b5ffe9, - 0xbdbdf21c, 0xcabac28a, 0x53b39330, 0x24b4a3a6, 0xbad03605, 0xcdd70693, 0x54de5729, 0x23d967bf, - 0xb3667a2e, 0xc4614ab8, 0x5d681b02, 0x2a6f2b94, 0xb40bbe37, 0xc30c8ea1, 0x5a05df1b, 0x2d02ef8d -}; -// clang-format on - -/* Static methods. */ - -void Crypto::ClassInit() { - // MS_TRACE(); - - // Init the vrypto seed with a random number taken from the address - // of the seed variable itself (which is random). - Crypto::seed = static_cast(reinterpret_cast(std::addressof(Crypto::seed))); - - // Create an OpenSSL HMAC_CTX context for HMAC SHA1 calculation. - // Crypto::hmacSha1Ctx = HMAC_CTX_new(); - if (Crypto::hmacSha1Ctx == nullptr) { - Crypto::hmacSha1Ctx = HMAC_CTX_new(); - } -} - -void Crypto::ClassDestroy() { - // MS_TRACE(); - - if (Crypto::hmacSha1Ctx != nullptr) { - HMAC_CTX_free(Crypto::hmacSha1Ctx); - } -} - -const uint8_t *Crypto::GetHmacShA1(const std::string &key, const uint8_t *data, size_t len) { - // MS_TRACE(); - - size_t ret; - - ret = HMAC_Init_ex(Crypto::hmacSha1Ctx, key.c_str(), key.length(), EVP_sha1(), nullptr); - - // MS_ASSERT(ret == 1, "OpenSSL HMAC_Init_ex() failed with key '%s'", key.c_str()); - - ret = HMAC_Update(Crypto::hmacSha1Ctx, data, static_cast(len)); - /* - MS_ASSERT( - ret == 1, - "OpenSSL HMAC_Update() failed with key '%s' and data length %zu bytes", - key.c_str(), - len); - */ - uint32_t resultLen; - - ret = HMAC_Final(Crypto::hmacSha1Ctx, (uint8_t *) Crypto::hmacSha1Buffer, &resultLen); - - /* - MS_ASSERT( - ret == 1, "OpenSSL HMAC_Final() failed with key '%s' and data length %zu bytes", key.c_str(), - len); MS_ASSERT(resultLen == 20, "OpenSSL HMAC_Final() resultLen is %u instead of 20", resultLen); - */ - return Crypto::hmacSha1Buffer; -} -}// namespace Utils - -namespace Utils { - -static std::string inet_ntoa(struct in_addr in) { - char buf[20]; - unsigned char *p = (unsigned char *) &(in); - snprintf(buf, sizeof(buf), "%u.%u.%u.%u", p[0], p[1], p[2], p[3]); - return buf; -} - -void IP::GetAddressInfo(const struct sockaddr *addr, int &family, std::string &ip, uint16_t &port) { - char ipBuffer[INET6_ADDRSTRLEN + 1]; - - switch (addr->sa_family) { - case AF_INET: { - ip = Utils::inet_ntoa(reinterpret_cast(addr)->sin_addr); - port = static_cast(ntohs(reinterpret_cast(addr)->sin_port)); - break; - } - - case AF_INET6: { - port = static_cast(ntohs(reinterpret_cast(addr)->sin6_port)); - break; - } - - default: { - // MS_ABORT("unknown network family: %d", static_cast(addr->sa_family)); - } - } - - family = addr->sa_family; - ip.assign(ipBuffer); -} - -}// namespace Utils \ No newline at end of file diff --git a/webrtc/utils.h b/webrtc/utils.h index 1cd6ba9d..6b33b9f5 100644 --- a/webrtc/utils.h +++ b/webrtc/utils.h @@ -30,76 +30,6 @@ #include namespace Utils { -class IP { -public: - static int GetFamily(const char *ip, size_t ipLen); - static int GetFamily(const std::string &ip); - static void GetAddressInfo(const struct sockaddr *addr, int &family, std::string &ip, - uint16_t &port); - static bool CompareAddresses(const struct sockaddr *addr1, const struct sockaddr *addr2); - static struct sockaddr_storage CopyAddress(const struct sockaddr *addr); - static void NormalizeIp(std::string &ip); -}; - -/* Inline static methods. */ - -inline int IP::GetFamily(const std::string &ip) { return GetFamily(ip.c_str(), ip.size()); } - -inline bool IP::CompareAddresses(const struct sockaddr *addr1, const struct sockaddr *addr2) { - // Compare family. - if (addr1->sa_family != addr2->sa_family || - (addr1->sa_family != AF_INET && addr1->sa_family != AF_INET6)) { - return false; - } - - // Compare port. - if (reinterpret_cast(addr1)->sin_port != - reinterpret_cast(addr2)->sin_port) { - return false; - } - - // Compare IP. - switch (addr1->sa_family) { - case AF_INET: { - return (reinterpret_cast(addr1)->sin_addr.s_addr == - reinterpret_cast(addr2)->sin_addr.s_addr); - } - - case AF_INET6: { - return (std::memcmp( - std::addressof(reinterpret_cast(addr1)->sin6_addr), - std::addressof(reinterpret_cast(addr2)->sin6_addr), - 16) == 0 - ? true - : false); - } - - default: { - return false; - } - } -} - -inline struct sockaddr_storage IP::CopyAddress(const struct sockaddr *addr) { - struct sockaddr_storage copiedAddr; - - switch (addr->sa_family) { - case AF_INET: - std::memcpy(std::addressof(copiedAddr), addr, sizeof(struct sockaddr_in)); - break; - - case AF_INET6: - std::memcpy(std::addressof(copiedAddr), addr, sizeof(struct sockaddr_in6)); - break; - } - - return copiedAddr; -} - -class File { -public: - static void CheckFile(const char *file); -}; class Byte { public: @@ -181,138 +111,6 @@ inline uint16_t Byte::PadTo4Bytes(uint16_t size) { return size; } -inline uint32_t Byte::PadTo4Bytes(uint32_t size) { - // If size is not multiple of 32 bits then pad it. - if (size & 0x03) - return (size & 0xFFFFFFFC) + 4; - else - return size; -} - -class Bits { -public: - static size_t CountSetBits(const uint16_t mask); -}; - -/* Inline static methods. */ - -class Crypto { -public: - static void ClassInit(); - static void ClassDestroy(); - static uint32_t GetRandomUInt(uint32_t min, uint32_t max); - static const std::string GetRandomString(size_t len); - static uint32_t GetCRC32(const uint8_t *data, size_t size); - static const uint8_t *GetHmacShA1(const std::string &key, const uint8_t *data, size_t len); - -private: - static uint32_t seed; - static HMAC_CTX *hmacSha1Ctx; - static uint8_t hmacSha1Buffer[]; - static const uint32_t crc32Table[256]; -}; - -/* Inline static methods. */ - -inline uint32_t Crypto::GetRandomUInt(uint32_t min, uint32_t max) { - // NOTE: This is the original, but produces very small values. - // Crypto::seed = (214013 * Crypto::seed) + 2531011; - // return (((Crypto::seed>>16)&0x7FFF) % (max - min + 1)) + min; - - // This seems to produce better results. - Crypto::seed = uint32_t{((214013 * Crypto::seed) + 2531011)}; - - return (((Crypto::seed >> 4) & 0x7FFF7FFF) % (max - min + 1)) + min; -} - -inline const std::string Crypto::GetRandomString(size_t len) { - static char buffer[64]; - static const char chars[] = {'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', - 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', - 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z'}; - - if (len > 64) len = 64; - - for (size_t i{0}; i < len; ++i) { - buffer[i] = chars[GetRandomUInt(0, sizeof(chars) - 1)]; - } - - return std::string(buffer, len); -} - -inline uint32_t Crypto::GetCRC32(const uint8_t *data, size_t size) { - uint32_t crc{0xFFFFFFFF}; - const uint8_t *p = data; - - while (size--) { - crc = Crypto::crc32Table[(crc ^ *p++) & 0xFF] ^ (crc >> 8); - } - - return crc ^ ~0U; -} - -class String { -public: - static void ToLowerCase(std::string &str); -}; - -inline void String::ToLowerCase(std::string &str) { - std::transform(str.begin(), str.end(), str.begin(), ::tolower); -} - -class Time { - // Seconds from Jan 1, 1900 to Jan 1, 1970. - static constexpr uint32_t UnixNtpOffset{0x83AA7E80}; - // NTP fractional unit. - static constexpr uint64_t NtpFractionalUnit{1LL << 32}; - -public: - struct Ntp { - uint32_t seconds; - uint32_t fractions; - }; - - static Time::Ntp TimeMs2Ntp(uint64_t ms); - static uint64_t Ntp2TimeMs(Time::Ntp ntp); - static bool IsNewerTimestamp(uint32_t timestamp, uint32_t prevTimestamp); - static uint32_t LatestTimestamp(uint32_t timestamp1, uint32_t timestamp2); -}; - -inline Time::Ntp Time::TimeMs2Ntp(uint64_t ms) { - Time::Ntp ntp;// NOLINT(cppcoreguidelines-pro-type-member-init) - - ntp.seconds = uint32_t(ms / 1000); - ntp.fractions = - static_cast((static_cast(ms % 1000) / 1000) * NtpFractionalUnit); - - return ntp; -} - -inline uint64_t Time::Ntp2TimeMs(Time::Ntp ntp) { - // clang-format off - return ( - static_cast(ntp.seconds) * 1000 + - static_cast(std::round((static_cast(ntp.fractions) * 1000) / NtpFractionalUnit)) - ); - // clang-format on -} - -inline bool Time::IsNewerTimestamp(uint32_t timestamp, uint32_t prevTimestamp) { - // Distinguish between elements that are exactly 0x80000000 apart. - // If t1>t2 and |t1-t2| = 0x80000000: IsNewer(t1,t2)=true, - // IsNewer(t2,t1)=false - // rather than having IsNewer(t1,t2) = IsNewer(t2,t1) = false. - if (static_cast(timestamp - prevTimestamp) == 0x80000000) - return timestamp > prevTimestamp; - - return timestamp != prevTimestamp && - static_cast(timestamp - prevTimestamp) < 0x80000000; -} - -inline uint32_t Time::LatestTimestamp(uint32_t timestamp1, uint32_t timestamp2) { - return IsNewerTimestamp(timestamp1, timestamp2) ? timestamp1 : timestamp2; -} - }// namespace Utils #endif diff --git a/webrtc/webrtc_transport.cc b/webrtc/webrtc_transport.cc index a475f2ac..f064b6db 100644 --- a/webrtc/webrtc_transport.cc +++ b/webrtc/webrtc_transport.cc @@ -4,31 +4,81 @@ WebRtcTransport::WebRtcTransport() { static onceToken token([](){ - Utils::Crypto::ClassInit(); RTC::DtlsTransport::ClassInit(); RTC::DepLibSRTP::ClassInit(); RTC::SrtpSession::ClassInit(); }); - ice_server_ = std::make_shared(Utils::Crypto::GetRandomString(4), Utils::Crypto::GetRandomString(24)); - ice_server_->SetIceServerCompletedCB([this]() { - this->OnIceServerCompleted(); - }); - ice_server_->SetSendCB([this](char *buf, size_t len, struct sockaddr_in *remote_address) { - this->WritePacket(buf, len, remote_address); - }); - - // todo dtls服务器或客户端模式 - dtls_transport_ = std::make_shared(true); - dtls_transport_->SetHandshakeCompletedCB([this](std::string client_key, std::string server_key, RTC::CryptoSuite srtp_crypto_suite) { - this->OnDtlsCompleted(client_key, server_key, srtp_crypto_suite); - }); - dtls_transport_->SetOutPutCB([this](char *buf, size_t len) { this->WritePacket(buf, len); }); + dtls_transport_ = std::make_shared(EventPollerPool::Instance().getFirstPoller(), this); + ice_server_ = std::make_shared(this, makeRandStr(4), makeRandStr(24)); } WebRtcTransport::~WebRtcTransport() {} +////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +void WebRtcTransport::OnIceServerSendStunPacket(const RTC::IceServer *iceServer, const RTC::StunPacket *packet, RTC::TransportTuple *tuple) { + onWrite((char *)packet->GetData(), packet->GetSize(), (struct sockaddr_in *)tuple); +} + +void WebRtcTransport::OnIceServerSelectedTuple(const RTC::IceServer *iceServer, RTC::TransportTuple *tuple) { + InfoL; +} + +void WebRtcTransport::OnIceServerConnected(const RTC::IceServer *iceServer) { + InfoL; + dtls_transport_->Run(RTC::DtlsTransport::Role::SERVER); +} + +void WebRtcTransport::OnIceServerCompleted(const RTC::IceServer *iceServer) { + InfoL; +} + +void WebRtcTransport::OnIceServerDisconnected(const RTC::IceServer *iceServer) { + InfoL; +} + +////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +void WebRtcTransport::OnDtlsTransportConnected( + const RTC::DtlsTransport *dtlsTransport, + RTC::SrtpSession::CryptoSuite srtpCryptoSuite, + uint8_t *srtpLocalKey, + size_t srtpLocalKeyLen, + uint8_t *srtpRemoteKey, + size_t srtpRemoteKeyLen, + std::string &remoteCert) { + InfoL; + srtp_session_ = std::make_shared(RTC::SrtpSession::Type::OUTBOUND, srtpCryptoSuite, srtpLocalKey, srtpLocalKeyLen); + onDtlsConnected(); +} + +void WebRtcTransport::OnDtlsTransportSendData(const RTC::DtlsTransport *dtlsTransport, const uint8_t *data, size_t len) { + onWrite((char *)data, len); +} + +////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +void WebRtcTransport::onWrite(const char *buf, size_t len){ + auto tuple = ice_server_->GetSelectedTuple(); + assert(tuple); + onWrite(buf, len, (struct sockaddr_in *)tuple); +} + std::string WebRtcTransport::GetLocalSdp() { + RTC::DtlsTransport::Fingerprint remote_fingerprint; + remote_fingerprint.algorithm = RTC::DtlsTransport::GetFingerprintAlgorithm("sha-256"); + remote_fingerprint.value = ""; + dtls_transport_->SetRemoteFingerprint(remote_fingerprint); + + string finger_print_sha256; + auto finger_prints = dtls_transport_->GetLocalFingerprints(); + for (size_t i = 0; i < finger_prints.size(); i++) { + if (finger_prints[i].algorithm == RTC::DtlsTransport::FingerprintAlgorithm::SHA256) { + finger_print_sha256 = finger_prints[i].value; + } + } + char sdp[1024 * 10] = {0}; auto ssrc = getSSRC(); auto ip = getIP(); @@ -60,22 +110,10 @@ std::string WebRtcTransport::GetLocalSdp() { "a=candidate:%s 1 udp %u %s %u typ %s\r\n", ip.c_str(), port, pt, ip.c_str(), ice_server_->GetUsernameFragment().c_str(),ice_server_->GetPassword().c_str(), - dtls_transport_->GetMyFingerprint().c_str(), pt, ssrc, ssrc, ssrc, ssrc, "4", ssrc, ip.c_str(), port, "host"); + finger_print_sha256.c_str(), pt, ssrc, ssrc, ssrc, ssrc, "4", ssrc, ip.c_str(), port, "host"); return sdp; } -void WebRtcTransport::OnIceServerCompleted() { - InfoL; - dtls_transport_->Start(); - onIceConnected(); -} - -void WebRtcTransport::OnDtlsCompleted(std::string client_key, std::string server_key, RTC::CryptoSuite srtp_crypto_suite) { - InfoL << client_key << " " << server_key << " " << (int)srtp_crypto_suite; - srtp_session_ = std::make_shared(RTC::SrtpSession::Type::OUTBOUND, srtp_crypto_suite, (uint8_t *) client_key.c_str(), client_key.size()); - onDtlsCompleted(); -} - bool is_dtls(char *buf) { return ((*buf > 19) && (*buf < 64)); } @@ -90,25 +128,23 @@ bool is_rtcp(char *buf) { return ((header->pt >= 64) && (header->pt < 96)); } -void WebRtcTransport::OnInputDataPacket(char *buf, size_t len, struct sockaddr_in *remote_address) { +void WebRtcTransport::OnInputDataPacket(char *buf, size_t len, RTC::TransportTuple *tuple) { if (RTC::StunPacket::IsStun((const uint8_t *) buf, len)) { - InfoL << "stun:" << hexdump(buf, len); RTC::StunPacket *packet = RTC::StunPacket::Parse((const uint8_t *) buf, len); if (packet == nullptr) { WarnL << "parse stun error" << std::endl; return; } - ice_server_->ProcessStunPacket(packet, remote_address); + ice_server_->ProcessStunPacket(packet, tuple); return; } - if (DtlsTransport::IsDtlsPacket(buf, len)) { - InfoL << "dtls:" << hexdump(buf, len); - dtls_transport_->InputData(buf, len); + if (is_dtls(buf)) { + dtls_transport_->ProcessDtlsData((uint8_t *)buf, len); return; } if (is_rtp(buf)) { RtpHeader *header = (RtpHeader *) buf; - InfoL << "rtp:" << header->dumpString(len); +// InfoL << "rtp:" << header->dumpString(len); return; } if (is_rtcp(buf)) { @@ -118,10 +154,6 @@ void WebRtcTransport::OnInputDataPacket(char *buf, size_t len, struct sockaddr_i } } -void WebRtcTransport::WritePacket(char *buf, size_t len, struct sockaddr_in *remote_address) { - onWrite(buf, len, remote_address ? remote_address : (ice_server_ ? ice_server_->GetSelectAddr() : nullptr)); -} - void WebRtcTransport::WritRtpPacket(char *buf, size_t len) { const uint8_t *p = (uint8_t *) buf; bool ret = false; @@ -129,7 +161,7 @@ void WebRtcTransport::WritRtpPacket(char *buf, size_t len) { ret = srtp_session_->EncryptRtp(&p, &len); } if (ret) { - onWrite((char *) p, len, ice_server_->GetSelectAddr()); + onWrite((char *) p, len); } } @@ -139,8 +171,8 @@ WebRtcTransportImp::WebRtcTransportImp(const EventPoller::Ptr &poller) { _socket = Socket::createSocket(poller, false); //随机端口,绑定全部网卡 _socket->bindUdpSock(0); - _socket->setOnRead([this](const Buffer::Ptr &buf, struct sockaddr *addr, int addr_len){ - OnInputDataPacket(buf->data(), buf->size(), (struct sockaddr_in*)addr); + _socket->setOnRead([this](const Buffer::Ptr &buf, struct sockaddr *addr, int addr_len) mutable { + OnInputDataPacket(buf->data(), buf->size(), addr); }); } @@ -149,7 +181,7 @@ void WebRtcTransportImp::attach(const RtspMediaSource::Ptr &src) { _src = src; } -void WebRtcTransportImp::onDtlsCompleted() { +void WebRtcTransportImp::onDtlsConnected() { _reader = _src->getRing()->attach(_socket->getPoller(), true); weak_ptr weak_self = shared_from_this(); _reader->setReadCB([weak_self](const RtspMediaSource::RingDataType &pkt){ @@ -167,14 +199,9 @@ void WebRtcTransportImp::onDtlsCompleted() { }); } -void WebRtcTransportImp::onIceConnected(){ - -} - void WebRtcTransportImp::onWrite(const char *buf, size_t len, struct sockaddr_in *dst) { auto ptr = BufferRaw::create(); ptr->assign(buf, len); -// InfoL << len << " " << SockUtil::inet_ntoa(dst->sin_addr) << " " << ntohs(dst->sin_port); _socket->send(ptr, (struct sockaddr *)(dst), sizeof(struct sockaddr)); } @@ -201,15 +228,5 @@ std::string WebRtcTransportImp::getIP() const { /////////////////////////////////////////////////////////////////// -INSTANCE_IMP(WebRtcManager) - -WebRtcManager::WebRtcManager() { - -} - -WebRtcManager::~WebRtcManager() { - -} - diff --git a/webrtc/webrtc_transport.h b/webrtc/webrtc_transport.h index 2289e5f1..b9fcf347 100644 --- a/webrtc/webrtc_transport.h +++ b/webrtc/webrtc_transport.h @@ -3,12 +3,12 @@ #include #include -#include "dtls_transport.h" +#include "rtc_dtls_transport.h" #include "ice_server.h" #include "srtp_session.h" #include "stun_packet.h" -class WebRtcTransport { +class WebRtcTransport : public RTC::DtlsTransport::Listener, public RTC::IceServer::Listener { public: using Ptr = std::shared_ptr; WebRtcTransport(); @@ -22,13 +22,38 @@ public: /// \param buf /// \param len /// \param remote_address - void OnInputDataPacket(char *buf, size_t len, struct sockaddr_in *remote_address); + void OnInputDataPacket(char *buf, size_t len, RTC::TransportTuple *tuple); /// 发送rtp /// \param buf /// \param len void WritRtpPacket(char *buf, size_t len); +protected: + // dtls相关的回调 + void OnDtlsTransportConnecting(const RTC::DtlsTransport *dtlsTransport) override {}; + void OnDtlsTransportConnected( + const RTC::DtlsTransport *dtlsTransport, + RTC::SrtpSession::CryptoSuite srtpCryptoSuite, + uint8_t *srtpLocalKey, + size_t srtpLocalKeyLen, + uint8_t *srtpRemoteKey, + size_t srtpRemoteKeyLen, + std::string &remoteCert) override; + + void OnDtlsTransportFailed(const RTC::DtlsTransport *dtlsTransport) override {}; + void OnDtlsTransportClosed(const RTC::DtlsTransport *dtlsTransport) override {}; + void OnDtlsTransportSendData(const RTC::DtlsTransport *dtlsTransport, const uint8_t *data, size_t len) override; + void OnDtlsTransportApplicationDataReceived(const RTC::DtlsTransport *dtlsTransport, const uint8_t *data, size_t len) override {}; + +protected: + //ice相关的回调 + void OnIceServerSendStunPacket(const RTC::IceServer *iceServer, const RTC::StunPacket *packet, RTC::TransportTuple *tuple) override; + void OnIceServerSelectedTuple(const RTC::IceServer *iceServer, RTC::TransportTuple *tuple) override; + void OnIceServerConnected(const RTC::IceServer *iceServer) override; + void OnIceServerCompleted(const RTC::IceServer *iceServer) override; + void OnIceServerDisconnected(const RTC::IceServer *iceServer) override; + protected: /// 输出udp数据 /// \param buf @@ -39,17 +64,14 @@ protected: virtual uint16_t getPort() const = 0; virtual std::string getIP() const = 0; virtual int getPayloadType() const = 0; - virtual void onIceConnected() = 0; - virtual void onDtlsCompleted() = 0; + virtual void onDtlsConnected() = 0; private: - void OnIceServerCompleted(); - void OnDtlsCompleted(std::string client_key, std::string server_key, RTC::CryptoSuite srtp_crypto_suite); - void WritePacket(char *buf, size_t len, struct sockaddr_in *remote_address = nullptr); + void onWrite(const char *buf, size_t len); private: - IceServer::Ptr ice_server_; - DtlsTransport::Ptr dtls_transport_; + std::shared_ptr ice_server_; + std::shared_ptr dtls_transport_; std::shared_ptr srtp_session_; }; @@ -74,8 +96,7 @@ protected: uint32_t getSSRC() const override; uint16_t getPort() const override; std::string getIP() const override; - void onIceConnected() override; - void onDtlsCompleted() override; + void onDtlsConnected() override; private: Socket::Ptr _socket; @@ -83,16 +104,6 @@ private: RtspMediaSource::RingType::RingReader::Ptr _reader; }; -class WebRtcManager : public std::enable_shared_from_this { -public: - ~WebRtcManager(); - static WebRtcManager& Instance(); - -private: - WebRtcManager(); - -}; - diff --git a/www/webrtc/index.html b/www/webrtc/index.html index a9603957..fbc4c2d2 100644 --- a/www/webrtc/index.html +++ b/www/webrtc/index.html @@ -22,7 +22,7 @@

ip_address

- +