From 65e470e0607bc6c24aa089cf5b64b5ba9301f752 Mon Sep 17 00:00:00 2001 From: ziyue <1213642868@qq.com> Date: Wed, 24 Mar 2021 16:52:41 +0800 Subject: [PATCH] =?UTF-8?q?=E5=88=9D=E6=AD=A5=E6=B7=BB=E5=8A=A0rtsp?= =?UTF-8?q?=E8=BD=ACwebrtc=E7=9B=B8=E5=85=B3=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- CMakeLists.txt | 6 + cmake/FindSRTP.cmake | 55 ++ server/CMakeLists.txt | 1 + server/WebApi.cpp | 24 + webrtc/CMakeLists.txt | 16 + webrtc/dtls_transport.cc | 75 ++ webrtc/dtls_transport.h | 58 ++ webrtc/ice_server.cc | 201 ++++++ webrtc/ice_server.h | 40 + webrtc/logger.h | 18 + webrtc/rtc_dtls_transport.cc | 1323 ++++++++++++++++++++++++++++++++++ webrtc/rtc_dtls_transport.h | 187 +++++ webrtc/srtp_session.cc | 269 +++++++ webrtc/srtp_session.h | 54 ++ webrtc/stun_packet.cc | 710 ++++++++++++++++++ webrtc/stun_packet.h | 179 +++++ webrtc/utils.cc | 139 ++++ webrtc/utils.h | 318 ++++++++ webrtc/webrtc_transport.cc | 215 ++++++ webrtc/webrtc_transport.h | 112 +++ 20 files changed, 4000 insertions(+) create mode 100644 cmake/FindSRTP.cmake create mode 100644 webrtc/CMakeLists.txt create mode 100644 webrtc/dtls_transport.cc create mode 100644 webrtc/dtls_transport.h create mode 100644 webrtc/ice_server.cc create mode 100644 webrtc/ice_server.h create mode 100644 webrtc/logger.h create mode 100644 webrtc/rtc_dtls_transport.cc create mode 100644 webrtc/rtc_dtls_transport.h create mode 100644 webrtc/srtp_session.cc create mode 100644 webrtc/srtp_session.h create mode 100644 webrtc/stun_packet.cc create mode 100644 webrtc/stun_packet.h create mode 100644 webrtc/utils.cc create mode 100644 webrtc/utils.h create mode 100644 webrtc/webrtc_transport.cc create mode 100644 webrtc/webrtc_transport.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 0db8dcea..36a93a12 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -55,6 +55,7 @@ option(ENABLE_TESTS "Enable Tests" true) option(ENABLE_SERVER "Enable Server" true) option(ENABLE_MEM_DEBUG "Enable Memory Debug" false) option(ENABLE_ASAN "Enable Address Sanitize" false) +option(ENABLE_WEBRTC "Enable WebRTC" true) if (ENABLE_MEM_DEBUG) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,-wrap,free -Wl,-wrap,malloc -Wl,-wrap,realloc -Wl,-wrap,calloc") @@ -224,6 +225,11 @@ if(ENABLE_API) add_subdirectory(api) endif() +if(ENABLE_WEBRTC) + add_definitions(-DENABLE_WEBRTC) + add_subdirectory(webrtc) +endif() + if (NOT IOS) #测试程序 if(ENABLE_TESTS) diff --git a/cmake/FindSRTP.cmake b/cmake/FindSRTP.cmake new file mode 100644 index 00000000..3046020e --- /dev/null +++ b/cmake/FindSRTP.cmake @@ -0,0 +1,55 @@ +############################################################################ +# FindSRTP.txt +# Copyright (C) 2014 Belledonne Communications, Grenoble France +# +############################################################################ +# +# This program is free software; you can redistribute it and/or +# modify it under the terms of the GNU General Public License +# as published by the Free Software Foundation; either version 2 +# of the License, or (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program; if not, write to the Free Software +# Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. +# +############################################################################ +# +# - Find the SRTP include file and library +# +# SRTP_FOUND - system has SRTP +# SRTP_INCLUDE_DIRS - the SRTP include directory +# SRTP_LIBRARIES - The libraries needed to use SRTP + +set(_SRTP_ROOT_PATHS + ${CMAKE_INSTALL_PREFIX} + ) + +find_path(SRTP_INCLUDE_DIRS + NAMES srtp2/srtp.h + HINTS _SRTP_ROOT_PATHS + PATH_SUFFIXES include + ) + +if(SRTP_INCLUDE_DIRS) + set(HAVE_SRTP_SRTP_H 1) +endif() + +find_library(SRTP_LIBRARIES + NAMES srtp2 + HINTS ${_SRTP_ROOT_PATHS} + PATH_SUFFIXES bin lib + ) + +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args(SRTP + DEFAULT_MSG + SRTP_INCLUDE_DIRS SRTP_LIBRARIES HAVE_SRTP_SRTP_H + ) + +mark_as_advanced(SRTP_INCLUDE_DIRS SRTP_LIBRARIES HAVE_SRTP_SRTP_H) \ No newline at end of file diff --git a/server/CMakeLists.txt b/server/CMakeLists.txt index 6e0af5e1..c0b3a009 100644 --- a/server/CMakeLists.txt +++ b/server/CMakeLists.txt @@ -49,3 +49,4 @@ else() endif() target_link_libraries(MediaServer jsoncpp ${LINK_LIB_LIST}) +message(${LINK_LIB_LIST}) diff --git a/server/WebApi.cpp b/server/WebApi.cpp index a70b12d1..4d0dba28 100755 --- a/server/WebApi.cpp +++ b/server/WebApi.cpp @@ -35,6 +35,9 @@ #if defined(ENABLE_RTPPROXY) #include "Rtp/RtpServer.h" #endif +#ifdef ENABLE_WEBRTC +#include "../webrtc/webrtc_transport.h" +#endif using namespace toolkit; using namespace mediakit; @@ -1049,6 +1052,27 @@ void installWebApi() { #endif }); +#ifdef ENABLE_WEBRTC + static list rtcs; + api_regist("/webrtc",[](API_ARGS_MAP_ASYNC){ + CHECK_ARGS("app", "stream"); + auto src = dynamic_pointer_cast(MediaSource::find(RTSP_SCHEMA, DEFAULT_VHOST, allArgs["app"], allArgs["stream"])); + if (!src) { + throw ApiRetException("流不存在", API::NotFound); + } + headerOut["Content-Type"] = "text/plain"; + headerOut["Access-Control-Allow-Origin"] = "*"; + auto poller = EventPollerPool::Instance().getFirstPoller(); + auto rtc = std::make_shared(poller); + poller->async([invoker, rtc, headerOut, src]() { + rtc->attach(src); + auto sdp = rtc->GetLocalSdp(); + invoker(200, headerOut, sdp); + rtcs.emplace_back(rtc); + }); + }); +#endif + ////////////以下是注册的Hook API//////////// api_regist("/index/hook/on_publish",[](API_ARGS_MAP){ //开始推流事件 diff --git a/webrtc/CMakeLists.txt b/webrtc/CMakeLists.txt new file mode 100644 index 00000000..03039aea --- /dev/null +++ b/webrtc/CMakeLists.txt @@ -0,0 +1,16 @@ +list(APPEND LINK_LIB_LIST webrtc) +#查找srtp是否安装 +find_package(SRTP QUIET) +if (SRTP_FOUND) + message(STATUS "found library:${SRTP_LIBRARIES}") + include_directories(${SRTP_INCLUDE_DIRS}) + list(APPEND LINK_LIB_LIST ${SRTP_LIBRARIES}) +else () + message(FATAL_ERROR "srtp未找到!") +endif () + +include_directories(./) +file(GLOB SRC_LIST ./*.*) +add_library(webrtc ${SRC_LIST}) +set(LINK_LIB_LIST ${LINK_LIB_LIST} PARENT_SCOPE) + diff --git a/webrtc/dtls_transport.cc b/webrtc/dtls_transport.cc new file mode 100644 index 00000000..69e1f402 --- /dev/null +++ b/webrtc/dtls_transport.cc @@ -0,0 +1,75 @@ +// +// Created by xueyuegui on 19-12-7. +// + +#include "dtls_transport.h" + +#include + +DtlsTransport::DtlsTransport(bool is_server) : is_server_(is_server) { + dtls_transport_.reset(new RTC::DtlsTransport(this)); +} + +DtlsTransport::~DtlsTransport() {} + +void DtlsTransport::Start() { + if (is_server_) { + dtls_transport_->Run(RTC::DtlsTransport::Role::SERVER); + } else { + dtls_transport_->Run(RTC::DtlsTransport::Role::CLIENT); + } +} + +void DtlsTransport::Close() {} + +void DtlsTransport::OnDtlsTransportConnecting(const RTC::DtlsTransport *dtlsTransport) {} + +void DtlsTransport::OnDtlsTransportConnected(const RTC::DtlsTransport *dtlsTransport, + RTC::CryptoSuite srtp_crypto_suite, + uint8_t *srtpLocalKey, size_t srtpLocalKeyLen, + uint8_t *srtpRemoteKey, size_t srtpRemoteKeyLen, + std::string &remoteCert) { + std::string client_key; + std::string server_key; + server_key.assign((char *) srtpLocalKey, srtpLocalKeyLen); + client_key.assign((char *) srtpRemoteKey, srtpRemoteKeyLen); + if (is_server_) { + // If we are server, we swap the keys + client_key.swap(server_key); + } + if (handshake_completed_callback_) { + handshake_completed_callback_(client_key, server_key, srtp_crypto_suite); + } +} + +void DtlsTransport::OnDtlsTransportFailed(const RTC::DtlsTransport *dtlsTransport) { + if (handshake_failed_callback_) { + handshake_failed_callback_(); + } +} + +void DtlsTransport::OnDtlsTransportClosed(const RTC::DtlsTransport *dtlsTransport) {} + +void DtlsTransport::OnDtlsTransportSendData(const RTC::DtlsTransport *dtlsTransport, + const uint8_t *data, size_t len) { + if (output_callback_) { + output_callback_((char *) data, len); + } +} + +void DtlsTransport::OutputData(char *buf, size_t len) { + if (output_callback_) { + output_callback_(buf, len); + } +} + +void DtlsTransport::OnDtlsTransportApplicationDataReceived(const RTC::DtlsTransport *dtlsTransport, + const uint8_t *data, size_t len) {} + +bool DtlsTransport::IsDtlsPacket(const char *buf, size_t len) { + return RTC::DtlsTransport::IsDtls((uint8_t *) buf, len); +} + +void DtlsTransport::InputData(char *buf, size_t len) { + dtls_transport_->ProcessDtlsData((uint8_t *) buf, len); +} diff --git a/webrtc/dtls_transport.h b/webrtc/dtls_transport.h new file mode 100644 index 00000000..41f4c65c --- /dev/null +++ b/webrtc/dtls_transport.h @@ -0,0 +1,58 @@ +// +// Created by xueyuegui on 19-12-7. +// + +#ifndef MYWEBRTC_MYDTLSTRANSPORT_H +#define MYWEBRTC_MYDTLSTRANSPORT_H + +#include +#include + +#include "rtc_dtls_transport.h" + +class DtlsTransport : RTC::DtlsTransport::Listener { +public: + typedef std::shared_ptr Ptr; + + DtlsTransport(bool bServer); + ~DtlsTransport(); + + void Start(); + void Close(); + void InputData(char *buf, size_t len); + void OutputData(char *buf, size_t len); + static bool IsDtlsPacket(const char *buf, size_t len); + std::string GetMyFingerprint() { + auto finger_prints = dtls_transport_->GetLocalFingerprints(); + for (size_t i = 0; i < finger_prints.size(); i++) { + if (finger_prints[i].algorithm == RTC::DtlsTransport::FingerprintAlgorithm::SHA256) { + return finger_prints[i].value; + } + } + return ""; + }; + + void SetHandshakeCompletedCB(std::function cb) { + handshake_completed_callback_ = std::move(cb); + } + void SetHandshakeFailedCB(std::function cb) { handshake_failed_callback_ = std::move(cb); } + void SetOutPutCB(std::function cb) { output_callback_ = std::move(cb); } + + /* Pure virtual methods inherited from RTC::DtlsTransport::Listener. */ +public: + void OnDtlsTransportConnecting(const RTC::DtlsTransport *dtlsTransport) override; + void OnDtlsTransportConnected(const RTC::DtlsTransport *dtlsTransport, RTC::CryptoSuite srtpCryptoSuite, uint8_t *srtpLocalKey, size_t srtpLocalKeyLen, uint8_t *srtpRemoteKey, size_t srtpRemoteKeyLen, std::string &remoteCert) override; + void OnDtlsTransportFailed(const RTC::DtlsTransport *dtlsTransport) override; + void OnDtlsTransportClosed(const RTC::DtlsTransport *dtlsTransport) override; + void OnDtlsTransportSendData(const RTC::DtlsTransport *dtlsTransport, const uint8_t *data,size_t len) override; + void OnDtlsTransportApplicationDataReceived(const RTC::DtlsTransport *dtlsTransport, const uint8_t *data, size_t len) override; + +private: + bool is_server_ = false; + std::function handshake_failed_callback_; + std::shared_ptr dtls_transport_; + std::function output_callback_; + std::function handshake_completed_callback_; +}; + +#endif// MYWEBRTC_MYDTLSTRANSPORT_H diff --git a/webrtc/ice_server.cc b/webrtc/ice_server.cc new file mode 100644 index 00000000..cf7fb24b --- /dev/null +++ b/webrtc/ice_server.cc @@ -0,0 +1,201 @@ +#include "ice_server.h" + +#include + +static constexpr size_t StunSerializeBufferSize{65536}; +static uint8_t StunSerializeBuffer[StunSerializeBufferSize]; + +IceServer::IceServer() {} + +IceServer::~IceServer() {} + +IceServer::IceServer(const std::string &username_fragment, const std::string &password) + : username_fragment_(username_fragment), password_(password) {} + +void IceServer::ProcessStunPacket(RTC::StunPacket *packet, sockaddr_in *remote_address) { + // Must be a Binding method. + if (packet->GetMethod() != RTC::StunPacket::Method::BINDING) { + if (packet->GetClass() == RTC::StunPacket::Class::REQUEST) { + ELOG_WARN("unknown method %#.3x in STUN Request => 400", + static_cast(packet->GetMethod())); + ELOG_WARN("unknown method %#.3x in STUN Request => 400", + static_cast(packet->GetMethod())); + // Reply 400. + RTC::StunPacket *response = packet->CreateErrorResponse(400); + response->Serialize(StunSerializeBuffer); + if (send_callback_) { + send_callback_((char *) StunSerializeBuffer, response->GetSize(), remote_address); + } + delete response; + } else { + ELOG_WARN("ignoring STUN Indication or Response with unknown method %#.3x", + static_cast(packet->GetMethod())); + } + return; + } + + // Must use FINGERPRINT (optional for ICE STUN indications). + if (!packet->HasFingerprint() && packet->GetClass() != RTC::StunPacket::Class::INDICATION) { + if (packet->GetClass() == RTC::StunPacket::Class::REQUEST) { + ELOG_WARN("STUN Binding Request without FINGERPRINT => 400"); + // Reply 400. + RTC::StunPacket *response = packet->CreateErrorResponse(400); + response->Serialize(StunSerializeBuffer); + if (send_callback_) { + send_callback_((char *) StunSerializeBuffer, response->GetSize(), remote_address); + } + delete response; + } else { + ELOG_WARN("ignoring STUN Binding Response without FINGERPRINT"); + } + return; + } + + switch (packet->GetClass()) { + case RTC::StunPacket::Class::REQUEST: { + // USERNAME, MESSAGE-INTEGRITY and PRIORITY are required. + if (!packet->HasMessageIntegrity() || (packet->GetPriority() == 0u) || + packet->GetUsername().empty()) { + ELOG_WARN("mising required attributes in STUN Binding Request => 400"); + + // Reply 400. + RTC::StunPacket *response = packet->CreateErrorResponse(400); + response->Serialize(StunSerializeBuffer); + if (send_callback_) { + send_callback_((char *) StunSerializeBuffer, response->GetSize(), remote_address); + } + delete response; + return; + } + + // Check authentication. + switch (packet->CheckAuthentication(this->username_fragment_, this->password_)) { + case RTC::StunPacket::Authentication::OK: { + if (!this->old_password_.empty()) { + ELOG_DEBUG("kNew ICE credentials applied"); + this->old_username_fragment_.clear(); + this->old_password_.clear(); + } + break; + } + + case RTC::StunPacket::Authentication::UNAUTHORIZED: { + // We may have changed our username_fragment_ and password_, so check + // the old ones. + // clang-format off + if (!this->old_username_fragment_.empty() && + !this->old_password_.empty() && + packet->CheckAuthentication(this->old_username_fragment_, this->old_password_) == + RTC::StunPacket::Authentication::OK) { + ELOG_DEBUG("using old ICE credentials"); + break; + } + ELOG_WARN("wrong authentication in STUN Binding Request => 401"); + // Reply 401. + RTC::StunPacket *response = packet->CreateErrorResponse(401); + response->Serialize(StunSerializeBuffer); + if (send_callback_) { + send_callback_((char *) StunSerializeBuffer, response->GetSize(), remote_address); + } + delete response; + return; + } + + case RTC::StunPacket::Authentication::BAD_REQUEST: { + ELOG_WARN("cannot check authentication in STUN Binding Request => 400"); + // Reply 400. + RTC::StunPacket *response = packet->CreateErrorResponse(400); + response->Serialize(StunSerializeBuffer); + if (send_callback_) { + send_callback_((char *) StunSerializeBuffer, response->GetSize(), remote_address); + } + delete response; + return; + } + } + +#if 0 + // NOTE: Should be rejected with 487, but this makes Chrome happy: + // https://bugs.chromium.org/p/webrtc/issues/detail?id=7478 + // The remote peer must be ICE controlling. + if (packet->GetIceControlled()) { + MS_WARN_TAG(ice, "peer indicates ICE-CONTROLLED in STUN Binding Request => 487"); + // Reply 487 (Role Conflict). + RTC::StunPacket *response = packet->CreateErrorResponse(487); + response->Serialize(StunSerializeBuffer); + if (send_callback_) { + send_callback_((char *) StunSerializeBuffer, response->GetSize(), remote_address); + } + delete response; + return; + } +#endif + + ELOG_DEBUG("processing STUN Binding Request [Priority:%d, UseCandidate:%s]", + static_cast(packet->GetPriority()), + (packet->HasUseCandidate() ? "true" : "false")); + // Create a success response. + RTC::StunPacket *response = packet->CreateSuccessResponse(); + // Add XOR-MAPPED-ADDRESS. + // response->SetXorMappedAddress(tuple->GetRemoteAddress()); + response->SetXorMappedAddress((struct sockaddr *) remote_address); + // Authenticate the response. + if (this->old_password_.empty()) { + response->Authenticate(this->password_); + } else { + response->Authenticate(this->old_password_); + } + + // Send back. + response->Serialize(StunSerializeBuffer); + if (send_callback_) { + send_callback_((char *) StunSerializeBuffer, response->GetSize(), remote_address); + } + delete response; + // Handle the tuple. + HandleTuple(remote_address, packet->HasUseCandidate()); + break; + } + + case RTC::StunPacket::Class::INDICATION: { + ELOG_DEBUG("STUN Binding Indication processed"); + break; + } + + case RTC::StunPacket::Class::SUCCESS_RESPONSE: { + ELOG_DEBUG("STUN Binding Success Response processed"); + break; + } + + case RTC::StunPacket::Class::ERROR_RESPONSE: { + ELOG_DEBUG("STUN Binding Error Response processed"); + break; + } + } +} +void IceServer::HandleTuple(sockaddr_in *remote_address, bool has_use_candidate) { + remote_address_ = *remote_address; + if (has_use_candidate) { + this->state = IceState::kCompleted; + } + if (ice_server_completed_callback_) { + ice_server_completed_callback_(); + ice_server_completed_callback_ = nullptr; + } +} + +const std::string &IceServer::GetUsernameFragment() const { return this->username_fragment_; } + +const std::string &IceServer::GetPassword() const { return this->password_; } + +inline void IceServer::SetUsernameFragment(const std::string &username_fragment) { + this->old_username_fragment_ = this->username_fragment_; + this->username_fragment_ = username_fragment; +} + +inline void IceServer::SetPassword(const std::string &password) { + this->old_password_ = this->password_; + this->password_ = password; +} + +inline IceServer::IceState IceServer::GetState() const { return this->state; } \ No newline at end of file diff --git a/webrtc/ice_server.h b/webrtc/ice_server.h new file mode 100644 index 00000000..d33f26d2 --- /dev/null +++ b/webrtc/ice_server.h @@ -0,0 +1,40 @@ +#pragma once + +#include +#include + +#include "logger.h" +#include "stun_packet.h" + +typedef std::function UdpSendCallback; + +class IceServer { +public: + enum class IceState { kNew = 1, kConnect, kCompleted, kDisconnected }; + typedef std::shared_ptr Ptr; + IceServer(); + IceServer(const std::string &username_fragment, const std::string &password); + const std::string &GetUsernameFragment() const; + const std::string &GetPassword() const; + void SetUsernameFragment(const std::string &username_fragment); + void SetPassword(const std::string &password); + IceState GetState() const; + void ProcessStunPacket(RTC::StunPacket *packet, struct sockaddr_in *remote_address); + void HandleTuple(struct sockaddr_in *remote_address, bool has_use_candidate); + ~IceServer(); + void SetSendCB(UdpSendCallback send_cb) { send_callback_ = send_cb; } + void SetIceServerCompletedCB(std::function cb) { ice_server_completed_callback_ = cb; }; + struct sockaddr_in *GetSelectAddr() { + return &remote_address_; + } + +private: + UdpSendCallback send_callback_; + std::function ice_server_completed_callback_; + std::string username_fragment_; + std::string password_; + std::string old_username_fragment_; + std::string old_password_; + IceState state{IceState::kNew}; + struct sockaddr_in remote_address_; +}; diff --git a/webrtc/logger.h b/webrtc/logger.h new file mode 100644 index 00000000..a165536b --- /dev/null +++ b/webrtc/logger.h @@ -0,0 +1,18 @@ +#pragma once +#include +#include + +#define ELOG_DEBUG(fmt, ...) printf(fmt "\n", ##__VA_ARGS__) +#define ELOG_WARN(fmt, ...) printf(fmt "\n", ##__VA_ARGS__) + +#define MS_TRACE() +#define MS_ERROR(fmt, ...) printf("error:" fmt "\n", ##__VA_ARGS__) +#define MS_THROW_ERROR(fmt, ...) do{ printf("throw:" fmt "\n", ##__VA_ARGS__); throw std::runtime_error("error"); } while(false); +#define MS_DUMP(fmt, ...) printf("dump:" fmt "\n", ##__VA_ARGS__) +#define MS_DEBUG_2TAGS(tag1, tag2,fmt, ...) printf("debug:" fmt "\n", ##__VA_ARGS__) +#define MS_WARN_2TAGS(tag1, tag2,fmt, ...) printf("warn:" fmt "\n", ##__VA_ARGS__) +#define MS_DEBUG_TAG(tag,fmt, ...) printf("debug:" fmt "\n", ##__VA_ARGS__) +#define MS_ASSERT(con, log) assert(con) +#define MS_ABORT(fmt, ...) do{ printf("abort:" fmt "\n", ##__VA_ARGS__); abort(); } while(false); +#define MS_WARN_TAG(tag,fmt, ...) printf("warn:" fmt "\n", ##__VA_ARGS__) +#define MS_DEBUG_DEV(fmt, ...) printf("debug:" fmt "\n", ##__VA_ARGS__) \ No newline at end of file diff --git a/webrtc/rtc_dtls_transport.cc b/webrtc/rtc_dtls_transport.cc new file mode 100644 index 00000000..1087af4a --- /dev/null +++ b/webrtc/rtc_dtls_transport.cc @@ -0,0 +1,1323 @@ +#define MS_CLASS "RTC::DtlsTransport" +// #define MS_LOG_DEV_LEVEL 3 + +#include "rtc_dtls_transport.h" + +#include +#include +#include +#include +#include + +#include // std::sprintf(), std::fopen() +#include // std::memcpy(), std::strcmp() + +#include "logger.h" + +typedef struct { + long tv_sec; + long tv_usec; +} uv_timeval_t; + +#define LOG_OPENSSL_ERROR(desc) \ + do { \ + if (ERR_peek_error() == 0) \ + MS_ERROR("OpenSSL error [desc:'%s']", desc); \ + else { \ + int64_t err; \ + while ((err = ERR_get_error()) != 0) { \ + MS_ERROR("OpenSSL error [desc:'%s', error:'%s']", desc, ERR_error_string(err, nullptr)); \ + } \ + ERR_clear_error(); \ + } \ + } while (false) + +/* Static methods for OpenSSL callbacks. */ + +inline static int onSslCertificateVerify(int /*preverifyOk*/, X509_STORE_CTX* /*ctx*/) { + MS_TRACE(); + + // Always valid since DTLS certificates are self-signed. + return 1; +} + +inline static void onSslInfo(const SSL* ssl, int where, int ret) { + static_cast(SSL_get_ex_data(ssl, 0))->OnSslInfo(where, ret); +} + +inline static unsigned int onSslDtlsTimer(SSL* /*ssl*/, unsigned int timerUs) { + if (timerUs == 0) + return 100000; + else if (timerUs >= 4000000) + return 4000000; + else + return 2 * timerUs; +} + +namespace RTC { +/* Static. */ + +// clang-format off + static constexpr int DtlsMtu{ 1350 }; + static constexpr int SslReadBufferSize{ 65536 }; + // AES-HMAC: http://tools.ietf.org/html/rfc3711 + static constexpr size_t SrtpMasterKeyLength{ 16 }; + static constexpr size_t SrtpMasterSaltLength{ 14 }; + static constexpr size_t SrtpMasterLength{ SrtpMasterKeyLength + SrtpMasterSaltLength }; + // AES-GCM: http://tools.ietf.org/html/rfc7714 + static constexpr size_t SrtpAesGcm256MasterKeyLength{ 32 }; + static constexpr size_t SrtpAesGcm256MasterSaltLength{ 12 }; + static constexpr size_t SrtpAesGcm256MasterLength{ SrtpAesGcm256MasterKeyLength + SrtpAesGcm256MasterSaltLength }; + static constexpr size_t SrtpAesGcm128MasterKeyLength{ 16 }; + static constexpr size_t SrtpAesGcm128MasterSaltLength{ 12 }; + static constexpr size_t SrtpAesGcm128MasterLength{ SrtpAesGcm128MasterKeyLength + SrtpAesGcm128MasterSaltLength }; +// clang-format on + +/* Class variables. */ + +X509* DtlsTransport::certificate{nullptr}; +EVP_PKEY* DtlsTransport::privateKey{nullptr}; +SSL_CTX* DtlsTransport::sslCtx{nullptr}; +uint8_t DtlsTransport::sslReadBuffer[SslReadBufferSize]; +// clang-format off + std::map DtlsTransport::string2FingerprintAlgorithm = + { + { "sha-1", DtlsTransport::FingerprintAlgorithm::SHA1 }, + { "sha-224", DtlsTransport::FingerprintAlgorithm::SHA224 }, + { "sha-256", DtlsTransport::FingerprintAlgorithm::SHA256 }, + { "sha-384", DtlsTransport::FingerprintAlgorithm::SHA384 }, + { "sha-512", DtlsTransport::FingerprintAlgorithm::SHA512 } + }; + std::map DtlsTransport::fingerprintAlgorithm2String = + { + { DtlsTransport::FingerprintAlgorithm::SHA1, "sha-1" }, + { DtlsTransport::FingerprintAlgorithm::SHA224, "sha-224" }, + { DtlsTransport::FingerprintAlgorithm::SHA256, "sha-256" }, + { DtlsTransport::FingerprintAlgorithm::SHA384, "sha-384" }, + { DtlsTransport::FingerprintAlgorithm::SHA512, "sha-512" } + }; + std::map DtlsTransport::string2Role = + { + { "auto", DtlsTransport::Role::AUTO }, + { "client", DtlsTransport::Role::CLIENT }, + { "server", DtlsTransport::Role::SERVER } + }; + std::vector DtlsTransport::localFingerprints; + std::vector DtlsTransport::srtpCryptoSuites = + { + { RTC::CryptoSuite::AEAD_AES_256_GCM, "SRTP_AEAD_AES_256_GCM" }, + { RTC::CryptoSuite::AEAD_AES_128_GCM, "SRTP_AEAD_AES_128_GCM" }, + { RTC::CryptoSuite::AES_CM_128_HMAC_SHA1_80, "SRTP_AES128_CM_SHA1_80" }, + { RTC::CryptoSuite::AES_CM_128_HMAC_SHA1_32, "SRTP_AES128_CM_SHA1_32" } + }; +// clang-format on + +/* Class methods. */ + +void DtlsTransport::ClassInit() { + MS_TRACE(); + +#if 0 + // Generate a X509 certificate and private key (unless PEM files are provided). + if (Settings::configuration.dtlsCertificateFile.empty() || + Settings::configuration.dtlsPrivateKeyFile.empty()) { + GenerateCertificateAndPrivateKey(); + } else { + ReadCertificateAndPrivateKeyFromFiles(); + } +#else + GenerateCertificateAndPrivateKey(); +#endif + + // Create a global SSL_CTX. + CreateSslCtx(); + + // Generate certificate fingerprints. + GenerateFingerprints(); +} + +void DtlsTransport::ClassDestroy() { + MS_TRACE(); + + if (DtlsTransport::privateKey) EVP_PKEY_free(DtlsTransport::privateKey); + if (DtlsTransport::certificate) X509_free(DtlsTransport::certificate); + if (DtlsTransport::sslCtx) SSL_CTX_free(DtlsTransport::sslCtx); +} + +void DtlsTransport::GenerateCertificateAndPrivateKey() { + MS_TRACE(); + + int ret{0}; + EC_KEY* ecKey{nullptr}; + X509_NAME* certName{nullptr}; + std::string subject = std::string("mediasoup") + std::to_string(rand() % 999999 + 100000); + // std::string("mediasoup") + std::to_string(Utils::Crypto::GetRandomUInt(100000, 999999)); + + // Create key with curve. + ecKey = EC_KEY_new_by_curve_name(NID_X9_62_prime256v1); + + if (!ecKey) { + LOG_OPENSSL_ERROR("EC_KEY_new_by_curve_name() failed"); + + goto error; + } + + EC_KEY_set_asn1_flag(ecKey, OPENSSL_EC_NAMED_CURVE); + + // NOTE: This can take some time. + ret = EC_KEY_generate_key(ecKey); + + if (ret == 0) { + LOG_OPENSSL_ERROR("EC_KEY_generate_key() failed"); + + goto error; + } + + // Create a private key object. + DtlsTransport::privateKey = EVP_PKEY_new(); + + if (!DtlsTransport::privateKey) { + LOG_OPENSSL_ERROR("EVP_PKEY_new() failed"); + + goto error; + } + + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-cstyle-cast) + ret = EVP_PKEY_assign_EC_KEY(DtlsTransport::privateKey, ecKey); + + if (ret == 0) { + LOG_OPENSSL_ERROR("EVP_PKEY_assign_EC_KEY() failed"); + + goto error; + } + + // The EC key now belongs to the private key, so don't clean it up separately. + ecKey = nullptr; + + // Create the X509 certificate. + DtlsTransport::certificate = X509_new(); + + if (!DtlsTransport::certificate) { + LOG_OPENSSL_ERROR("X509_new() failed"); + + goto error; + } + + // Set version 3 (note that 0 means version 1). + X509_set_version(DtlsTransport::certificate, 2); + + // Set serial number (avoid default 0). + // ASN1_INTEGER_set(X509_get_serialNumber(DtlsTransport::certificate), + // static_cast(Utils::Crypto::GetRandomUInt(1000000, 9999999))); + ASN1_INTEGER_set(X509_get_serialNumber(DtlsTransport::certificate), + static_cast(rand() % 999999 + 100000)); + + // Set valid period. + X509_gmtime_adj(X509_get_notBefore(DtlsTransport::certificate), -315360000); // -10 years. + X509_gmtime_adj(X509_get_notAfter(DtlsTransport::certificate), 315360000); // 10 years. + + // Set the public key for the certificate using the key. + ret = X509_set_pubkey(DtlsTransport::certificate, DtlsTransport::privateKey); + + if (ret == 0) { + LOG_OPENSSL_ERROR("X509_set_pubkey() failed"); + + goto error; + } + + // Set certificate fields. + certName = X509_get_subject_name(DtlsTransport::certificate); + + if (!certName) { + LOG_OPENSSL_ERROR("X509_get_subject_name() failed"); + + goto error; + } + + X509_NAME_add_entry_by_txt(certName, "O", MBSTRING_ASC, + reinterpret_cast(subject.c_str()), -1, -1, 0); + X509_NAME_add_entry_by_txt(certName, "CN", MBSTRING_ASC, + reinterpret_cast(subject.c_str()), -1, -1, 0); + + // It is self-signed so set the issuer name to be the same as the subject. + ret = X509_set_issuer_name(DtlsTransport::certificate, certName); + + if (ret == 0) { + LOG_OPENSSL_ERROR("X509_set_issuer_name() failed"); + + goto error; + } + + // Sign the certificate with its own private key. + ret = X509_sign(DtlsTransport::certificate, DtlsTransport::privateKey, EVP_sha1()); + + if (ret == 0) { + LOG_OPENSSL_ERROR("X509_sign() failed"); + + goto error; + } + + return; + +error: + + if (ecKey) EC_KEY_free(ecKey); + + if (DtlsTransport::privateKey) + EVP_PKEY_free(DtlsTransport::privateKey); // NOTE: This also frees the EC key. + + if (DtlsTransport::certificate) X509_free(DtlsTransport::certificate); + + MS_THROW_ERROR("DTLS certificate and private key generation failed"); +} + +void DtlsTransport::ReadCertificateAndPrivateKeyFromFiles() { +#if 0 + MS_TRACE(); + + FILE* file{nullptr}; + + file = fopen(Settings::configuration.dtlsCertificateFile.c_str(), "r"); + + if (!file) { + MS_ERROR("error reading DTLS certificate file: %s", std::strerror(errno)); + + goto error; + } + + DtlsTransport::certificate = PEM_read_X509(file, nullptr, nullptr, nullptr); + + if (!DtlsTransport::certificate) { + LOG_OPENSSL_ERROR("PEM_read_X509() failed"); + + goto error; + } + + fclose(file); + + file = fopen(Settings::configuration.dtlsPrivateKeyFile.c_str(), "r"); + + if (!file) { + MS_ERROR("error reading DTLS private key file: %s", std::strerror(errno)); + + goto error; + } + + DtlsTransport::privateKey = PEM_read_PrivateKey(file, nullptr, nullptr, nullptr); + + if (!DtlsTransport::privateKey) { + LOG_OPENSSL_ERROR("PEM_read_PrivateKey() failed"); + + goto error; + } + + fclose(file); + + return; + + error: + + MS_THROW_ERROR("error reading DTLS certificate and private key PEM files"); +#endif +} + +void DtlsTransport::CreateSslCtx() { + MS_TRACE(); + + std::string dtlsSrtpCryptoSuites; + int ret; + + /* Set the global DTLS context. */ + + // Both DTLS 1.0 and 1.2 (requires OpenSSL >= 1.1.0). + DtlsTransport::sslCtx = SSL_CTX_new(DTLS_method()); + + if (!DtlsTransport::sslCtx) { + LOG_OPENSSL_ERROR("SSL_CTX_new() failed"); + + goto error; + } + + ret = SSL_CTX_use_certificate(DtlsTransport::sslCtx, DtlsTransport::certificate); + + if (ret == 0) { + LOG_OPENSSL_ERROR("SSL_CTX_use_certificate() failed"); + + goto error; + } + + ret = SSL_CTX_use_PrivateKey(DtlsTransport::sslCtx, DtlsTransport::privateKey); + + if (ret == 0) { + LOG_OPENSSL_ERROR("SSL_CTX_use_PrivateKey() failed"); + + goto error; + } + + ret = SSL_CTX_check_private_key(DtlsTransport::sslCtx); + + if (ret == 0) { + LOG_OPENSSL_ERROR("SSL_CTX_check_private_key() failed"); + + goto error; + } + + // Set options. + SSL_CTX_set_options(DtlsTransport::sslCtx, SSL_OP_CIPHER_SERVER_PREFERENCE | SSL_OP_NO_TICKET | + SSL_OP_SINGLE_ECDH_USE | SSL_OP_NO_QUERY_MTU); + + // Don't use sessions cache. + SSL_CTX_set_session_cache_mode(DtlsTransport::sslCtx, SSL_SESS_CACHE_OFF); + + // Read always as much into the buffer as possible. + // NOTE: This is the default for DTLS, but a bug in non latest OpenSSL + // versions makes this call required. + SSL_CTX_set_read_ahead(DtlsTransport::sslCtx, 1); + + SSL_CTX_set_verify_depth(DtlsTransport::sslCtx, 4); + + // Require certificate from peer. + SSL_CTX_set_verify(DtlsTransport::sslCtx, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, + onSslCertificateVerify); + + // Set SSL info callback. + SSL_CTX_set_info_callback(DtlsTransport::sslCtx, onSslInfo); + + // Set ciphers. + ret = SSL_CTX_set_cipher_list(DtlsTransport::sslCtx, + "DEFAULT:!NULL:!aNULL:!SHA256:!SHA384:!aECDH:!AESGCM+AES256:!aPSK"); + + if (ret == 0) { + LOG_OPENSSL_ERROR("SSL_CTX_set_cipher_list() failed"); + + goto error; + } + + // Enable ECDH ciphers. + // DOC: http://en.wikibooks.org/wiki/OpenSSL/Diffie-Hellman_parameters + // NOTE: https://code.google.com/p/chromium/issues/detail?id=406458 + // NOTE: https://bugs.ruby-lang.org/issues/12324 + + // For OpenSSL >= 1.0.2. + SSL_CTX_set_ecdh_auto(DtlsTransport::sslCtx, 1); + + // Set the "use_srtp" DTLS extension. + for (auto it = DtlsTransport::srtpCryptoSuites.begin(); + it != DtlsTransport::srtpCryptoSuites.end(); ++it) { + if (it != DtlsTransport::srtpCryptoSuites.begin()) dtlsSrtpCryptoSuites += ":"; + + SrtpCryptoSuiteMapEntry* cryptoSuiteEntry = std::addressof(*it); + dtlsSrtpCryptoSuites += cryptoSuiteEntry->name; + } + + MS_DEBUG_2TAGS(dtls, srtp, "setting SRTP cryptoSuites for DTLS: %s", + dtlsSrtpCryptoSuites.c_str()); + + // NOTE: This function returns 0 on success. + ret = SSL_CTX_set_tlsext_use_srtp(DtlsTransport::sslCtx, dtlsSrtpCryptoSuites.c_str()); + + if (ret != 0) { + MS_ERROR("SSL_CTX_set_tlsext_use_srtp() failed when entering '%s'", + dtlsSrtpCryptoSuites.c_str()); + LOG_OPENSSL_ERROR("SSL_CTX_set_tlsext_use_srtp() failed"); + + goto error; + } + + return; + +error: + + if (DtlsTransport::sslCtx) { + SSL_CTX_free(DtlsTransport::sslCtx); + DtlsTransport::sslCtx = nullptr; + } + + MS_THROW_ERROR("SSL context creation failed"); +} + +void DtlsTransport::GenerateFingerprints() { + MS_TRACE(); + + for (auto& kv : DtlsTransport::string2FingerprintAlgorithm) { + const std::string& algorithmString = kv.first; + FingerprintAlgorithm algorithm = kv.second; + uint8_t binaryFingerprint[EVP_MAX_MD_SIZE]; + unsigned int size{0}; + char hexFingerprint[(EVP_MAX_MD_SIZE * 3) + 1]; + const EVP_MD* hashFunction; + int ret; + + switch (algorithm) { + case FingerprintAlgorithm::SHA1: + hashFunction = EVP_sha1(); + break; + + case FingerprintAlgorithm::SHA224: + hashFunction = EVP_sha224(); + break; + + case FingerprintAlgorithm::SHA256: + hashFunction = EVP_sha256(); + break; + + case FingerprintAlgorithm::SHA384: + hashFunction = EVP_sha384(); + break; + + case FingerprintAlgorithm::SHA512: + hashFunction = EVP_sha512(); + break; + + default: + MS_THROW_ERROR("unknown algorithm"); + } + + ret = X509_digest(DtlsTransport::certificate, hashFunction, binaryFingerprint, &size); + + if (ret == 0) { + MS_ERROR("X509_digest() failed"); + MS_THROW_ERROR("Fingerprints generation failed"); + } + + // Convert to hexadecimal format in uppercase with colons. + for (unsigned int i{0}; i < size; ++i) { + std::sprintf(hexFingerprint + (i * 3), "%.2X:", binaryFingerprint[i]); + } + hexFingerprint[(size * 3) - 1] = '\0'; + + MS_DEBUG_TAG(dtls, "%-7s fingerprint: %s", algorithmString.c_str(), hexFingerprint); + // Store it in the vector. + DtlsTransport::Fingerprint fingerprint; + + fingerprint.algorithm = DtlsTransport::GetFingerprintAlgorithm(algorithmString); + fingerprint.value = hexFingerprint; + + DtlsTransport::localFingerprints.push_back(fingerprint); + } +} + +/* Instance methods. */ + +DtlsTransport::DtlsTransport(Listener* listener) : listener(listener) { + MS_TRACE(); + + /* Set SSL. */ + + this->ssl = SSL_new(DtlsTransport::sslCtx); + + if (!this->ssl) { + LOG_OPENSSL_ERROR("SSL_new() failed"); + + goto error; + } + + // Set this as custom data. + SSL_set_ex_data(this->ssl, 0, static_cast(this)); + + this->sslBioFromNetwork = BIO_new(BIO_s_mem()); + + if (!this->sslBioFromNetwork) { + LOG_OPENSSL_ERROR("BIO_new() failed"); + + SSL_free(this->ssl); + + goto error; + } + + this->sslBioToNetwork = BIO_new(BIO_s_mem()); + + if (!this->sslBioToNetwork) { + LOG_OPENSSL_ERROR("BIO_new() failed"); + + BIO_free(this->sslBioFromNetwork); + SSL_free(this->ssl); + + goto error; + } + + SSL_set_bio(this->ssl, this->sslBioFromNetwork, this->sslBioToNetwork); + + // Set the MTU so that we don't send packets that are too large with no fragmentation. + SSL_set_mtu(this->ssl, DtlsMtu); + DTLS_set_link_mtu(this->ssl, DtlsMtu); + + // Set callback handler for setting DTLS timer interval. + DTLS_set_timer_cb(this->ssl, onSslDtlsTimer); + + // Set the DTLS timer. + // this->timer = new Timer(this); + + return; + +error: + + // NOTE: At this point SSL_set_bio() was not called so we must free BIOs as + // well. + if (this->sslBioFromNetwork) BIO_free(this->sslBioFromNetwork); + + if (this->sslBioToNetwork) BIO_free(this->sslBioToNetwork); + + if (this->ssl) SSL_free(this->ssl); + + // NOTE: If this is not catched by the caller the program will abort, but + // this should never happen. + MS_THROW_ERROR("DtlsTransport instance creation failed"); +} + +DtlsTransport::~DtlsTransport() { + MS_TRACE(); + + if (IsRunning()) { + // Send close alert to the peer. + SSL_shutdown(this->ssl); + SendPendingOutgoingDtlsData(); + } + + if (this->ssl) { + SSL_free(this->ssl); + + this->ssl = nullptr; + this->sslBioFromNetwork = nullptr; + this->sslBioToNetwork = nullptr; + } + + // Close the DTLS timer. + // delete this->timer; +} + +void DtlsTransport::Dump() const { + MS_TRACE(); + + std::string state{"new"}; + std::string role{"none "}; + + switch (this->state) { + case DtlsState::CONNECTING: + state = "connecting"; + break; + case DtlsState::CONNECTED: + state = "connected"; + break; + case DtlsState::FAILED: + state = "failed"; + break; + case DtlsState::CLOSED: + state = "closed"; + break; + default:; + } + + switch (this->localRole) { + case Role::AUTO: + role = "auto"; + break; + case Role::SERVER: + role = "server"; + break; + case Role::CLIENT: + role = "client"; + break; + default:; + } + + MS_DUMP(""); + MS_DUMP(" state : %s", state.c_str()); + MS_DUMP(" role : %s", role.c_str()); + MS_DUMP(" handshake done: : %s", this->handshakeDone ? "yes" : "no"); + MS_DUMP(""); +} + +void DtlsTransport::Run(Role localRole) { + MS_TRACE(); + + MS_ASSERT(localRole == Role::CLIENT || localRole == Role::SERVER, + "local DTLS role must be 'client' or 'server'"); + + Role previousLocalRole = this->localRole; + + if (localRole == previousLocalRole) { + MS_ERROR("same local DTLS role provided, doing nothing"); + + return; + } + + // If the previous local DTLS role was 'client' or 'server' do reset. + if (previousLocalRole == Role::CLIENT || previousLocalRole == Role::SERVER) { + MS_DEBUG_TAG(dtls, "resetting DTLS due to local role change"); + + Reset(); + } + + // Update local role. + this->localRole = localRole; + + // Set state and notify the listener. + this->state = DtlsState::CONNECTING; + this->listener->OnDtlsTransportConnecting(this); + + switch (this->localRole) { + case Role::CLIENT: { + MS_DEBUG_TAG(dtls, "running [role:client]"); + + SSL_set_connect_state(this->ssl); + SSL_do_handshake(this->ssl); + SendPendingOutgoingDtlsData(); + SetTimeout(); + + break; + } + + case Role::SERVER: { + MS_DEBUG_TAG(dtls, "running [role:server]"); + + SSL_set_accept_state(this->ssl); + SSL_do_handshake(this->ssl); + + break; + } + + default: { + MS_ABORT("invalid local DTLS role"); + } + } +} + +bool DtlsTransport::SetRemoteFingerprint(Fingerprint fingerprint) { + MS_TRACE(); + + MS_ASSERT(fingerprint.algorithm != FingerprintAlgorithm::NONE, + "no fingerprint algorithm provided"); + + this->remoteFingerprint = fingerprint; + + // The remote fingerpring may have been set after DTLS handshake was done, + // so we may need to process it now. + if (this->handshakeDone && this->state != DtlsState::CONNECTED) { + MS_DEBUG_TAG(dtls, "handshake already done, processing it right now"); + + return ProcessHandshake(); + } + + return true; +} + +void DtlsTransport::ProcessDtlsData(const uint8_t* data, size_t len) { + MS_TRACE(); + + int written; + int read; + + if (!IsRunning()) { + MS_ERROR("cannot process data while not running"); + + return; + } + + // Write the received DTLS data into the sslBioFromNetwork. + written = + BIO_write(this->sslBioFromNetwork, static_cast(data), static_cast(len)); + + if (written != static_cast(len)) { + MS_WARN_TAG(dtls, "OpenSSL BIO_write() wrote less (%zu bytes) than given data (%zu bytes)", + static_cast(written), len); + } + + // Must call SSL_read() to process received DTLS data. + read = SSL_read(this->ssl, static_cast(DtlsTransport::sslReadBuffer), SslReadBufferSize); + + // Send data if it's ready. + SendPendingOutgoingDtlsData(); + + // Check SSL status and return if it is bad/closed. + if (!CheckStatus(read)) return; + + // Set/update the DTLS timeout. + if (!SetTimeout()) return; + + // Application data received. Notify to the listener. + if (read > 0) { + // It is allowed to receive DTLS data even before validating remote fingerprint. + if (!this->handshakeDone) { + MS_WARN_TAG(dtls, "ignoring application data received while DTLS handshake not done"); + + return; + } + + // Notify the listener. + this->listener->OnDtlsTransportApplicationDataReceived( + this, (uint8_t*)DtlsTransport::sslReadBuffer, static_cast(read)); + } +} + +void DtlsTransport::SendApplicationData(const uint8_t* data, size_t len) { + MS_TRACE(); + + // We cannot send data to the peer if its remote fingerprint is not validated. + if (this->state != DtlsState::CONNECTED) { + MS_WARN_TAG(dtls, "cannot send application data while DTLS is not fully connected"); + + return; + } + + if (len == 0) { + MS_WARN_TAG(dtls, "ignoring 0 length data"); + + return; + } + + int written; + + written = SSL_write(this->ssl, static_cast(data), static_cast(len)); + + if (written < 0) { + LOG_OPENSSL_ERROR("SSL_write() failed"); + + if (!CheckStatus(written)) return; + } else if (written != static_cast(len)) { + MS_WARN_TAG(dtls, "OpenSSL SSL_write() wrote less (%d bytes) than given data (%zu bytes)", + written, len); + } + + // Send data. + SendPendingOutgoingDtlsData(); +} + +void DtlsTransport::Reset() { + MS_TRACE(); + + int ret; + + if (!IsRunning()) return; + + MS_WARN_TAG(dtls, "resetting DTLS transport"); + + // Stop the DTLS timer. + // this->timer->Stop(); + + // We need to reset the SSL instance so we need to "shutdown" it, but we + // don't want to send a Close Alert to the peer, so just don't call + // SendPendingOutgoingDTLSData(). + SSL_shutdown(this->ssl); + + this->localRole = Role::NONE; + this->state = DtlsState::NEW; + this->handshakeDone = false; + this->handshakeDoneNow = false; + + // Reset SSL status. + // NOTE: For this to properly work, SSL_shutdown() must be called before. + // NOTE: This may fail if not enough DTLS handshake data has been received, + // but we don't care so just clear the error queue. + ret = SSL_clear(this->ssl); + + if (ret == 0) ERR_clear_error(); +} + +inline bool DtlsTransport::CheckStatus(int returnCode) { + MS_TRACE(); + + int err; + bool wasHandshakeDone = this->handshakeDone; + + err = SSL_get_error(this->ssl, returnCode); + + switch (err) { + case SSL_ERROR_NONE: + break; + + case SSL_ERROR_SSL: + LOG_OPENSSL_ERROR("SSL status: SSL_ERROR_SSL"); + break; + + case SSL_ERROR_WANT_READ: + break; + + case SSL_ERROR_WANT_WRITE: + MS_WARN_TAG(dtls, "SSL status: SSL_ERROR_WANT_WRITE"); + break; + + case SSL_ERROR_WANT_X509_LOOKUP: + MS_DEBUG_TAG(dtls, "SSL status: SSL_ERROR_WANT_X509_LOOKUP"); + break; + + case SSL_ERROR_SYSCALL: + LOG_OPENSSL_ERROR("SSL status: SSL_ERROR_SYSCALL"); + break; + + case SSL_ERROR_ZERO_RETURN: + break; + + case SSL_ERROR_WANT_CONNECT: + MS_WARN_TAG(dtls, "SSL status: SSL_ERROR_WANT_CONNECT"); + break; + + case SSL_ERROR_WANT_ACCEPT: + MS_WARN_TAG(dtls, "SSL status: SSL_ERROR_WANT_ACCEPT"); + break; + + default: + MS_WARN_TAG(dtls, "SSL status: unknown error"); + } + + // Check if the handshake (or re-handshake) has been done right now. + if (this->handshakeDoneNow) { + this->handshakeDoneNow = false; + this->handshakeDone = true; + + // Stop the timer. + // this->timer->Stop(); + + // Process the handshake just once (ignore if DTLS renegotiation). + // if (!wasHandshakeDone && this->remoteFingerprint.algorithm != FingerprintAlgorithm::NONE) + // return ProcessHandshake(); + if (!wasHandshakeDone) { + return ProcessHandshake(); + } + + return true; + } + // Check if the peer sent close alert or a fatal error happened. + else if (((SSL_get_shutdown(this->ssl) & SSL_RECEIVED_SHUTDOWN) != 0) || err == SSL_ERROR_SSL || + err == SSL_ERROR_SYSCALL) { + if (this->state == DtlsState::CONNECTED) { + MS_DEBUG_TAG(dtls, "disconnected"); + + Reset(); + + // Set state and notify the listener. + this->state = DtlsState::CLOSED; + this->listener->OnDtlsTransportClosed(this); + } else { + MS_WARN_TAG(dtls, "connection failed"); + + Reset(); + + // Set state and notify the listener. + this->state = DtlsState::FAILED; + this->listener->OnDtlsTransportFailed(this); + } + + return false; + } else { + return true; + } +} + +inline void DtlsTransport::SendPendingOutgoingDtlsData() { + MS_TRACE(); + + if (BIO_eof(this->sslBioToNetwork)) return; + + int64_t read; + char* data{nullptr}; + + read = BIO_get_mem_data(this->sslBioToNetwork, &data); // NOLINT + + if (read <= 0) return; + + MS_DEBUG_DEV("%ld bytes of DTLS data ready to sent to the peer", read); + + // Notify the listener. + this->listener->OnDtlsTransportSendData(this, reinterpret_cast(data), + static_cast(read)); + + // Clear the BIO buffer. + // NOTE: the (void) avoids the -Wunused-value warning. + (void)BIO_reset(this->sslBioToNetwork); +} + +inline bool DtlsTransport::SetTimeout() { + MS_TRACE(); + + MS_ASSERT(this->state == DtlsState::CONNECTING || this->state == DtlsState::CONNECTED, + "invalid DTLS state"); + + int64_t ret; + uv_timeval_t dtlsTimeout{0, 0}; + uint64_t timeoutMs; + + // NOTE: If ret == 0 then ignore the value in dtlsTimeout. + // NOTE: No DTLSv_1_2_get_timeout() or DTLS_get_timeout() in OpenSSL 1.1.0-dev. + ret = DTLSv1_get_timeout(this->ssl, static_cast(&dtlsTimeout)); // NOLINT + + if (ret == 0) return true; + + timeoutMs = (dtlsTimeout.tv_sec * static_cast(1000)) + (dtlsTimeout.tv_usec / 1000); + + if (timeoutMs == 0) { + return true; + } else if (timeoutMs < 30000) { + MS_DEBUG_DEV("DTLS timer set in %lu ms", timeoutMs); + + // this->timer->Start(timeoutMs); + + return true; + } + // NOTE: Don't start the timer again if the timeout is greater than 30 seconds. + else { + MS_WARN_TAG(dtls, "DTLS timeout too high (%lu ms), resetting DLTS", timeoutMs); + + Reset(); + + // Set state and notify the listener. + this->state = DtlsState::FAILED; + this->listener->OnDtlsTransportFailed(this); + + return false; + } +} + +inline bool DtlsTransport::ProcessHandshake() { + MS_TRACE(); + + MS_ASSERT(this->handshakeDone, "handshake not done yet"); +// MS_ASSERT(this->remoteFingerprint.algorithm != FingerprintAlgorithm::NONE, +// "remote fingerprint not set"); + + // Validate the remote fingerprint. + // if (!CheckRemoteFingerprint()) { + // Reset(); + + // // Set state and notify the listener. + // this->state = DtlsState::FAILED; + // this->listener->OnDtlsTransportFailed(this); + + // return false; + // } + + // Get the negotiated SRTP crypto suite. + RTC::CryptoSuite srtpCryptoSuite = GetNegotiatedSrtpCryptoSuite(); + + if (srtpCryptoSuite != RTC::CryptoSuite::NONE) { + // Extract the SRTP keys (will notify the listener with them). + ExtractSrtpKeys(srtpCryptoSuite); + + return true; + } + + // NOTE: We assume that "use_srtp" DTLS extension is required even if + // there is no audio/video. + MS_WARN_2TAGS(dtls, srtp, "SRTP crypto suite not negotiated"); + + Reset(); + + // Set state and notify the listener. + this->state = DtlsState::FAILED; + this->listener->OnDtlsTransportFailed(this); + + return false; +} + +inline bool DtlsTransport::CheckRemoteFingerprint() { + MS_TRACE(); + + MS_ASSERT(this->remoteFingerprint.algorithm != FingerprintAlgorithm::NONE, + "remote fingerprint not set"); + + X509* certificate; + uint8_t binaryFingerprint[EVP_MAX_MD_SIZE]; + unsigned int size{0}; + char hexFingerprint[(EVP_MAX_MD_SIZE * 3) + 1]; + const EVP_MD* hashFunction; + int ret; + + certificate = SSL_get_peer_certificate(this->ssl); + + if (!certificate) { + MS_WARN_TAG(dtls, "no certificate was provided by the peer"); + + return false; + } + + switch (this->remoteFingerprint.algorithm) { + case FingerprintAlgorithm::SHA1: + hashFunction = EVP_sha1(); + break; + + case FingerprintAlgorithm::SHA224: + hashFunction = EVP_sha224(); + break; + + case FingerprintAlgorithm::SHA256: + hashFunction = EVP_sha256(); + break; + + case FingerprintAlgorithm::SHA384: + hashFunction = EVP_sha384(); + break; + + case FingerprintAlgorithm::SHA512: + hashFunction = EVP_sha512(); + break; + + default: + MS_ABORT("unknown algorithm"); + } + + // Compare the remote fingerprint with the value given via signaling. + ret = X509_digest(certificate, hashFunction, binaryFingerprint, &size); + + if (ret == 0) { + MS_ERROR("X509_digest() failed"); + + X509_free(certificate); + + return false; + } + + // Convert to hexadecimal format in uppercase with colons. + for (unsigned int i{0}; i < size; ++i) { + std::sprintf(hexFingerprint + (i * 3), "%.2X:", binaryFingerprint[i]); + } + hexFingerprint[(size * 3) - 1] = '\0'; + + if (this->remoteFingerprint.value != hexFingerprint) { + MS_WARN_TAG(dtls, + "fingerprint in the remote certificate (%s) does not match the announced one (%s)", + hexFingerprint, this->remoteFingerprint.value.c_str()); + + X509_free(certificate); + + return false; + } + + MS_DEBUG_TAG(dtls, "valid remote fingerprint"); + + // Get the remote certificate in PEM format. + + BIO* bio = BIO_new(BIO_s_mem()); + + // Ensure the underlying BUF_MEM structure is also freed. + // NOTE: Avoid stupid "warning: value computed is not used [-Wunused-value]" since + // BIO_set_close() always returns 1. + (void)BIO_set_close(bio, BIO_CLOSE); + + ret = PEM_write_bio_X509(bio, certificate); + + if (ret != 1) { + LOG_OPENSSL_ERROR("PEM_write_bio_X509() failed"); + + X509_free(certificate); + BIO_free(bio); + + return false; + } + + BUF_MEM* mem; + + BIO_get_mem_ptr(bio, &mem); // NOLINT[cppcoreguidelines-pro-type-cstyle-cast] + + if (!mem || !mem->data || mem->length == 0u) { + LOG_OPENSSL_ERROR("BIO_get_mem_ptr() failed"); + + X509_free(certificate); + BIO_free(bio); + + return false; + } + + this->remoteCert = std::string(mem->data, mem->length); + + X509_free(certificate); + BIO_free(bio); + + return true; +} + +inline void DtlsTransport::ExtractSrtpKeys(RTC::CryptoSuite srtpCryptoSuite) { + MS_TRACE(); + + size_t srtpKeyLength{0}; + size_t srtpSaltLength{0}; + size_t srtpMasterLength{0}; + + switch (srtpCryptoSuite) { + case RTC::CryptoSuite::AES_CM_128_HMAC_SHA1_80: + case RTC::CryptoSuite::AES_CM_128_HMAC_SHA1_32: { + srtpKeyLength = SrtpMasterKeyLength; + srtpSaltLength = SrtpMasterSaltLength; + srtpMasterLength = SrtpMasterLength; + + break; + } + + case RTC::CryptoSuite::AEAD_AES_256_GCM: { + srtpKeyLength = SrtpAesGcm256MasterKeyLength; + srtpSaltLength = SrtpAesGcm256MasterSaltLength; + srtpMasterLength = SrtpAesGcm256MasterLength; + + break; + } + + case RTC::CryptoSuite::AEAD_AES_128_GCM: { + srtpKeyLength = SrtpAesGcm128MasterKeyLength; + srtpSaltLength = SrtpAesGcm128MasterSaltLength; + srtpMasterLength = SrtpAesGcm128MasterLength; + + break; + } + + default: { + MS_ABORT("unknown SRTP crypto suite"); + } + } + + auto* srtpMaterial = new uint8_t[srtpMasterLength * 2]; + uint8_t* srtpLocalKey{nullptr}; + uint8_t* srtpLocalSalt{nullptr}; + uint8_t* srtpRemoteKey{nullptr}; + uint8_t* srtpRemoteSalt{nullptr}; + auto* srtpLocalMasterKey = new uint8_t[srtpMasterLength]; + auto* srtpRemoteMasterKey = new uint8_t[srtpMasterLength]; + int ret; + + ret = SSL_export_keying_material(this->ssl, srtpMaterial, srtpMasterLength * 2, + "EXTRACTOR-dtls_srtp", 19, nullptr, 0, 0); + + MS_ASSERT(ret != 0, "SSL_export_keying_material() failed"); + + switch (this->localRole) { + case Role::SERVER: { + srtpRemoteKey = srtpMaterial; + srtpLocalKey = srtpRemoteKey + srtpKeyLength; + srtpRemoteSalt = srtpLocalKey + srtpKeyLength; + srtpLocalSalt = srtpRemoteSalt + srtpSaltLength; + + break; + } + + case Role::CLIENT: { + srtpLocalKey = srtpMaterial; + srtpRemoteKey = srtpLocalKey + srtpKeyLength; + srtpLocalSalt = srtpRemoteKey + srtpKeyLength; + srtpRemoteSalt = srtpLocalSalt + srtpSaltLength; + + break; + } + + default: { + MS_ABORT("no DTLS role set"); + } + } + + // Create the SRTP local master key. + std::memcpy(srtpLocalMasterKey, srtpLocalKey, srtpKeyLength); + std::memcpy(srtpLocalMasterKey + srtpKeyLength, srtpLocalSalt, srtpSaltLength); + // Create the SRTP remote master key. + std::memcpy(srtpRemoteMasterKey, srtpRemoteKey, srtpKeyLength); + std::memcpy(srtpRemoteMasterKey + srtpKeyLength, srtpRemoteSalt, srtpSaltLength); + + // Set state and notify the listener. + this->state = DtlsState::CONNECTED; + this->listener->OnDtlsTransportConnected(this, srtpCryptoSuite, srtpLocalMasterKey, + srtpMasterLength, srtpRemoteMasterKey, srtpMasterLength, + this->remoteCert); + + delete[] srtpMaterial; + delete[] srtpLocalMasterKey; + delete[] srtpRemoteMasterKey; +} + +inline RTC::CryptoSuite DtlsTransport::GetNegotiatedSrtpCryptoSuite() { + MS_TRACE(); + + RTC::CryptoSuite negotiatedSrtpCryptoSuite = RTC::CryptoSuite::NONE; + + // Ensure that the SRTP crypto suite has been negotiated. + // NOTE: This is a OpenSSL type. + SRTP_PROTECTION_PROFILE* sslSrtpCryptoSuite = SSL_get_selected_srtp_profile(this->ssl); + + if (!sslSrtpCryptoSuite) return negotiatedSrtpCryptoSuite; + + // Get the negotiated SRTP crypto suite. + for (auto& srtpCryptoSuite : DtlsTransport::srtpCryptoSuites) { + SrtpCryptoSuiteMapEntry* cryptoSuiteEntry = std::addressof(srtpCryptoSuite); + + if (std::strcmp(sslSrtpCryptoSuite->name, cryptoSuiteEntry->name) == 0) { + MS_DEBUG_2TAGS(dtls, srtp, "chosen SRTP crypto suite: %s", cryptoSuiteEntry->name); + + negotiatedSrtpCryptoSuite = cryptoSuiteEntry->cryptoSuite; + } + } + + MS_ASSERT(negotiatedSrtpCryptoSuite != RTC::CryptoSuite::NONE, + "chosen SRTP crypto suite is not an available one"); + + return negotiatedSrtpCryptoSuite; +} + +inline void DtlsTransport::OnSslInfo(int where, int ret) { + MS_TRACE(); + + int w = where & -SSL_ST_MASK; + const char* role; + + if ((w & SSL_ST_CONNECT) != 0) + role = "client"; + else if ((w & SSL_ST_ACCEPT) != 0) + role = "server"; + else + role = "undefined"; + + if ((where & SSL_CB_LOOP) != 0) { + MS_DEBUG_TAG(dtls, "[role:%s, action:'%s']", role, SSL_state_string_long(this->ssl)); + } else if ((where & SSL_CB_ALERT) != 0) { + const char* alertType; + + switch (*SSL_alert_type_string(ret)) { + case 'W': + alertType = "warning"; + break; + + case 'F': + alertType = "fatal"; + break; + + default: + alertType = "undefined"; + } + + if ((where & SSL_CB_READ) != 0) { + MS_WARN_TAG(dtls, "received DTLS %s alert: %s", alertType, SSL_alert_desc_string_long(ret)); + } else if ((where & SSL_CB_WRITE) != 0) { + MS_DEBUG_TAG(dtls, "sending DTLS %s alert: %s", alertType, SSL_alert_desc_string_long(ret)); + } else { + MS_DEBUG_TAG(dtls, "DTLS %s alert: %s", alertType, SSL_alert_desc_string_long(ret)); + } + } else if ((where & SSL_CB_EXIT) != 0) { + if (ret == 0) + MS_DEBUG_TAG(dtls, "[role:%s, failed:'%s']", role, SSL_state_string_long(this->ssl)); + else if (ret < 0) + MS_DEBUG_TAG(dtls, "role: %s, waiting:'%s']", role, SSL_state_string_long(this->ssl)); + } else if ((where & SSL_CB_HANDSHAKE_START) != 0) { + MS_DEBUG_TAG(dtls, "DTLS handshake start"); + } else if ((where & SSL_CB_HANDSHAKE_DONE) != 0) { + MS_DEBUG_TAG(dtls, "DTLS handshake done"); + + this->handshakeDoneNow = true; + } + + // NOTE: checking SSL_get_shutdown(this->ssl) & SSL_RECEIVED_SHUTDOWN here upon + // receipt of a close alert does not work (the flag is set after this callback). +} + +inline void DtlsTransport::OnTimer() { + MS_TRACE(); + + // Workaround for https://github.com/openssl/openssl/issues/7998. + if (this->handshakeDone) { + MS_DEBUG_DEV("handshake is done so return"); + + return; + } + + DTLSv1_handle_timeout(this->ssl); + + // If required, send DTLS data. + SendPendingOutgoingDtlsData(); + + // Set the DTLS timer again. + SetTimeout(); +} +} // namespace RTC diff --git a/webrtc/rtc_dtls_transport.h b/webrtc/rtc_dtls_transport.h new file mode 100644 index 00000000..628f966b --- /dev/null +++ b/webrtc/rtc_dtls_transport.h @@ -0,0 +1,187 @@ +#ifndef MS_RTC_DTLS_TRANSPORT_HPP +#define MS_RTC_DTLS_TRANSPORT_HPP + +#include +#include +#include + +#include +#include +#include + +namespace RTC { +enum class CryptoSuite { + NONE = 0, + AES_CM_128_HMAC_SHA1_80 = 1, + AES_CM_128_HMAC_SHA1_32, + AEAD_AES_256_GCM, + AEAD_AES_128_GCM +}; +class DtlsTransport { + public: + enum class DtlsState { NEW = 1, CONNECTING, CONNECTED, FAILED, CLOSED }; + + public: + enum class Role { NONE = 0, AUTO = 1, CLIENT, SERVER }; + + public: + enum class FingerprintAlgorithm { NONE = 0, SHA1 = 1, SHA224, SHA256, SHA384, SHA512 }; + + public: + struct Fingerprint { + FingerprintAlgorithm algorithm{FingerprintAlgorithm::NONE}; + std::string value; + }; + + private: + struct SrtpCryptoSuiteMapEntry { + RTC::CryptoSuite cryptoSuite; + const char* name; + }; + + public: + class Listener { + public: + // DTLS is in the process of negotiating a secure connection. Incoming + // media can flow through. + // NOTE: The caller MUST NOT call any method during this callback. + virtual void OnDtlsTransportConnecting(const RTC::DtlsTransport* dtlsTransport) = 0; + // DTLS has completed negotiation of a secure connection (including DTLS-SRTP + // and remote fingerprint verification). Outgoing media can now flow through. + // NOTE: The caller MUST NOT call any method during this callback. + virtual void OnDtlsTransportConnected(const RTC::DtlsTransport* dtlsTransport, + RTC::CryptoSuite srtpCryptoSuite, uint8_t* srtpLocalKey, + size_t srtpLocalKeyLen, uint8_t* srtpRemoteKey, + size_t srtpRemoteKeyLen, std::string& remoteCert) = 0; + // The DTLS connection has been closed as the result of an error (such as a + // DTLS alert or a failure to validate the remote fingerprint). + virtual void OnDtlsTransportFailed(const RTC::DtlsTransport* dtlsTransport) = 0; + // The DTLS connection has been closed due to receipt of a close_notify alert. + virtual void OnDtlsTransportClosed(const RTC::DtlsTransport* dtlsTransport) = 0; + // Need to send DTLS data to the peer. + virtual void OnDtlsTransportSendData(const RTC::DtlsTransport* dtlsTransport, + const uint8_t* data, size_t len) = 0; + // DTLS application data received. + virtual void OnDtlsTransportApplicationDataReceived(const RTC::DtlsTransport* dtlsTransport, + const uint8_t* data, size_t len) = 0; + }; + + public: + static void ClassInit(); + static void ClassDestroy(); + static Role StringToRole(const std::string& role) { + auto it = DtlsTransport::string2Role.find(role); + + if (it != DtlsTransport::string2Role.end()) + return it->second; + else + return DtlsTransport::Role::NONE; + } + static FingerprintAlgorithm GetFingerprintAlgorithm(const std::string& fingerprint) { + auto it = DtlsTransport::string2FingerprintAlgorithm.find(fingerprint); + + if (it != DtlsTransport::string2FingerprintAlgorithm.end()) + return it->second; + else + return DtlsTransport::FingerprintAlgorithm::NONE; + } + static std::string& GetFingerprintAlgorithmString(FingerprintAlgorithm fingerprint) { + auto it = DtlsTransport::fingerprintAlgorithm2String.find(fingerprint); + + return it->second; + } + static bool IsDtls(const uint8_t* data, size_t len) { + // clang-format off + return ( + // Minimum DTLS record length is 13 bytes. + (len >= 13) && + // DOC: https://tools.ietf.org/html/draft-ietf-avtcore-rfc5764-mux-fixes + (data[0] > 19 && data[0] < 64) + ); + // clang-format on + } + + private: + static void GenerateCertificateAndPrivateKey(); + static void ReadCertificateAndPrivateKeyFromFiles(); + static void CreateSslCtx(); + static void GenerateFingerprints(); + + private: + static X509* certificate; + static EVP_PKEY* privateKey; + static SSL_CTX* sslCtx; + static uint8_t sslReadBuffer[]; + static std::map string2Role; + static std::map string2FingerprintAlgorithm; + static std::map fingerprintAlgorithm2String; + static std::vector localFingerprints; + static std::vector srtpCryptoSuites; + + public: + explicit DtlsTransport(Listener* listener); + ~DtlsTransport(); + + public: + void Dump() const; + void Run(Role localRole); + std::vector& GetLocalFingerprints() const { + return DtlsTransport::localFingerprints; + } + bool SetRemoteFingerprint(Fingerprint fingerprint); + void ProcessDtlsData(const uint8_t* data, size_t len); + DtlsState GetState() const { return this->state; } + Role GetLocalRole() const { return this->localRole; } + void SendApplicationData(const uint8_t* data, size_t len); + + private: + bool IsRunning() const { + switch (this->state) { + case DtlsState::NEW: + return false; + case DtlsState::CONNECTING: + case DtlsState::CONNECTED: + return true; + case DtlsState::FAILED: + case DtlsState::CLOSED: + return false; + } + + // Make GCC 4.9 happy. + return false; + } + void Reset(); + bool CheckStatus(int returnCode); + void SendPendingOutgoingDtlsData(); + bool SetTimeout(); + bool ProcessHandshake(); + bool CheckRemoteFingerprint(); + void ExtractSrtpKeys(RTC::CryptoSuite srtpCryptoSuite); + RTC::CryptoSuite GetNegotiatedSrtpCryptoSuite(); + + /* Callbacks fired by OpenSSL events. */ + public: + void OnSslInfo(int where, int ret); + + /* Pure virtual methods inherited from Timer::Listener. */ + public: + void OnTimer(); + + private: + // Passed by argument. + Listener* listener{nullptr}; + // Allocated by this. + SSL* ssl{nullptr}; + BIO* sslBioFromNetwork{nullptr}; // The BIO from which ssl reads. + BIO* sslBioToNetwork{nullptr}; // The BIO in which ssl writes. + // Others. + DtlsState state{DtlsState::NEW}; + Role localRole{Role::NONE}; + Fingerprint remoteFingerprint; + bool handshakeDone{false}; + bool handshakeDoneNow{false}; + std::string remoteCert; +}; +} // namespace RTC + +#endif diff --git a/webrtc/srtp_session.cc b/webrtc/srtp_session.cc new file mode 100644 index 00000000..b21a83cd --- /dev/null +++ b/webrtc/srtp_session.cc @@ -0,0 +1,269 @@ +#define MS_CLASS "RTC::SrtpSession" +// #define MS_LOG_DEV_LEVEL 3 + +#include "srtp_session.h" + +#include // std::memset(), std::memcpy() +#include + +#include "logger.h" + +namespace RTC { +/* Static. */ + +static constexpr size_t EncryptBufferSize{65536}; +static uint8_t EncryptBuffer[EncryptBufferSize]; + +/* Class methods. */ + +std::vector DepLibSRTP::errors = { + // From 0 (srtp_err_status_ok) to 24 (srtp_err_status_pfkey_err). + "success (srtp_err_status_ok)", + "unspecified failure (srtp_err_status_fail)", + "unsupported parameter (srtp_err_status_bad_param)", + "couldn't allocate memory (srtp_err_status_alloc_fail)", + "couldn't deallocate memory (srtp_err_status_dealloc_fail)", + "couldn't initialize (srtp_err_status_init_fail)", + "can’t process as much data as requested (srtp_err_status_terminus)", + "authentication failure (srtp_err_status_auth_fail)", + "cipher failure (srtp_err_status_cipher_fail)", + "replay check failed (bad index) (srtp_err_status_replay_fail)", + "replay check failed (index too old) (srtp_err_status_replay_old)", + "algorithm failed test routine (srtp_err_status_algo_fail)", + "unsupported operation (srtp_err_status_no_such_op)", + "no appropriate context found (srtp_err_status_no_ctx)", + "unable to perform desired validation (srtp_err_status_cant_check)", + "can’t use key any more (srtp_err_status_key_expired)", + "error in use of socket (srtp_err_status_socket_err)", + "error in use POSIX signals (srtp_err_status_signal_err)", + "nonce check failed (srtp_err_status_nonce_bad)", + "couldn’t read data (srtp_err_status_read_fail)", + "couldn’t write data (srtp_err_status_write_fail)", + "error parsing data (srtp_err_status_parse_err)", + "error encoding data (srtp_err_status_encode_err)", + "error while using semaphores (srtp_err_status_semaphore_err)", + "error while using pfkey (srtp_err_status_pfkey_err)"}; +// clang-format on + +/* Static methods. */ + +void DepLibSRTP::ClassInit() { + MS_TRACE(); + + MS_DEBUG_TAG(info, "libsrtp version: \"%s\"", srtp_get_version_string()); + + srtp_err_status_t err = srtp_init(); + + if (DepLibSRTP::IsError(err)) + MS_THROW_ERROR("srtp_init() failed: %s", DepLibSRTP::GetErrorString(err)); +} + +void DepLibSRTP::ClassDestroy() { + MS_TRACE(); + + srtp_shutdown(); +} + +void SrtpSession::ClassInit() { + // Set libsrtp event handler. + srtp_err_status_t err = + srtp_install_event_handler(static_cast(OnSrtpEvent)); + if (DepLibSRTP::IsError(err)) { + MS_THROW_ERROR("srtp_install_event_handler() failed: %s", DepLibSRTP::GetErrorString(err)); + std::cout << "srtp_install_event_handler() failed :" << DepLibSRTP::GetErrorString(err); + } +} + +void SrtpSession::OnSrtpEvent(srtp_event_data_t *data) { + MS_TRACE(); + + switch (data->event) { + case event_ssrc_collision: + MS_WARN_TAG(srtp, "SSRC collision occurred"); + break; + + case event_key_soft_limit: + MS_WARN_TAG(srtp, "stream reached the soft key usage limit and will expire soon"); + break; + + case event_key_hard_limit: + MS_WARN_TAG(srtp, "stream reached the hard key usage limit and has expired"); + break; + + case event_packet_index_limit: + MS_WARN_TAG(srtp, "stream reached the hard packet limit (2^48 packets)"); + break; + } +} + +/* Instance methods. */ + +SrtpSession::SrtpSession(Type type, CryptoSuite cryptoSuite, uint8_t *key, size_t keyLen) { + MS_TRACE(); + + srtp_policy_t policy;// NOLINT(cppcoreguidelines-pro-type-member-init) + + // Set all policy fields to 0. + std::memset(&policy, 0, sizeof(srtp_policy_t)); + + switch (cryptoSuite) { + case CryptoSuite::AES_CM_128_HMAC_SHA1_80: { + srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&policy.rtp); + srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&policy.rtcp); + + break; + } + + case CryptoSuite::AES_CM_128_HMAC_SHA1_32: { + srtp_crypto_policy_set_aes_cm_128_hmac_sha1_32(&policy.rtp); + // NOTE: Must be 80 for RTCP. + srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&policy.rtcp); + + break; + } + + case CryptoSuite::AEAD_AES_256_GCM: { + srtp_crypto_policy_set_aes_gcm_256_16_auth(&policy.rtp); + srtp_crypto_policy_set_aes_gcm_256_16_auth(&policy.rtcp); + + break; + } + + case CryptoSuite::AEAD_AES_128_GCM: { + srtp_crypto_policy_set_aes_gcm_128_16_auth(&policy.rtp); + srtp_crypto_policy_set_aes_gcm_128_16_auth(&policy.rtcp); + + break; + } + + default: { + MS_ABORT("unknown SRTP crypto suite"); + } + } + + MS_ASSERT((int) keyLen == policy.rtp.cipher_key_len, + "given keyLen does not match policy.rtp.cipher_keyLen"); + + switch (type) { + case Type::INBOUND: + policy.ssrc.type = ssrc_any_inbound; + break; + + case Type::OUTBOUND: + policy.ssrc.type = ssrc_any_outbound; + break; + } + + policy.ssrc.value = 0; + policy.key = key; + // Required for sending RTP retransmission without RTX. + policy.allow_repeat_tx = 1; + policy.window_size = 1024; + policy.next = nullptr; + + // Set the SRTP session. + srtp_err_status_t err = srtp_create(&this->session, &policy); + if (DepLibSRTP::IsError(err)) { + is_init = false; + MS_THROW_ERROR("srtp_create() failed: %s", DepLibSRTP::GetErrorString(err)); + } else { + is_init = true; + } +} + +SrtpSession::~SrtpSession() { + MS_TRACE(); + + if (this->session != nullptr) { + srtp_err_status_t err = srtp_dealloc(this->session); + + if (DepLibSRTP::IsError(err)) + MS_ABORT("srtp_dealloc() failed: %s", DepLibSRTP::GetErrorString(err)); + } +} + +bool SrtpSession::EncryptRtp(const uint8_t **data, size_t *len) { + MS_TRACE(); + if (!is_init) { + return false; + } + // Ensure that the resulting SRTP packet fits into the encrypt buffer. + if (*len + SRTP_MAX_TRAILER_LEN > EncryptBufferSize) { + MS_WARN_TAG(srtp, "cannot encrypt RTP packet, size too big (%zu bytes)", *len); + + return false; + } + std::memcpy(EncryptBuffer, *data, *len); + + srtp_err_status_t err = + srtp_protect(this->session, static_cast(EncryptBuffer), reinterpret_cast(len)); + + if (DepLibSRTP::IsError(err)) { + MS_WARN_TAG(srtp, "srtp_protect() failed: %s", DepLibSRTP::GetErrorString(err)); + + return false; + } + + // Update the given data pointer. + *data = (const uint8_t *) EncryptBuffer; + + return true; +} + +bool SrtpSession::DecryptSrtp(uint8_t *data, size_t *len) { + MS_TRACE(); + + srtp_err_status_t err = + srtp_unprotect(this->session, static_cast(data), reinterpret_cast(len)); + + if (DepLibSRTP::IsError(err)) { + MS_DEBUG_TAG(srtp, "srtp_unprotect() failed: %s", DepLibSRTP::GetErrorString(err)); + + return false; + } + + return true; +} + +bool SrtpSession::EncryptRtcp(const uint8_t **data, size_t *len) { + MS_TRACE(); + + // Ensure that the resulting SRTCP packet fits into the encrypt buffer. + if (*len + SRTP_MAX_TRAILER_LEN > EncryptBufferSize) { + MS_WARN_TAG(srtp, "cannot encrypt RTCP packet, size too big (%zu bytes)", *len); + + return false; + } + + std::memcpy(EncryptBuffer, *data, *len); + + srtp_err_status_t err = srtp_protect_rtcp(this->session, static_cast(EncryptBuffer), + reinterpret_cast(len)); + + if (DepLibSRTP::IsError(err)) { + MS_WARN_TAG(srtp, "srtp_protect_rtcp() failed: %s", DepLibSRTP::GetErrorString(err)); + + return false; + } + + // Update the given data pointer. + *data = (const uint8_t *) EncryptBuffer; + + return true; +} + +bool SrtpSession::DecryptSrtcp(uint8_t *data, size_t *len) { + MS_TRACE(); + + srtp_err_status_t err = + srtp_unprotect_rtcp(this->session, static_cast(data), reinterpret_cast(len)); + + if (DepLibSRTP::IsError(err)) { + MS_DEBUG_TAG(srtp, "srtp_unprotect_rtcp() failed: %s", DepLibSRTP::GetErrorString(err)); + + return false; + } + + return true; +} +}// namespace RTC diff --git a/webrtc/srtp_session.h b/webrtc/srtp_session.h new file mode 100644 index 00000000..95da5fbf --- /dev/null +++ b/webrtc/srtp_session.h @@ -0,0 +1,54 @@ +#ifndef MS_RTC_SRTP_SESSION_HPP +#define MS_RTC_SRTP_SESSION_HPP + +#include "rtc_dtls_transport.h" +#include "utils.h" +#include +#include + +namespace RTC { + +class DepLibSRTP { +public: + static void ClassInit(); + static void ClassDestroy(); + static bool IsError(srtp_err_status_t code) { return (code != srtp_err_status_ok); } + static const char *GetErrorString(srtp_err_status_t code) { + // This throws out_of_range if the given index is not in the vector. + return DepLibSRTP::errors.at(code); + } + +private: + static std::vector errors; +}; + +class SrtpSession { +public: +public: + enum class Type { INBOUND = 1, OUTBOUND }; + +public: + static void ClassInit(); + +private: + static void OnSrtpEvent(srtp_event_data_t *data); + +public: + SrtpSession(Type type, CryptoSuite cryptoSuite, uint8_t *key, size_t keyLen); + ~SrtpSession(); + +public: + bool EncryptRtp(const uint8_t **data, size_t *len); + bool DecryptSrtp(uint8_t *data, size_t *len); + bool EncryptRtcp(const uint8_t **data, size_t *len); + bool DecryptSrtcp(uint8_t *data, size_t *len); + void RemoveStream(uint32_t ssrc) { srtp_remove_stream(this->session, uint32_t{htonl(ssrc)}); } + +private: + bool is_init = false; + // Allocated by this. + srtp_t session{nullptr}; +}; +}// namespace RTC + +#endif diff --git a/webrtc/stun_packet.cc b/webrtc/stun_packet.cc new file mode 100644 index 00000000..926f863b --- /dev/null +++ b/webrtc/stun_packet.cc @@ -0,0 +1,710 @@ +#define MS_CLASS "RTC::StunPacket" +// #define MS_LOG_DEV + +#include "stun_packet.h" + +#include // std::snprintf() +#include // std::memcmp(), std::memcpy() + +#include "utils.h" + +namespace RTC { + +/* Class variables. */ + +const uint8_t StunPacket::kMagicCookie[] = {0x21, 0x12, 0xA4, 0x42}; + +/* Class methods. */ + +StunPacket* StunPacket::Parse(const uint8_t* data, size_t len) { + if (!StunPacket::IsStun(data, len)) return nullptr; + + /* + The message type field is decomposed further into the following + structure: + + 0 1 + 2 3 4 5 6 7 8 9 0 1 2 3 4 5 + +--+--+-+-+-+-+-+-+-+-+-+-+-+-+ + |M |M |M|M|M|C|M|M|M|C|M|M|M|M| + |11|10|9|8|7|1|6|5|4|0|3|2|1|0| + +--+--+-+-+-+-+-+-+-+-+-+-+-+-+ + + Figure 3: Format of STUN Message Type Field + + Here the bits in the message type field are shown as most significant + (M11) through least significant (M0). M11 through M0 represent a 12- + bit encoding of the method. C1 and C0 represent a 2-bit encoding of + the class. + */ + + // Get type field. + uint16_t msgType = Utils::Byte::Get2Bytes(data, 0); + + // Get length field. + uint16_t msgLength = Utils::Byte::Get2Bytes(data, 2); + + // length field must be total size minus header's 20 bytes, and must be multiple of 4 Bytes. + if ((static_cast(msgLength) != len - 20) || ((msgLength & 0x03) != 0)) { + ELOG_DEBUG( + "length field + 20 does not match total size (or it is not multiple of 4 bytes), " + "packet discarded"); + + return nullptr; + } + + // Get STUN method. + uint16_t msgMethod = (msgType & 0x000f) | ((msgType & 0x00e0) >> 1) | ((msgType & 0x3E00) >> 2); + + // Get STUN class. + uint16_t msgClass = ((data[0] & 0x01) << 1) | ((data[1] & 0x10) >> 4); + + // Create a new StunPacket (data + 8 points to the received TransactionID field). + auto packet = new StunPacket(static_cast(msgClass), static_cast(msgMethod), + data + 8, data, len); + + /* + STUN Attributes + + After the STUN header are zero or more attributes. Each attribute + MUST be TLV encoded, with a 16-bit type, 16-bit length, and value. + Each STUN attribute MUST end on a 32-bit boundary. As mentioned + above, all fields in an attribute are transmitted most significant + bit first. + + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Type | Length | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Value (variable) .... + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + */ + + // Start looking for attributes after STUN header (Byte #20). + size_t pos{20}; + // Flags (positions) for special MESSAGE-INTEGRITY and FINGERPRINT attributes. + bool hasMessageIntegrity{false}; + bool hasFingerprint{false}; + size_t fingerprintAttrPos; // Will point to the beginning of the attribute. + uint32_t fingerprint; // Holds the value of the FINGERPRINT attribute. + + // Ensure there are at least 4 remaining bytes (attribute with 0 length). + while (pos + 4 <= len) { + // Get the attribute type. + auto attrType = static_cast(Utils::Byte::Get2Bytes(data, pos)); + + // Get the attribute length. + uint16_t attrLength = Utils::Byte::Get2Bytes(data, pos + 2); + + // Ensure the attribute length is not greater than the remaining size. + if ((pos + 4 + attrLength) > len) { + ELOG_DEBUG("the attribute length exceeds the remaining size, packet discarded"); + + delete packet; + return nullptr; + } + + // FINGERPRINT must be the last attribute. + if (hasFingerprint) { + ELOG_DEBUG("attribute after FINGERPRINT is not allowed, packet discarded"); + + delete packet; + return nullptr; + } + + // After a MESSAGE-INTEGRITY attribute just FINGERPRINT is allowed. + if (hasMessageIntegrity && attrType != Attribute::FINGERPRINT) { + ELOG_DEBUG( + "attribute after MESSAGE-INTEGRITY other than FINGERPRINT is not allowed, " + "packet discarded"); + + delete packet; + return nullptr; + } + + const uint8_t* attrValuePos = data + pos + 4; + + switch (attrType) { + case Attribute::USERNAME: { + packet->SetUsername(reinterpret_cast(attrValuePos), + static_cast(attrLength)); + + break; + } + + case Attribute::PRIORITY: { + // Ensure attribute length is 4 bytes. + if (attrLength != 4) { + ELOG_DEBUG("attribute PRIORITY must be 4 bytes length, packet discarded"); + + delete packet; + return nullptr; + } + + packet->SetPriority(Utils::Byte::Get4Bytes(attrValuePos, 0)); + + break; + } + + case Attribute::ICE_CONTROLLING: { + // Ensure attribute length is 8 bytes. + if (attrLength != 8) { + ELOG_DEBUG("attribute ICE-CONTROLLING must be 8 bytes length, packet discarded"); + + delete packet; + return nullptr; + } + + packet->SetIceControlling(Utils::Byte::Get8Bytes(attrValuePos, 0)); + + break; + } + + case Attribute::ICE_CONTROLLED: { + // Ensure attribute length is 8 bytes. + if (attrLength != 8) { + ELOG_DEBUG("attribute ICE-CONTROLLED must be 8 bytes length, packet discarded"); + + delete packet; + return nullptr; + } + + packet->SetIceControlled(Utils::Byte::Get8Bytes(attrValuePos, 0)); + + break; + } + + case Attribute::USE_CANDIDATE: { + // Ensure attribute length is 0 bytes. + if (attrLength != 0) { + ELOG_DEBUG("attribute USE-CANDIDATE must be 0 bytes length, packet discarded"); + + delete packet; + return nullptr; + } + + packet->SetUseCandidate(); + + break; + } + + case Attribute::MESSAGE_INTEGRITY: { + // Ensure attribute length is 20 bytes. + if (attrLength != 20) { + ELOG_DEBUG("attribute MESSAGE-INTEGRITY must be 20 bytes length, packet discarded"); + + delete packet; + return nullptr; + } + + hasMessageIntegrity = true; + packet->SetMessageIntegrity(attrValuePos); + + break; + } + + case Attribute::FINGERPRINT: { + // Ensure attribute length is 4 bytes. + if (attrLength != 4) { + ELOG_DEBUG("attribute FINGERPRINT must be 4 bytes length, packet discarded"); + + delete packet; + return nullptr; + } + + hasFingerprint = true; + fingerprintAttrPos = pos; + fingerprint = Utils::Byte::Get4Bytes(attrValuePos, 0); + packet->SetFingerprint(); + + break; + } + + case Attribute::ERROR_CODE: { + // Ensure attribute length >= 4bytes. + if (attrLength < 4) { + ELOG_DEBUG("attribute ERROR-CODE must be >= 4bytes length, packet discarded"); + + delete packet; + return nullptr; + } + + uint8_t errorClass = Utils::Byte::Get1Byte(attrValuePos, 2); + uint8_t errorNumber = Utils::Byte::Get1Byte(attrValuePos, 3); + auto errorCode = static_cast(errorClass * 100 + errorNumber); + + packet->SetErrorCode(errorCode); + + break; + } + + default:; + } + + // Set next attribute position. + pos = static_cast(Utils::Byte::PadTo4Bytes(static_cast(pos + 4 + attrLength))); + } + + // Ensure current position matches the total length. + if (pos != len) { + ELOG_DEBUG("computed packet size does not match total size, packet discarded"); + + delete packet; + return nullptr; + } + + // If it has FINGERPRINT attribute then verify it. + if (hasFingerprint) { + // Compute the CRC32 of the received packet up to (but excluding) the + // FINGERPRINT attribute and XOR it with 0x5354554e. + uint32_t computedFingerprint = Utils::Crypto::GetCRC32(data, fingerprintAttrPos) ^ 0x5354554e; + + // Compare with the FINGERPRINT value in the packet. + if (fingerprint != computedFingerprint) { + ELOG_DEBUG( + "computed FINGERPRINT value does not match the value in the packet, " + "packet discarded"); + + delete packet; + return nullptr; + } + } + + return packet; +} + +/* Instance methods. */ + +StunPacket::StunPacket(Class klass, Method method, const uint8_t* transactionId, + const uint8_t* data, size_t size) + : klass(klass), + method(method), + transactionId(transactionId), + data(const_cast(data)), + size(size) { + // MS_TRACE(); +} + +StunPacket::~StunPacket() { + // MS_TRACE(); +} + +void StunPacket::Dump() const { + // MS_TRACE(); + + // MS_DUMP(""); + + std::string klass; + switch (this->klass) { + case Class::REQUEST: + klass = "Request"; + break; + case Class::INDICATION: + klass = "Indication"; + break; + case Class::SUCCESS_RESPONSE: + klass = "SuccessResponse"; + break; + case Class::ERROR_RESPONSE: + klass = "ErrorResponse"; + break; + } + if (this->method == Method::BINDING) { + // MS_DUMP(" Binding %s", klass.c_str()); + } else { + // This prints the unknown method number. Example: TURN Allocate => 0x003. + // MS_DUMP(" %s with unknown method %#.3x", klass.c_str(), + // static_cast(this->method)); + } + // MS_DUMP(" size: %zu bytes", this->size); + + static char transactionId[25]; + + for (int i{0}; i < 12; ++i) { + // NOTE: n must be 3 because snprintf adds a \0 after printed chars. + std::snprintf(transactionId + (i * 2), 3, "%.2x", this->transactionId[i]); + } + // MS_DUMP(" transactionId: %s", transactionId); + if (this->errorCode != 0u) + // MS_DUMP(" errorCode: %" PRIu16, this->errorCode); + if (!this->username.empty()) + // MS_DUMP(" username: %s", this->username.c_str()); + if (this->priority != 0u) + // MS_DUMP(" priority: %" PRIu32, this->priority); + if (this->iceControlling != 0u) + // MS_DUMP(" iceControlling: %" PRIu64, this->iceControlling); + if (this->iceControlled != 0u) + // MS_DUMP(" iceControlled: %" PRIu64, this->iceControlled); + if (this->hasUseCandidate) + // MS_DUMP(" useCandidate"); + if (this->xorMappedAddress != nullptr) { + int family; + uint16_t port; + std::string ip; + + Utils::IP::GetAddressInfo(this->xorMappedAddress, family, ip, port); + + // MS_DUMP(" xorMappedAddress: %s : %" PRIu16, ip.c_str(), port); + } + if (this->messageIntegrity != nullptr) { + static char messageIntegrity[41]; + + for (int i{0}; i < 20; ++i) { + std::snprintf(messageIntegrity + (i * 2), 3, "%.2x", this->messageIntegrity[i]); + } + + // MS_DUMP(" messageIntegrity: %s", messageIntegrity); + } + if (this->hasFingerprint) { + } + // MS_DUMP(" has fingerprint"); + + // MS_DUMP(""); +} + +StunPacket::Authentication StunPacket::CheckAuthentication(const std::string& localUsername, + const std::string& localPassword) { + // MS_TRACE(); + + switch (this->klass) { + case Class::REQUEST: + case Class::INDICATION: { + // Both USERNAME and MESSAGE-INTEGRITY must be present. + if (this->messageIntegrity == nullptr || this->username.empty()) + return Authentication::BAD_REQUEST; + + // Check that USERNAME attribute begins with our local username plus ":". + size_t localUsernameLen = localUsername.length(); + + if (this->username.length() <= localUsernameLen || + this->username.at(localUsernameLen) != ':' || + (this->username.compare(0, localUsernameLen, localUsername) != 0)) { + return Authentication::UNAUTHORIZED; + } + + break; + } + // This method cannot check authentication in received responses (as we + // are ICE-Lite and don't generate requests). + case Class::SUCCESS_RESPONSE: + case Class::ERROR_RESPONSE: { + // MS_ERROR("cannot check authentication for a STUN response"); + + return Authentication::BAD_REQUEST; + } + } + + // If there is FINGERPRINT it must be discarded for MESSAGE-INTEGRITY calculation, + // so the header length field must be modified (and later restored). + if (this->hasFingerprint) + // Set the header length field: full size - header length (20) - FINGERPRINT length (8). + Utils::Byte::Set2Bytes(this->data, 2, static_cast(this->size - 20 - 8)); + + // Calculate the HMAC-SHA1 of the message according to MESSAGE-INTEGRITY rules. + const uint8_t* computedMessageIntegrity = Utils::Crypto::GetHmacShA1( + localPassword, this->data, (this->messageIntegrity - 4) - this->data); + + Authentication result; + + // Compare the computed HMAC-SHA1 with the MESSAGE-INTEGRITY in the packet. + if (std::memcmp(this->messageIntegrity, computedMessageIntegrity, 20) == 0) + result = Authentication::OK; + else + result = Authentication::UNAUTHORIZED; + + // Restore the header length field. + if (this->hasFingerprint) + Utils::Byte::Set2Bytes(this->data, 2, static_cast(this->size - 20)); + + return result; +} + +StunPacket* StunPacket::CreateSuccessResponse() { + // MS_TRACE(); + + // MS_ASSERT( + // this->klass == Class::REQUEST, + // "attempt to create a success response for a non Request STUN packet"); + + return new StunPacket(Class::SUCCESS_RESPONSE, this->method, this->transactionId, nullptr, 0); +} + +StunPacket* StunPacket::CreateErrorResponse(uint16_t errorCode) { + // MS_TRACE(); + + // MS_ASSERT( + // this->klass == Class::REQUEST, + // "attempt to create an error response for a non Request STUN packet"); + + auto response = + new StunPacket(Class::ERROR_RESPONSE, this->method, this->transactionId, nullptr, 0); + + response->SetErrorCode(errorCode); + + return response; +} + +void StunPacket::Authenticate(const std::string& password) { + // Just for Request, Indication and SuccessResponse messages. + if (this->klass == Class::ERROR_RESPONSE) { + // MS_ERROR("cannot set password for ErrorResponse messages"); + + return; + } + + this->password = password; +} + +void StunPacket::Serialize(uint8_t* buffer) { + // MS_TRACE(); + + // Some useful variables. + uint16_t usernamePaddedLen{0}; + uint16_t xorMappedAddressPaddedLen{0}; + bool addXorMappedAddress = + ((this->xorMappedAddress != nullptr) && this->method == StunPacket::Method::BINDING && + this->klass == Class::SUCCESS_RESPONSE); + bool addErrorCode = ((this->errorCode != 0u) && this->klass == Class::ERROR_RESPONSE); + bool addMessageIntegrity = (this->klass != Class::ERROR_RESPONSE && !this->password.empty()); + bool addFingerprint{true}; // Do always. + + // Update data pointer. + this->data = buffer; + + // First calculate the total required size for the entire packet. + this->size = 20; // Header. + + if (!this->username.empty()) { + usernamePaddedLen = Utils::Byte::PadTo4Bytes(static_cast(this->username.length())); + this->size += 4 + usernamePaddedLen; + } + + if (this->priority != 0u) this->size += 4 + 4; + + if (this->iceControlling != 0u) this->size += 4 + 8; + + if (this->iceControlled != 0u) this->size += 4 + 8; + + if (this->hasUseCandidate) this->size += 4; + + if (addXorMappedAddress) { + switch (this->xorMappedAddress->sa_family) { + case AF_INET: { + xorMappedAddressPaddedLen = 8; + this->size += 4 + 8; + + break; + } + + case AF_INET6: { + xorMappedAddressPaddedLen = 20; + this->size += 4 + 20; + + break; + } + + default: { + // MS_ERROR("invalid inet family in XOR-MAPPED-ADDRESS attribute"); + + addXorMappedAddress = false; + } + } + } + + if (addErrorCode) this->size += 4 + 4; + + if (addMessageIntegrity) this->size += 4 + 20; + + if (addFingerprint) this->size += 4 + 4; + + // Merge class and method fields into type. + uint16_t typeField = (static_cast(this->method) & 0x0f80) << 2; + + typeField |= (static_cast(this->method) & 0x0070) << 1; + typeField |= (static_cast(this->method) & 0x000f); + typeField |= (static_cast(this->klass) & 0x02) << 7; + typeField |= (static_cast(this->klass) & 0x01) << 4; + + // Set type field. + Utils::Byte::Set2Bytes(buffer, 0, typeField); + // Set length field. + Utils::Byte::Set2Bytes(buffer, 2, static_cast(this->size) - 20); + // Set magic cookie. + std::memcpy(buffer + 4, StunPacket::kMagicCookie, 4); + // Set TransactionId field. + std::memcpy(buffer + 8, this->transactionId, 12); + // Update the transaction ID pointer. + this->transactionId = buffer + 8; + // Add atributes. + size_t pos{20}; + + // Add USERNAME. + if (usernamePaddedLen != 0u) { + Utils::Byte::Set2Bytes(buffer, pos, static_cast(Attribute::USERNAME)); + Utils::Byte::Set2Bytes(buffer, pos + 2, static_cast(this->username.length())); + std::memcpy(buffer + pos + 4, this->username.c_str(), this->username.length()); + pos += 4 + usernamePaddedLen; + } + + // Add PRIORITY. + if (this->priority != 0u) { + Utils::Byte::Set2Bytes(buffer, pos, static_cast(Attribute::PRIORITY)); + Utils::Byte::Set2Bytes(buffer, pos + 2, 4); + Utils::Byte::Set4Bytes(buffer, pos + 4, this->priority); + pos += 4 + 4; + } + + // Add ICE-CONTROLLING. + if (this->iceControlling != 0u) { + Utils::Byte::Set2Bytes(buffer, pos, static_cast(Attribute::ICE_CONTROLLING)); + Utils::Byte::Set2Bytes(buffer, pos + 2, 8); + Utils::Byte::Set8Bytes(buffer, pos + 4, this->iceControlling); + pos += 4 + 8; + } + + // Add ICE-CONTROLLED. + if (this->iceControlled != 0u) { + Utils::Byte::Set2Bytes(buffer, pos, static_cast(Attribute::ICE_CONTROLLED)); + Utils::Byte::Set2Bytes(buffer, pos + 2, 8); + Utils::Byte::Set8Bytes(buffer, pos + 4, this->iceControlled); + pos += 4 + 8; + } + + // Add USE-CANDIDATE. + if (this->hasUseCandidate) { + Utils::Byte::Set2Bytes(buffer, pos, static_cast(Attribute::USE_CANDIDATE)); + Utils::Byte::Set2Bytes(buffer, pos + 2, 0); + pos += 4; + } + + // Add XOR-MAPPED-ADDRESS + if (addXorMappedAddress) { + Utils::Byte::Set2Bytes(buffer, pos, static_cast(Attribute::XOR_MAPPED_ADDRESS)); + Utils::Byte::Set2Bytes(buffer, pos + 2, xorMappedAddressPaddedLen); + + uint8_t* attrValue = buffer + pos + 4; + + switch (this->xorMappedAddress->sa_family) { + case AF_INET: { + // Set first byte to 0. + attrValue[0] = 0; + // Set inet family. + attrValue[1] = 0x01; + // Set port and XOR it. + std::memcpy(attrValue + 2, + &(reinterpret_cast(this->xorMappedAddress))->sin_port, 2); + attrValue[2] ^= StunPacket::kMagicCookie[0]; + attrValue[3] ^= StunPacket::kMagicCookie[1]; + // Set address and XOR it. + std::memcpy( + attrValue + 4, + &(reinterpret_cast(this->xorMappedAddress))->sin_addr.s_addr, 4); + attrValue[4] ^= StunPacket::kMagicCookie[0]; + attrValue[5] ^= StunPacket::kMagicCookie[1]; + attrValue[6] ^= StunPacket::kMagicCookie[2]; + attrValue[7] ^= StunPacket::kMagicCookie[3]; + + pos += 4 + 8; + + break; + } + + case AF_INET6: { + // Set first byte to 0. + attrValue[0] = 0; + // Set inet family. + attrValue[1] = 0x02; + // Set port and XOR it. + std::memcpy(attrValue + 2, + &(reinterpret_cast(this->xorMappedAddress))->sin6_port, 2); + attrValue[2] ^= StunPacket::kMagicCookie[0]; + attrValue[3] ^= StunPacket::kMagicCookie[1]; + // Set address and XOR it. + std::memcpy( + attrValue + 4, + &(reinterpret_cast(this->xorMappedAddress))->sin6_addr.s6_addr, + 16); + attrValue[4] ^= StunPacket::kMagicCookie[0]; + attrValue[5] ^= StunPacket::kMagicCookie[1]; + attrValue[6] ^= StunPacket::kMagicCookie[2]; + attrValue[7] ^= StunPacket::kMagicCookie[3]; + attrValue[8] ^= this->transactionId[0]; + attrValue[9] ^= this->transactionId[1]; + attrValue[10] ^= this->transactionId[2]; + attrValue[11] ^= this->transactionId[3]; + attrValue[12] ^= this->transactionId[4]; + attrValue[13] ^= this->transactionId[5]; + attrValue[14] ^= this->transactionId[6]; + attrValue[15] ^= this->transactionId[7]; + attrValue[16] ^= this->transactionId[8]; + attrValue[17] ^= this->transactionId[9]; + attrValue[18] ^= this->transactionId[10]; + attrValue[19] ^= this->transactionId[11]; + + pos += 4 + 20; + + break; + } + } + } + + // Add ERROR-CODE. + if (addErrorCode) { + Utils::Byte::Set2Bytes(buffer, pos, static_cast(Attribute::ERROR_CODE)); + Utils::Byte::Set2Bytes(buffer, pos + 2, 4); + + auto codeClass = static_cast(this->errorCode / 100); + uint8_t codeNumber = static_cast(this->errorCode) - (codeClass * 100); + + Utils::Byte::Set2Bytes(buffer, pos + 4, 0); + Utils::Byte::Set1Byte(buffer, pos + 6, codeClass); + Utils::Byte::Set1Byte(buffer, pos + 7, codeNumber); + pos += 4 + 4; + } + + // Add MESSAGE-INTEGRITY. + if (addMessageIntegrity) { + // Ignore FINGERPRINT. + if (addFingerprint) + Utils::Byte::Set2Bytes(buffer, 2, static_cast(this->size - 20 - 8)); + + // Calculate the HMAC-SHA1 of the packet according to MESSAGE-INTEGRITY rules. + const uint8_t* computedMessageIntegrity = + Utils::Crypto::GetHmacShA1(this->password, buffer, pos); + + Utils::Byte::Set2Bytes(buffer, pos, static_cast(Attribute::MESSAGE_INTEGRITY)); + Utils::Byte::Set2Bytes(buffer, pos + 2, 20); + std::memcpy(buffer + pos + 4, computedMessageIntegrity, 20); + + // Update the pointer. + this->messageIntegrity = buffer + pos + 4; + pos += 4 + 20; + + // Restore length field. + if (addFingerprint) Utils::Byte::Set2Bytes(buffer, 2, static_cast(this->size - 20)); + } else { + // Unset the pointer (if it was set). + this->messageIntegrity = nullptr; + } + + // Add FINGERPRINT. + if (addFingerprint) { + // Compute the CRC32 of the packet up to (but excluding) the FINGERPRINT + // attribute and XOR it with 0x5354554e. + uint32_t computedFingerprint = Utils::Crypto::GetCRC32(buffer, pos) ^ 0x5354554e; + + Utils::Byte::Set2Bytes(buffer, pos, static_cast(Attribute::FINGERPRINT)); + Utils::Byte::Set2Bytes(buffer, pos + 2, 4); + Utils::Byte::Set4Bytes(buffer, pos + 4, computedFingerprint); + pos += 4 + 4; + + // Set flag. + this->hasFingerprint = true; + } else { + this->hasFingerprint = false; + } + + // MS_ASSERT(pos == this->size, "pos != this->size"); +} +} // namespace RTC diff --git a/webrtc/stun_packet.h b/webrtc/stun_packet.h new file mode 100644 index 00000000..861d11cd --- /dev/null +++ b/webrtc/stun_packet.h @@ -0,0 +1,179 @@ +#ifndef MS_RTC_STUN_PACKET_HPP +#define MS_RTC_STUN_PACKET_HPP + + +#include "logger.h" +#include "utils.h" +#include + +namespace RTC { +class StunPacket { +public: + // STUN message class. + enum class Class : uint16_t { + REQUEST = 0, + INDICATION = 1, + SUCCESS_RESPONSE = 2, + ERROR_RESPONSE = 3 + }; + + // STUN message method. + enum class Method : uint16_t { BINDING = 1 }; + + // Attribute type. + enum class Attribute : uint16_t { + MAPPED_ADDRESS = 0x0001, + USERNAME = 0x0006, + MESSAGE_INTEGRITY = 0x0008, + ERROR_CODE = 0x0009, + UNKNOWN_ATTRIBUTES = 0x000A, + REALM = 0x0014, + NONCE = 0x0015, + XOR_MAPPED_ADDRESS = 0x0020, + PRIORITY = 0x0024, + USE_CANDIDATE = 0x0025, + SOFTWARE = 0x8022, + ALTERNATE_SERVER = 0x8023, + FINGERPRINT = 0x8028, + ICE_CONTROLLED = 0x8029, + ICE_CONTROLLING = 0x802A + }; + + // Authentication result. + enum class Authentication { OK = 0, UNAUTHORIZED = 1, BAD_REQUEST = 2 }; + +public: + static bool IsStun(const uint8_t *data, size_t len); + static StunPacket *Parse(const uint8_t *data, size_t len); + +private: + static const uint8_t kMagicCookie[]; + +public: + StunPacket(Class klass, Method method, const uint8_t *transactionId, const uint8_t *data, + size_t size); + ~StunPacket(); + + void Dump() const; + Class GetClass() const; + Method GetMethod() const; + const uint8_t *GetData() const; + size_t GetSize() const; + void SetUsername(const char *username, size_t len); + void SetPriority(uint32_t priority); + void SetIceControlling(uint64_t iceControlling); + void SetIceControlled(uint64_t iceControlled); + void SetUseCandidate(); + void SetXorMappedAddress(const struct sockaddr *xorMappedAddress); + void SetErrorCode(uint16_t errorCode); + void SetMessageIntegrity(const uint8_t *messageIntegrity); + void SetFingerprint(); + const std::string &GetUsername() const; + uint32_t GetPriority() const; + uint64_t GetIceControlling() const; + uint64_t GetIceControlled() const; + bool HasUseCandidate() const; + uint16_t GetErrorCode() const; + bool HasMessageIntegrity() const; + bool HasFingerprint() const; + Authentication CheckAuthentication(const std::string &localUsername, + const std::string &localPassword); + StunPacket *CreateSuccessResponse(); + StunPacket *CreateErrorResponse(uint16_t errorCode); + void Authenticate(const std::string &password); + void Serialize(uint8_t *buffer); + +private: + // Passed by argument. + Class klass; // 2 bytes. + Method method; // 2 bytes. + const uint8_t *transactionId{nullptr};// 12 bytes. + uint8_t *data{nullptr}; // Pointer to binary data. + size_t size{0}; // The full message size (including header). + // STUN attributes. + std::string username; // Less than 513 bytes. + uint32_t priority{0}; // 4 bytes unsigned integer. + uint64_t iceControlling{0}; // 8 bytes unsigned integer. + uint64_t iceControlled{0}; // 8 bytes unsigned integer. + bool hasUseCandidate{false}; // 0 bytes. + const uint8_t *messageIntegrity{nullptr}; // 20 bytes. + bool hasFingerprint{false}; // 4 bytes. + const struct sockaddr *xorMappedAddress{nullptr};// 8 or 20 bytes. + uint16_t errorCode{0}; // 4 bytes (no reason phrase). + std::string password; +}; + +/* Inline class methods. */ + +inline bool StunPacket::IsStun(const uint8_t *data, size_t len) { + // clang-format off + return ( + // STUN headers are 20 bytes. + (len >= 20) && + // DOC: https://tools.ietf.org/html/draft-ietf-avtcore-rfc5764-mux-fixes + (data[0] < 3) && + // Magic cookie must match. + (data[4] == StunPacket::kMagicCookie[0]) && (data[5] == StunPacket::kMagicCookie[1]) && + (data[6] == StunPacket::kMagicCookie[2]) && (data[7] == StunPacket::kMagicCookie[3]) + ); + // clang-format on +} + +/* Inline instance methods. */ + +inline StunPacket::Class StunPacket::GetClass() const { return this->klass; } + +inline StunPacket::Method StunPacket::GetMethod() const { return this->method; } + +inline const uint8_t *StunPacket::GetData() const { return this->data; } + +inline size_t StunPacket::GetSize() const { return this->size; } + +inline void StunPacket::SetUsername(const char *username, size_t len) { + this->username.assign(username, len); +} + +inline void StunPacket::SetPriority(const uint32_t priority) { this->priority = priority; } + +inline void StunPacket::SetIceControlling(const uint64_t iceControlling) { + this->iceControlling = iceControlling; +} + +inline void StunPacket::SetIceControlled(const uint64_t iceControlled) { + this->iceControlled = iceControlled; +} + +inline void StunPacket::SetUseCandidate() { this->hasUseCandidate = true; } + +inline void StunPacket::SetXorMappedAddress(const struct sockaddr *xorMappedAddress) { + this->xorMappedAddress = xorMappedAddress; +} + +inline void StunPacket::SetErrorCode(uint16_t errorCode) { this->errorCode = errorCode; } + +inline void StunPacket::SetMessageIntegrity(const uint8_t *messageIntegrity) { + this->messageIntegrity = messageIntegrity; +} + +inline void StunPacket::SetFingerprint() { this->hasFingerprint = true; } + +inline const std::string &StunPacket::GetUsername() const { return this->username; } + +inline uint32_t StunPacket::GetPriority() const { return this->priority; } + +inline uint64_t StunPacket::GetIceControlling() const { return this->iceControlling; } + +inline uint64_t StunPacket::GetIceControlled() const { return this->iceControlled; } + +inline bool StunPacket::HasUseCandidate() const { return this->hasUseCandidate; } + +inline uint16_t StunPacket::GetErrorCode() const { return this->errorCode; } + +inline bool StunPacket::HasMessageIntegrity() const { + return (this->messageIntegrity ? true : false); +} + +inline bool StunPacket::HasFingerprint() const { return this->hasFingerprint; } +}// namespace RTC + +#endif diff --git a/webrtc/utils.cc b/webrtc/utils.cc new file mode 100644 index 00000000..ab9f1fda --- /dev/null +++ b/webrtc/utils.cc @@ -0,0 +1,139 @@ +#define MS_CLASS "Utils::Crypto" +// #define MS_LOG_DEV + +#include "utils.h" + +#include "openssl/sha.h" + +namespace Utils { +/* Static variables. */ + +uint32_t Crypto::seed; +HMAC_CTX *Crypto::hmacSha1Ctx{nullptr}; +uint8_t Crypto::hmacSha1Buffer[20];// SHA-1 result is 20 bytes long. +// clang-format off +const uint32_t Crypto::crc32Table[] = +{ + 0x00000000, 0x77073096, 0xee0e612c, 0x990951ba, 0x076dc419, 0x706af48f, 0xe963a535, 0x9e6495a3, + 0x0edb8832, 0x79dcb8a4, 0xe0d5e91e, 0x97d2d988, 0x09b64c2b, 0x7eb17cbd, 0xe7b82d07, 0x90bf1d91, + 0x1db71064, 0x6ab020f2, 0xf3b97148, 0x84be41de, 0x1adad47d, 0x6ddde4eb, 0xf4d4b551, 0x83d385c7, + 0x136c9856, 0x646ba8c0, 0xfd62f97a, 0x8a65c9ec, 0x14015c4f, 0x63066cd9, 0xfa0f3d63, 0x8d080df5, + 0x3b6e20c8, 0x4c69105e, 0xd56041e4, 0xa2677172, 0x3c03e4d1, 0x4b04d447, 0xd20d85fd, 0xa50ab56b, + 0x35b5a8fa, 0x42b2986c, 0xdbbbc9d6, 0xacbcf940, 0x32d86ce3, 0x45df5c75, 0xdcd60dcf, 0xabd13d59, + 0x26d930ac, 0x51de003a, 0xc8d75180, 0xbfd06116, 0x21b4f4b5, 0x56b3c423, 0xcfba9599, 0xb8bda50f, + 0x2802b89e, 0x5f058808, 0xc60cd9b2, 0xb10be924, 0x2f6f7c87, 0x58684c11, 0xc1611dab, 0xb6662d3d, + 0x76dc4190, 0x01db7106, 0x98d220bc, 0xefd5102a, 0x71b18589, 0x06b6b51f, 0x9fbfe4a5, 0xe8b8d433, + 0x7807c9a2, 0x0f00f934, 0x9609a88e, 0xe10e9818, 0x7f6a0dbb, 0x086d3d2d, 0x91646c97, 0xe6635c01, + 0x6b6b51f4, 0x1c6c6162, 0x856530d8, 0xf262004e, 0x6c0695ed, 0x1b01a57b, 0x8208f4c1, 0xf50fc457, + 0x65b0d9c6, 0x12b7e950, 0x8bbeb8ea, 0xfcb9887c, 0x62dd1ddf, 0x15da2d49, 0x8cd37cf3, 0xfbd44c65, + 0x4db26158, 0x3ab551ce, 0xa3bc0074, 0xd4bb30e2, 0x4adfa541, 0x3dd895d7, 0xa4d1c46d, 0xd3d6f4fb, + 0x4369e96a, 0x346ed9fc, 0xad678846, 0xda60b8d0, 0x44042d73, 0x33031de5, 0xaa0a4c5f, 0xdd0d7cc9, + 0x5005713c, 0x270241aa, 0xbe0b1010, 0xc90c2086, 0x5768b525, 0x206f85b3, 0xb966d409, 0xce61e49f, + 0x5edef90e, 0x29d9c998, 0xb0d09822, 0xc7d7a8b4, 0x59b33d17, 0x2eb40d81, 0xb7bd5c3b, 0xc0ba6cad, + 0xedb88320, 0x9abfb3b6, 0x03b6e20c, 0x74b1d29a, 0xead54739, 0x9dd277af, 0x04db2615, 0x73dc1683, + 0xe3630b12, 0x94643b84, 0x0d6d6a3e, 0x7a6a5aa8, 0xe40ecf0b, 0x9309ff9d, 0x0a00ae27, 0x7d079eb1, + 0xf00f9344, 0x8708a3d2, 0x1e01f268, 0x6906c2fe, 0xf762575d, 0x806567cb, 0x196c3671, 0x6e6b06e7, + 0xfed41b76, 0x89d32be0, 0x10da7a5a, 0x67dd4acc, 0xf9b9df6f, 0x8ebeeff9, 0x17b7be43, 0x60b08ed5, + 0xd6d6a3e8, 0xa1d1937e, 0x38d8c2c4, 0x4fdff252, 0xd1bb67f1, 0xa6bc5767, 0x3fb506dd, 0x48b2364b, + 0xd80d2bda, 0xaf0a1b4c, 0x36034af6, 0x41047a60, 0xdf60efc3, 0xa867df55, 0x316e8eef, 0x4669be79, + 0xcb61b38c, 0xbc66831a, 0x256fd2a0, 0x5268e236, 0xcc0c7795, 0xbb0b4703, 0x220216b9, 0x5505262f, + 0xc5ba3bbe, 0xb2bd0b28, 0x2bb45a92, 0x5cb36a04, 0xc2d7ffa7, 0xb5d0cf31, 0x2cd99e8b, 0x5bdeae1d, + 0x9b64c2b0, 0xec63f226, 0x756aa39c, 0x026d930a, 0x9c0906a9, 0xeb0e363f, 0x72076785, 0x05005713, + 0x95bf4a82, 0xe2b87a14, 0x7bb12bae, 0x0cb61b38, 0x92d28e9b, 0xe5d5be0d, 0x7cdcefb7, 0x0bdbdf21, + 0x86d3d2d4, 0xf1d4e242, 0x68ddb3f8, 0x1fda836e, 0x81be16cd, 0xf6b9265b, 0x6fb077e1, 0x18b74777, + 0x88085ae6, 0xff0f6a70, 0x66063bca, 0x11010b5c, 0x8f659eff, 0xf862ae69, 0x616bffd3, 0x166ccf45, + 0xa00ae278, 0xd70dd2ee, 0x4e048354, 0x3903b3c2, 0xa7672661, 0xd06016f7, 0x4969474d, 0x3e6e77db, + 0xaed16a4a, 0xd9d65adc, 0x40df0b66, 0x37d83bf0, 0xa9bcae53, 0xdebb9ec5, 0x47b2cf7f, 0x30b5ffe9, + 0xbdbdf21c, 0xcabac28a, 0x53b39330, 0x24b4a3a6, 0xbad03605, 0xcdd70693, 0x54de5729, 0x23d967bf, + 0xb3667a2e, 0xc4614ab8, 0x5d681b02, 0x2a6f2b94, 0xb40bbe37, 0xc30c8ea1, 0x5a05df1b, 0x2d02ef8d +}; +// clang-format on + +/* Static methods. */ + +void Crypto::ClassInit() { + // MS_TRACE(); + + // Init the vrypto seed with a random number taken from the address + // of the seed variable itself (which is random). + Crypto::seed = static_cast(reinterpret_cast(std::addressof(Crypto::seed))); + + // Create an OpenSSL HMAC_CTX context for HMAC SHA1 calculation. + // Crypto::hmacSha1Ctx = HMAC_CTX_new(); + if (Crypto::hmacSha1Ctx == nullptr) { + Crypto::hmacSha1Ctx = HMAC_CTX_new(); + } +} + +void Crypto::ClassDestroy() { + // MS_TRACE(); + + if (Crypto::hmacSha1Ctx != nullptr) { + HMAC_CTX_free(Crypto::hmacSha1Ctx); + } +} + +const uint8_t *Crypto::GetHmacShA1(const std::string &key, const uint8_t *data, size_t len) { + // MS_TRACE(); + + size_t ret; + + ret = HMAC_Init_ex(Crypto::hmacSha1Ctx, key.c_str(), key.length(), EVP_sha1(), nullptr); + + // MS_ASSERT(ret == 1, "OpenSSL HMAC_Init_ex() failed with key '%s'", key.c_str()); + + ret = HMAC_Update(Crypto::hmacSha1Ctx, data, static_cast(len)); + /* + MS_ASSERT( + ret == 1, + "OpenSSL HMAC_Update() failed with key '%s' and data length %zu bytes", + key.c_str(), + len); + */ + uint32_t resultLen; + + ret = HMAC_Final(Crypto::hmacSha1Ctx, (uint8_t *) Crypto::hmacSha1Buffer, &resultLen); + + /* + MS_ASSERT( + ret == 1, "OpenSSL HMAC_Final() failed with key '%s' and data length %zu bytes", key.c_str(), + len); MS_ASSERT(resultLen == 20, "OpenSSL HMAC_Final() resultLen is %u instead of 20", resultLen); + */ + return Crypto::hmacSha1Buffer; +} +}// namespace Utils + +namespace Utils { + +static std::string inet_ntoa(struct in_addr in) { + char buf[20]; + unsigned char *p = (unsigned char *) &(in); + snprintf(buf, sizeof(buf), "%u.%u.%u.%u", p[0], p[1], p[2], p[3]); + return buf; +} + +void IP::GetAddressInfo(const struct sockaddr *addr, int &family, std::string &ip, uint16_t &port) { + char ipBuffer[INET6_ADDRSTRLEN + 1]; + + switch (addr->sa_family) { + case AF_INET: { + ip = Utils::inet_ntoa(reinterpret_cast(addr)->sin_addr); + port = static_cast(ntohs(reinterpret_cast(addr)->sin_port)); + break; + } + + case AF_INET6: { + port = static_cast(ntohs(reinterpret_cast(addr)->sin6_port)); + break; + } + + default: { + // MS_ABORT("unknown network family: %d", static_cast(addr->sa_family)); + } + } + + family = addr->sa_family; + ip.assign(ipBuffer); +} + +}// namespace Utils \ No newline at end of file diff --git a/webrtc/utils.h b/webrtc/utils.h new file mode 100644 index 00000000..1cd6ba9d --- /dev/null +++ b/webrtc/utils.h @@ -0,0 +1,318 @@ +#ifndef MS_UTILS_HPP +#define MS_UTILS_HPP + +#if defined(_WIN32) +#include +#include +#include +#pragma comment (lib, "Ws2_32.lib") +#pragma comment(lib,"Iphlpapi.lib") +#else +#include +#include +#include +#include +#include +#include +#include +#endif // defined(_WIN32) + +#include // std::transform(), std::find(), std::min(), std::max() +#include // PRIu64, etc +#include +#include // size_t +#include // uint8_t, etc +#include // std::memcmp(), std::memcpy() +#include +#include +#include +#include +#include + +namespace Utils { +class IP { +public: + static int GetFamily(const char *ip, size_t ipLen); + static int GetFamily(const std::string &ip); + static void GetAddressInfo(const struct sockaddr *addr, int &family, std::string &ip, + uint16_t &port); + static bool CompareAddresses(const struct sockaddr *addr1, const struct sockaddr *addr2); + static struct sockaddr_storage CopyAddress(const struct sockaddr *addr); + static void NormalizeIp(std::string &ip); +}; + +/* Inline static methods. */ + +inline int IP::GetFamily(const std::string &ip) { return GetFamily(ip.c_str(), ip.size()); } + +inline bool IP::CompareAddresses(const struct sockaddr *addr1, const struct sockaddr *addr2) { + // Compare family. + if (addr1->sa_family != addr2->sa_family || + (addr1->sa_family != AF_INET && addr1->sa_family != AF_INET6)) { + return false; + } + + // Compare port. + if (reinterpret_cast(addr1)->sin_port != + reinterpret_cast(addr2)->sin_port) { + return false; + } + + // Compare IP. + switch (addr1->sa_family) { + case AF_INET: { + return (reinterpret_cast(addr1)->sin_addr.s_addr == + reinterpret_cast(addr2)->sin_addr.s_addr); + } + + case AF_INET6: { + return (std::memcmp( + std::addressof(reinterpret_cast(addr1)->sin6_addr), + std::addressof(reinterpret_cast(addr2)->sin6_addr), + 16) == 0 + ? true + : false); + } + + default: { + return false; + } + } +} + +inline struct sockaddr_storage IP::CopyAddress(const struct sockaddr *addr) { + struct sockaddr_storage copiedAddr; + + switch (addr->sa_family) { + case AF_INET: + std::memcpy(std::addressof(copiedAddr), addr, sizeof(struct sockaddr_in)); + break; + + case AF_INET6: + std::memcpy(std::addressof(copiedAddr), addr, sizeof(struct sockaddr_in6)); + break; + } + + return copiedAddr; +} + +class File { +public: + static void CheckFile(const char *file); +}; + +class Byte { +public: + /** + * Getters below get value in Host Byte Order. + * Setters below set value in Network Byte Order. + */ + static uint8_t Get1Byte(const uint8_t *data, size_t i); + static uint16_t Get2Bytes(const uint8_t *data, size_t i); + static uint32_t Get3Bytes(const uint8_t *data, size_t i); + static uint32_t Get4Bytes(const uint8_t *data, size_t i); + static uint64_t Get8Bytes(const uint8_t *data, size_t i); + static void Set1Byte(uint8_t *data, size_t i, uint8_t value); + static void Set2Bytes(uint8_t *data, size_t i, uint16_t value); + static void Set3Bytes(uint8_t *data, size_t i, uint32_t value); + static void Set4Bytes(uint8_t *data, size_t i, uint32_t value); + static void Set8Bytes(uint8_t *data, size_t i, uint64_t value); + static uint16_t PadTo4Bytes(uint16_t size); + static uint32_t PadTo4Bytes(uint32_t size); +}; + +/* Inline static methods. */ + +inline uint8_t Byte::Get1Byte(const uint8_t *data, size_t i) { return data[i]; } + +inline uint16_t Byte::Get2Bytes(const uint8_t *data, size_t i) { + return uint16_t{data[i + 1]} | uint16_t{data[i]} << 8; +} + +inline uint32_t Byte::Get3Bytes(const uint8_t *data, size_t i) { + return uint32_t{data[i + 2]} | uint32_t{data[i + 1]} << 8 | uint32_t{data[i]} << 16; +} + +inline uint32_t Byte::Get4Bytes(const uint8_t *data, size_t i) { + return uint32_t{data[i + 3]} | uint32_t{data[i + 2]} << 8 | uint32_t{data[i + 1]} << 16 | + uint32_t{data[i]} << 24; +} + +inline uint64_t Byte::Get8Bytes(const uint8_t *data, size_t i) { + return uint64_t{Byte::Get4Bytes(data, i)} << 32 | Byte::Get4Bytes(data, i + 4); +} + +inline void Byte::Set1Byte(uint8_t *data, size_t i, uint8_t value) { data[i] = value; } + +inline void Byte::Set2Bytes(uint8_t *data, size_t i, uint16_t value) { + data[i + 1] = static_cast(value); + data[i] = static_cast(value >> 8); +} + +inline void Byte::Set3Bytes(uint8_t *data, size_t i, uint32_t value) { + data[i + 2] = static_cast(value); + data[i + 1] = static_cast(value >> 8); + data[i] = static_cast(value >> 16); +} + +inline void Byte::Set4Bytes(uint8_t *data, size_t i, uint32_t value) { + data[i + 3] = static_cast(value); + data[i + 2] = static_cast(value >> 8); + data[i + 1] = static_cast(value >> 16); + data[i] = static_cast(value >> 24); +} + +inline void Byte::Set8Bytes(uint8_t *data, size_t i, uint64_t value) { + data[i + 7] = static_cast(value); + data[i + 6] = static_cast(value >> 8); + data[i + 5] = static_cast(value >> 16); + data[i + 4] = static_cast(value >> 24); + data[i + 3] = static_cast(value >> 32); + data[i + 2] = static_cast(value >> 40); + data[i + 1] = static_cast(value >> 48); + data[i] = static_cast(value >> 56); +} + +inline uint16_t Byte::PadTo4Bytes(uint16_t size) { + // If size is not multiple of 32 bits then pad it. + if (size & 0x03) + return (size & 0xFFFC) + 4; + else + return size; +} + +inline uint32_t Byte::PadTo4Bytes(uint32_t size) { + // If size is not multiple of 32 bits then pad it. + if (size & 0x03) + return (size & 0xFFFFFFFC) + 4; + else + return size; +} + +class Bits { +public: + static size_t CountSetBits(const uint16_t mask); +}; + +/* Inline static methods. */ + +class Crypto { +public: + static void ClassInit(); + static void ClassDestroy(); + static uint32_t GetRandomUInt(uint32_t min, uint32_t max); + static const std::string GetRandomString(size_t len); + static uint32_t GetCRC32(const uint8_t *data, size_t size); + static const uint8_t *GetHmacShA1(const std::string &key, const uint8_t *data, size_t len); + +private: + static uint32_t seed; + static HMAC_CTX *hmacSha1Ctx; + static uint8_t hmacSha1Buffer[]; + static const uint32_t crc32Table[256]; +}; + +/* Inline static methods. */ + +inline uint32_t Crypto::GetRandomUInt(uint32_t min, uint32_t max) { + // NOTE: This is the original, but produces very small values. + // Crypto::seed = (214013 * Crypto::seed) + 2531011; + // return (((Crypto::seed>>16)&0x7FFF) % (max - min + 1)) + min; + + // This seems to produce better results. + Crypto::seed = uint32_t{((214013 * Crypto::seed) + 2531011)}; + + return (((Crypto::seed >> 4) & 0x7FFF7FFF) % (max - min + 1)) + min; +} + +inline const std::string Crypto::GetRandomString(size_t len) { + static char buffer[64]; + static const char chars[] = {'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', + 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', + 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z'}; + + if (len > 64) len = 64; + + for (size_t i{0}; i < len; ++i) { + buffer[i] = chars[GetRandomUInt(0, sizeof(chars) - 1)]; + } + + return std::string(buffer, len); +} + +inline uint32_t Crypto::GetCRC32(const uint8_t *data, size_t size) { + uint32_t crc{0xFFFFFFFF}; + const uint8_t *p = data; + + while (size--) { + crc = Crypto::crc32Table[(crc ^ *p++) & 0xFF] ^ (crc >> 8); + } + + return crc ^ ~0U; +} + +class String { +public: + static void ToLowerCase(std::string &str); +}; + +inline void String::ToLowerCase(std::string &str) { + std::transform(str.begin(), str.end(), str.begin(), ::tolower); +} + +class Time { + // Seconds from Jan 1, 1900 to Jan 1, 1970. + static constexpr uint32_t UnixNtpOffset{0x83AA7E80}; + // NTP fractional unit. + static constexpr uint64_t NtpFractionalUnit{1LL << 32}; + +public: + struct Ntp { + uint32_t seconds; + uint32_t fractions; + }; + + static Time::Ntp TimeMs2Ntp(uint64_t ms); + static uint64_t Ntp2TimeMs(Time::Ntp ntp); + static bool IsNewerTimestamp(uint32_t timestamp, uint32_t prevTimestamp); + static uint32_t LatestTimestamp(uint32_t timestamp1, uint32_t timestamp2); +}; + +inline Time::Ntp Time::TimeMs2Ntp(uint64_t ms) { + Time::Ntp ntp;// NOLINT(cppcoreguidelines-pro-type-member-init) + + ntp.seconds = uint32_t(ms / 1000); + ntp.fractions = + static_cast((static_cast(ms % 1000) / 1000) * NtpFractionalUnit); + + return ntp; +} + +inline uint64_t Time::Ntp2TimeMs(Time::Ntp ntp) { + // clang-format off + return ( + static_cast(ntp.seconds) * 1000 + + static_cast(std::round((static_cast(ntp.fractions) * 1000) / NtpFractionalUnit)) + ); + // clang-format on +} + +inline bool Time::IsNewerTimestamp(uint32_t timestamp, uint32_t prevTimestamp) { + // Distinguish between elements that are exactly 0x80000000 apart. + // If t1>t2 and |t1-t2| = 0x80000000: IsNewer(t1,t2)=true, + // IsNewer(t2,t1)=false + // rather than having IsNewer(t1,t2) = IsNewer(t2,t1) = false. + if (static_cast(timestamp - prevTimestamp) == 0x80000000) + return timestamp > prevTimestamp; + + return timestamp != prevTimestamp && + static_cast(timestamp - prevTimestamp) < 0x80000000; +} + +inline uint32_t Time::LatestTimestamp(uint32_t timestamp1, uint32_t timestamp2) { + return IsNewerTimestamp(timestamp1, timestamp2) ? timestamp1 : timestamp2; +} + +}// namespace Utils + +#endif diff --git a/webrtc/webrtc_transport.cc b/webrtc/webrtc_transport.cc new file mode 100644 index 00000000..a475f2ac --- /dev/null +++ b/webrtc/webrtc_transport.cc @@ -0,0 +1,215 @@ +#include "webrtc_transport.h" +#include +#include "Rtcp/Rtcp.h" + +WebRtcTransport::WebRtcTransport() { + static onceToken token([](){ + Utils::Crypto::ClassInit(); + RTC::DtlsTransport::ClassInit(); + RTC::DepLibSRTP::ClassInit(); + RTC::SrtpSession::ClassInit(); + }); + + ice_server_ = std::make_shared(Utils::Crypto::GetRandomString(4), Utils::Crypto::GetRandomString(24)); + ice_server_->SetIceServerCompletedCB([this]() { + this->OnIceServerCompleted(); + }); + ice_server_->SetSendCB([this](char *buf, size_t len, struct sockaddr_in *remote_address) { + this->WritePacket(buf, len, remote_address); + }); + + // todo dtls服务器或客户端模式 + dtls_transport_ = std::make_shared(true); + dtls_transport_->SetHandshakeCompletedCB([this](std::string client_key, std::string server_key, RTC::CryptoSuite srtp_crypto_suite) { + this->OnDtlsCompleted(client_key, server_key, srtp_crypto_suite); + }); + dtls_transport_->SetOutPutCB([this](char *buf, size_t len) { this->WritePacket(buf, len); }); +} + +WebRtcTransport::~WebRtcTransport() {} + +std::string WebRtcTransport::GetLocalSdp() { + char sdp[1024 * 10] = {0}; + auto ssrc = getSSRC(); + auto ip = getIP(); + auto pt = getPayloadType(); + auto port = getPort(); + sprintf(sdp, + "v=0\r\n" + "o=- 1495799811084970 1495799811084970 IN IP4 %s\r\n" + "s=Streaming Test\r\n" + "t=0 0\r\n" + "a=group:BUNDLE video\r\n" + "a=msid-semantic: WMS janus\r\n" + "m=video %u RTP/SAVPF %u\r\n" + "c=IN IP4 %s\r\n" + "a=mid:video\r\n" + "a=sendonly\r\n" + "a=rtcp-mux\r\n" + "a=ice-ufrag:%s\r\n" + "a=ice-pwd:%s\r\n" + "a=ice-options:trickle\r\n" + "a=fingerprint:sha-256 %s\r\n" + "a=setup:actpass\r\n" + "a=connection:new\r\n" + "a=rtpmap:%u H264/90000\r\n" + "a=ssrc:%u cname:janusvideo\r\n" + "a=ssrc:%u msid:janus janusv0\r\n" + "a=ssrc:%u mslabel:janus\r\n" + "a=ssrc:%u label:janusv0\r\n" + "a=candidate:%s 1 udp %u %s %u typ %s\r\n", + ip.c_str(), port, pt, ip.c_str(), + ice_server_->GetUsernameFragment().c_str(),ice_server_->GetPassword().c_str(), + dtls_transport_->GetMyFingerprint().c_str(), pt, ssrc, ssrc, ssrc, ssrc, "4", ssrc, ip.c_str(), port, "host"); + return sdp; +} + +void WebRtcTransport::OnIceServerCompleted() { + InfoL; + dtls_transport_->Start(); + onIceConnected(); +} + +void WebRtcTransport::OnDtlsCompleted(std::string client_key, std::string server_key, RTC::CryptoSuite srtp_crypto_suite) { + InfoL << client_key << " " << server_key << " " << (int)srtp_crypto_suite; + srtp_session_ = std::make_shared(RTC::SrtpSession::Type::OUTBOUND, srtp_crypto_suite, (uint8_t *) client_key.c_str(), client_key.size()); + onDtlsCompleted(); +} + +bool is_dtls(char *buf) { + return ((*buf > 19) && (*buf < 64)); +} + +bool is_rtp(char *buf) { + RtpHeader *header = (RtpHeader *) buf; + return ((header->pt < 64) || (header->pt >= 96)); +} + +bool is_rtcp(char *buf) { + RtpHeader *header = (RtpHeader *) buf; + return ((header->pt >= 64) && (header->pt < 96)); +} + +void WebRtcTransport::OnInputDataPacket(char *buf, size_t len, struct sockaddr_in *remote_address) { + if (RTC::StunPacket::IsStun((const uint8_t *) buf, len)) { + InfoL << "stun:" << hexdump(buf, len); + RTC::StunPacket *packet = RTC::StunPacket::Parse((const uint8_t *) buf, len); + if (packet == nullptr) { + WarnL << "parse stun error" << std::endl; + return; + } + ice_server_->ProcessStunPacket(packet, remote_address); + return; + } + if (DtlsTransport::IsDtlsPacket(buf, len)) { + InfoL << "dtls:" << hexdump(buf, len); + dtls_transport_->InputData(buf, len); + return; + } + if (is_rtp(buf)) { + RtpHeader *header = (RtpHeader *) buf; + InfoL << "rtp:" << header->dumpString(len); + return; + } + if (is_rtcp(buf)) { + RtcpHeader *header = (RtcpHeader *) buf; +// InfoL << "rtcp:" << header->dumpString(); + return; + } +} + +void WebRtcTransport::WritePacket(char *buf, size_t len, struct sockaddr_in *remote_address) { + onWrite(buf, len, remote_address ? remote_address : (ice_server_ ? ice_server_->GetSelectAddr() : nullptr)); +} + +void WebRtcTransport::WritRtpPacket(char *buf, size_t len) { + const uint8_t *p = (uint8_t *) buf; + bool ret = false; + if (srtp_session_) { + ret = srtp_session_->EncryptRtp(&p, &len); + } + if (ret) { + onWrite((char *) p, len, ice_server_->GetSelectAddr()); + } +} + +/////////////////////////////////////////////////////////////////////////////////// + +WebRtcTransportImp::WebRtcTransportImp(const EventPoller::Ptr &poller) { + _socket = Socket::createSocket(poller, false); + //随机端口,绑定全部网卡 + _socket->bindUdpSock(0); + _socket->setOnRead([this](const Buffer::Ptr &buf, struct sockaddr *addr, int addr_len){ + OnInputDataPacket(buf->data(), buf->size(), (struct sockaddr_in*)addr); + }); +} + +void WebRtcTransportImp::attach(const RtspMediaSource::Ptr &src) { + assert(src); + _src = src; +} + +void WebRtcTransportImp::onDtlsCompleted() { + _reader = _src->getRing()->attach(_socket->getPoller(), true); + weak_ptr weak_self = shared_from_this(); + _reader->setReadCB([weak_self](const RtspMediaSource::RingDataType &pkt){ + auto strongSelf = weak_self.lock(); + if (!strongSelf) { + return; + } + pkt->for_each([&](const RtpPacket::Ptr &rtp) { + if(rtp->type == TrackVideo) { + //目前只支持视频 + strongSelf->WritRtpPacket(rtp->data() + RtpPacket::kRtpTcpHeaderSize, + rtp->size() - RtpPacket::kRtpTcpHeaderSize); + } + }); + }); +} + +void WebRtcTransportImp::onIceConnected(){ + +} + +void WebRtcTransportImp::onWrite(const char *buf, size_t len, struct sockaddr_in *dst) { + auto ptr = BufferRaw::create(); + ptr->assign(buf, len); +// InfoL << len << " " << SockUtil::inet_ntoa(dst->sin_addr) << " " << ntohs(dst->sin_port); + _socket->send(ptr, (struct sockaddr *)(dst), sizeof(struct sockaddr)); +} + +uint32_t WebRtcTransportImp::getSSRC() const { + return _src->getSsrc(TrackVideo); +} + +int WebRtcTransportImp::getPayloadType() const{ + auto sdp = SdpParser(_src->getSdp()); + auto track = sdp.getTrack(TrackVideo); + assert(track); + return track ? track->_pt : 0; +} + +uint16_t WebRtcTransportImp::getPort() const { + //todo udp端口号应该与外网映射端口相同 + return _socket->get_local_port(); +} + +std::string WebRtcTransportImp::getIP() const { + //todo 替换为外网ip + return SockUtil::get_local_ip(); +} + +/////////////////////////////////////////////////////////////////// + +INSTANCE_IMP(WebRtcManager) + +WebRtcManager::WebRtcManager() { + +} + +WebRtcManager::~WebRtcManager() { + +} + + + diff --git a/webrtc/webrtc_transport.h b/webrtc/webrtc_transport.h new file mode 100644 index 00000000..2289e5f1 --- /dev/null +++ b/webrtc/webrtc_transport.h @@ -0,0 +1,112 @@ +#pragma once + +#include +#include + +#include "dtls_transport.h" +#include "ice_server.h" +#include "srtp_session.h" +#include "stun_packet.h" + +class WebRtcTransport { +public: + using Ptr = std::shared_ptr; + WebRtcTransport(); + virtual ~WebRtcTransport(); + + /// 获取本地sdp + /// \return + std::string GetLocalSdp(); + + /// 收到udp数据 + /// \param buf + /// \param len + /// \param remote_address + void OnInputDataPacket(char *buf, size_t len, struct sockaddr_in *remote_address); + + /// 发送rtp + /// \param buf + /// \param len + void WritRtpPacket(char *buf, size_t len); + +protected: + /// 输出udp数据 + /// \param buf + /// \param len + /// \param dst + virtual void onWrite(const char *buf, size_t len, struct sockaddr_in *dst) = 0; + virtual uint32_t getSSRC() const = 0; + virtual uint16_t getPort() const = 0; + virtual std::string getIP() const = 0; + virtual int getPayloadType() const = 0; + virtual void onIceConnected() = 0; + virtual void onDtlsCompleted() = 0; + +private: + void OnIceServerCompleted(); + void OnDtlsCompleted(std::string client_key, std::string server_key, RTC::CryptoSuite srtp_crypto_suite); + void WritePacket(char *buf, size_t len, struct sockaddr_in *remote_address = nullptr); + +private: + IceServer::Ptr ice_server_; + DtlsTransport::Ptr dtls_transport_; + std::shared_ptr srtp_session_; +}; + +#include "Poller/EventPoller.h" +#include "Network/Socket.h" +#include "Rtsp/RtspMediaSource.h" +using namespace toolkit; +using namespace mediakit; + +class WebRtcTransportImp : public WebRtcTransport, public std::enable_shared_from_this{ +public: + using Ptr = std::shared_ptr; + + WebRtcTransportImp(const EventPoller::Ptr &poller); + ~WebRtcTransportImp() override = default; + + void attach(const RtspMediaSource::Ptr &src); + +protected: + void onWrite(const char *buf, size_t len, struct sockaddr_in *dst) override; + int getPayloadType() const ; + uint32_t getSSRC() const override; + uint16_t getPort() const override; + std::string getIP() const override; + void onIceConnected() override; + void onDtlsCompleted() override; + +private: + Socket::Ptr _socket; + RtspMediaSource::Ptr _src; + RtspMediaSource::RingType::RingReader::Ptr _reader; +}; + +class WebRtcManager : public std::enable_shared_from_this { +public: + ~WebRtcManager(); + static WebRtcManager& Instance(); + +private: + WebRtcManager(); + +}; + + + + + + + + + + + + + + + + + +