From 56d6eb0f2858b586424bdeddfd7f345f8e003026 Mon Sep 17 00:00:00 2001 From: ziyue <1213642868@qq.com> Date: Fri, 3 Mar 2023 15:24:12 +0800 Subject: [PATCH] =?UTF-8?q?=E6=89=B9=E9=87=8F=E6=9B=BF=E6=8D=A2tab?= =?UTF-8?q?=E4=B8=BA4=E4=B8=AA=E7=A9=BA=E6=A0=BC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api/source/mk_media.cpp | 6 +- webrtc/DtlsTransport.cpp | 2404 ++++++++++++++++++------------------ webrtc/DtlsTransport.hpp | 356 +++--- webrtc/IceServer.cpp | 894 +++++++------- webrtc/IceServer.hpp | 176 +-- webrtc/SctpAssociation.cpp | 1320 ++++++++++---------- webrtc/SctpAssociation.hpp | 170 +-- webrtc/StunPacket.cpp | 1392 ++++++++++----------- webrtc/StunPacket.hpp | 348 +++--- 9 files changed, 3533 insertions(+), 3533 deletions(-) diff --git a/api/source/mk_media.cpp b/api/source/mk_media.cpp index 117d7082..1c6c9d76 100755 --- a/api/source/mk_media.cpp +++ b/api/source/mk_media.cpp @@ -269,9 +269,9 @@ API_EXPORT int API_CALL mk_media_input_aac(mk_media ctx, const void *data, int l } API_EXPORT int API_CALL mk_media_input_pcm(mk_media ctx, void *data , int len, uint64_t pts){ - assert(ctx && data && len > 0); - MediaHelper::Ptr* obj = (MediaHelper::Ptr*) ctx; - return (*obj)->getChannel()->inputPCM((char*)data, len, pts); + assert(ctx && data && len > 0); + MediaHelper::Ptr* obj = (MediaHelper::Ptr*) ctx; + return (*obj)->getChannel()->inputPCM((char*)data, len, pts); } API_EXPORT int API_CALL mk_media_input_audio(mk_media ctx, const void* data, int len, uint64_t dts){ diff --git a/webrtc/DtlsTransport.cpp b/webrtc/DtlsTransport.cpp index 0e2f160a..ec2f67b6 100644 --- a/webrtc/DtlsTransport.cpp +++ b/webrtc/DtlsTransport.cpp @@ -33,1453 +33,1453 @@ OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. using namespace std; #define LOG_OPENSSL_ERROR(desc) \ - do \ - { \ - if (ERR_peek_error() == 0) \ - MS_ERROR("OpenSSL error [desc:'%s']", desc); \ - else \ - { \ - int64_t err; \ - while ((err = ERR_get_error()) != 0) \ - { \ - MS_ERROR("OpenSSL error [desc:'%s', error:'%s']", desc, ERR_error_string(err, nullptr)); \ - } \ - ERR_clear_error(); \ - } \ - } while (false) + do \ + { \ + if (ERR_peek_error() == 0) \ + MS_ERROR("OpenSSL error [desc:'%s']", desc); \ + else \ + { \ + int64_t err; \ + while ((err = ERR_get_error()) != 0) \ + { \ + MS_ERROR("OpenSSL error [desc:'%s', error:'%s']", desc, ERR_error_string(err, nullptr)); \ + } \ + ERR_clear_error(); \ + } \ + } while (false) /* Static methods for OpenSSL callbacks. */ inline static int onSslCertificateVerify(int /*preverifyOk*/, X509_STORE_CTX* /*ctx*/) { - MS_TRACE(); + MS_TRACE(); - // Always valid since DTLS certificates are self-signed. - return 1; + // Always valid since DTLS certificates are self-signed. + return 1; } inline static unsigned int onSslDtlsTimer(SSL* /*ssl*/, unsigned int timerUs) { - if (timerUs == 0) - return 100000; - else if (timerUs >= 4000000) - return 4000000; - else - return 2 * timerUs; + if (timerUs == 0) + return 100000; + else if (timerUs >= 4000000) + return 4000000; + else + return 2 * timerUs; } namespace RTC { - /* Static. */ + /* Static. */ - // clang-format off - static constexpr int DtlsMtu{ 1350 }; - // AES-HMAC: http://tools.ietf.org/html/rfc3711 - static constexpr size_t SrtpMasterKeyLength{ 16 }; - static constexpr size_t SrtpMasterSaltLength{ 14 }; - static constexpr size_t SrtpMasterLength{ SrtpMasterKeyLength + SrtpMasterSaltLength }; - // AES-GCM: http://tools.ietf.org/html/rfc7714 - static constexpr size_t SrtpAesGcm256MasterKeyLength{ 32 }; - static constexpr size_t SrtpAesGcm256MasterSaltLength{ 12 }; - static constexpr size_t SrtpAesGcm256MasterLength{ SrtpAesGcm256MasterKeyLength + SrtpAesGcm256MasterSaltLength }; - static constexpr size_t SrtpAesGcm128MasterKeyLength{ 16 }; - static constexpr size_t SrtpAesGcm128MasterSaltLength{ 12 }; - static constexpr size_t SrtpAesGcm128MasterLength{ SrtpAesGcm128MasterKeyLength + SrtpAesGcm128MasterSaltLength }; - // clang-format on + // clang-format off + static constexpr int DtlsMtu{ 1350 }; + // AES-HMAC: http://tools.ietf.org/html/rfc3711 + static constexpr size_t SrtpMasterKeyLength{ 16 }; + static constexpr size_t SrtpMasterSaltLength{ 14 }; + static constexpr size_t SrtpMasterLength{ SrtpMasterKeyLength + SrtpMasterSaltLength }; + // AES-GCM: http://tools.ietf.org/html/rfc7714 + static constexpr size_t SrtpAesGcm256MasterKeyLength{ 32 }; + static constexpr size_t SrtpAesGcm256MasterSaltLength{ 12 }; + static constexpr size_t SrtpAesGcm256MasterLength{ SrtpAesGcm256MasterKeyLength + SrtpAesGcm256MasterSaltLength }; + static constexpr size_t SrtpAesGcm128MasterKeyLength{ 16 }; + static constexpr size_t SrtpAesGcm128MasterSaltLength{ 12 }; + static constexpr size_t SrtpAesGcm128MasterLength{ SrtpAesGcm128MasterKeyLength + SrtpAesGcm128MasterSaltLength }; + // clang-format on - /* Class variables. */ - // clang-format off - std::map DtlsTransport::string2FingerprintAlgorithm = - { - { "sha-1", DtlsTransport::FingerprintAlgorithm::SHA1 }, - { "sha-224", DtlsTransport::FingerprintAlgorithm::SHA224 }, - { "sha-256", DtlsTransport::FingerprintAlgorithm::SHA256 }, - { "sha-384", DtlsTransport::FingerprintAlgorithm::SHA384 }, - { "sha-512", DtlsTransport::FingerprintAlgorithm::SHA512 } - }; - std::map DtlsTransport::fingerprintAlgorithm2String = - { - { DtlsTransport::FingerprintAlgorithm::SHA1, "sha-1" }, - { DtlsTransport::FingerprintAlgorithm::SHA224, "sha-224" }, - { DtlsTransport::FingerprintAlgorithm::SHA256, "sha-256" }, - { DtlsTransport::FingerprintAlgorithm::SHA384, "sha-384" }, - { DtlsTransport::FingerprintAlgorithm::SHA512, "sha-512" } - }; - std::map DtlsTransport::string2Role = - { - { "auto", DtlsTransport::Role::AUTO }, - { "client", DtlsTransport::Role::CLIENT }, - { "server", DtlsTransport::Role::SERVER } - }; - std::vector DtlsTransport::srtpCryptoSuites = - { - { RTC::SrtpSession::CryptoSuite::AEAD_AES_256_GCM, "SRTP_AEAD_AES_256_GCM" }, - { RTC::SrtpSession::CryptoSuite::AEAD_AES_128_GCM, "SRTP_AEAD_AES_128_GCM" }, - { RTC::SrtpSession::CryptoSuite::AES_CM_128_HMAC_SHA1_80, "SRTP_AES128_CM_SHA1_80" }, - { RTC::SrtpSession::CryptoSuite::AES_CM_128_HMAC_SHA1_32, "SRTP_AES128_CM_SHA1_32" } - }; - // clang-format on + /* Class variables. */ + // clang-format off + std::map DtlsTransport::string2FingerprintAlgorithm = + { + { "sha-1", DtlsTransport::FingerprintAlgorithm::SHA1 }, + { "sha-224", DtlsTransport::FingerprintAlgorithm::SHA224 }, + { "sha-256", DtlsTransport::FingerprintAlgorithm::SHA256 }, + { "sha-384", DtlsTransport::FingerprintAlgorithm::SHA384 }, + { "sha-512", DtlsTransport::FingerprintAlgorithm::SHA512 } + }; + std::map DtlsTransport::fingerprintAlgorithm2String = + { + { DtlsTransport::FingerprintAlgorithm::SHA1, "sha-1" }, + { DtlsTransport::FingerprintAlgorithm::SHA224, "sha-224" }, + { DtlsTransport::FingerprintAlgorithm::SHA256, "sha-256" }, + { DtlsTransport::FingerprintAlgorithm::SHA384, "sha-384" }, + { DtlsTransport::FingerprintAlgorithm::SHA512, "sha-512" } + }; + std::map DtlsTransport::string2Role = + { + { "auto", DtlsTransport::Role::AUTO }, + { "client", DtlsTransport::Role::CLIENT }, + { "server", DtlsTransport::Role::SERVER } + }; + std::vector DtlsTransport::srtpCryptoSuites = + { + { RTC::SrtpSession::CryptoSuite::AEAD_AES_256_GCM, "SRTP_AEAD_AES_256_GCM" }, + { RTC::SrtpSession::CryptoSuite::AEAD_AES_128_GCM, "SRTP_AEAD_AES_128_GCM" }, + { RTC::SrtpSession::CryptoSuite::AES_CM_128_HMAC_SHA1_80, "SRTP_AES128_CM_SHA1_80" }, + { RTC::SrtpSession::CryptoSuite::AES_CM_128_HMAC_SHA1_32, "SRTP_AES128_CM_SHA1_32" } + }; + // clang-format on - INSTANCE_IMP(DtlsTransport::DtlsEnvironment); + INSTANCE_IMP(DtlsTransport::DtlsEnvironment); - /* Class methods. */ + /* Class methods. */ DtlsTransport::DtlsEnvironment::DtlsEnvironment() - { - MS_TRACE(); + { + MS_TRACE(); - // Generate a X509 certificate and private key (unless PEM files are provided). - if (true /* - Settings::configuration.dtlsCertificateFile.empty() || - Settings::configuration.dtlsPrivateKeyFile.empty()*/) - { - GenerateCertificateAndPrivateKey(); - } - else - { - ReadCertificateAndPrivateKeyFromFiles(); - } + // Generate a X509 certificate and private key (unless PEM files are provided). + if (true /* + Settings::configuration.dtlsCertificateFile.empty() || + Settings::configuration.dtlsPrivateKeyFile.empty()*/) + { + GenerateCertificateAndPrivateKey(); + } + else + { + ReadCertificateAndPrivateKeyFromFiles(); + } - // Create a global SSL_CTX. - CreateSslCtx(); + // Create a global SSL_CTX. + CreateSslCtx(); - // Generate certificate fingerprints. - GenerateFingerprints(); - } + // Generate certificate fingerprints. + GenerateFingerprints(); + } DtlsTransport::DtlsEnvironment::~DtlsEnvironment() - { - MS_TRACE(); + { + MS_TRACE(); - if (privateKey) - EVP_PKEY_free(privateKey); - if (certificate) - X509_free(certificate); - if (sslCtx) - SSL_CTX_free(sslCtx); - } + if (privateKey) + EVP_PKEY_free(privateKey); + if (certificate) + X509_free(certificate); + if (sslCtx) + SSL_CTX_free(sslCtx); + } - void DtlsTransport::DtlsEnvironment::GenerateCertificateAndPrivateKey() - { - MS_TRACE(); + void DtlsTransport::DtlsEnvironment::GenerateCertificateAndPrivateKey() + { + MS_TRACE(); - int ret{ 0 }; - EC_KEY* ecKey{ nullptr }; - X509_NAME* certName{ nullptr }; - std::string subject = - std::string("mediasoup") + to_string(rand() % 999999 + 100000); + int ret{ 0 }; + EC_KEY* ecKey{ nullptr }; + X509_NAME* certName{ nullptr }; + std::string subject = + std::string("mediasoup") + to_string(rand() % 999999 + 100000); - // Create key with curve. - ecKey = EC_KEY_new_by_curve_name(NID_X9_62_prime256v1); + // Create key with curve. + ecKey = EC_KEY_new_by_curve_name(NID_X9_62_prime256v1); - if (!ecKey) - { - LOG_OPENSSL_ERROR("EC_KEY_new_by_curve_name() failed"); + if (!ecKey) + { + LOG_OPENSSL_ERROR("EC_KEY_new_by_curve_name() failed"); - goto error; - } + goto error; + } - EC_KEY_set_asn1_flag(ecKey, OPENSSL_EC_NAMED_CURVE); + EC_KEY_set_asn1_flag(ecKey, OPENSSL_EC_NAMED_CURVE); - // NOTE: This can take some time. - ret = EC_KEY_generate_key(ecKey); + // NOTE: This can take some time. + ret = EC_KEY_generate_key(ecKey); - if (ret == 0) - { - LOG_OPENSSL_ERROR("EC_KEY_generate_key() failed"); + if (ret == 0) + { + LOG_OPENSSL_ERROR("EC_KEY_generate_key() failed"); - goto error; - } + goto error; + } - // Create a private key object. - privateKey = EVP_PKEY_new(); + // Create a private key object. + privateKey = EVP_PKEY_new(); - if (!privateKey) - { - LOG_OPENSSL_ERROR("EVP_PKEY_new() failed"); + if (!privateKey) + { + LOG_OPENSSL_ERROR("EVP_PKEY_new() failed"); - goto error; - } + goto error; + } - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-cstyle-cast) - ret = EVP_PKEY_assign_EC_KEY(privateKey, ecKey); + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-cstyle-cast) + ret = EVP_PKEY_assign_EC_KEY(privateKey, ecKey); - if (ret == 0) - { - LOG_OPENSSL_ERROR("EVP_PKEY_assign_EC_KEY() failed"); + if (ret == 0) + { + LOG_OPENSSL_ERROR("EVP_PKEY_assign_EC_KEY() failed"); - goto error; - } + goto error; + } - // The EC key now belongs to the private key, so don't clean it up separately. - ecKey = nullptr; + // The EC key now belongs to the private key, so don't clean it up separately. + ecKey = nullptr; - // Create the X509 certificate. - certificate = X509_new(); + // Create the X509 certificate. + certificate = X509_new(); - if (!certificate) - { - LOG_OPENSSL_ERROR("X509_new() failed"); + if (!certificate) + { + LOG_OPENSSL_ERROR("X509_new() failed"); - goto error; - } + goto error; + } - // Set version 3 (note that 0 means version 1). - X509_set_version(certificate, 2); + // Set version 3 (note that 0 means version 1). + X509_set_version(certificate, 2); - // Set serial number (avoid default 0). - ASN1_INTEGER_set( - X509_get_serialNumber(certificate), - static_cast(rand() % 999999 + 100000)); + // Set serial number (avoid default 0). + ASN1_INTEGER_set( + X509_get_serialNumber(certificate), + static_cast(rand() % 999999 + 100000)); - // Set valid period. - X509_gmtime_adj(X509_get_notBefore(certificate), -315360000); // -10 years. - X509_gmtime_adj(X509_get_notAfter(certificate), 315360000); // 10 years. + // Set valid period. + X509_gmtime_adj(X509_get_notBefore(certificate), -315360000); // -10 years. + X509_gmtime_adj(X509_get_notAfter(certificate), 315360000); // 10 years. - // Set the public key for the certificate using the key. - ret = X509_set_pubkey(certificate, privateKey); + // Set the public key for the certificate using the key. + ret = X509_set_pubkey(certificate, privateKey); - if (ret == 0) - { - LOG_OPENSSL_ERROR("X509_set_pubkey() failed"); + if (ret == 0) + { + LOG_OPENSSL_ERROR("X509_set_pubkey() failed"); - goto error; - } + goto error; + } - // Set certificate fields. - certName = X509_get_subject_name(certificate); + // Set certificate fields. + certName = X509_get_subject_name(certificate); - if (!certName) - { - LOG_OPENSSL_ERROR("X509_get_subject_name() failed"); + if (!certName) + { + LOG_OPENSSL_ERROR("X509_get_subject_name() failed"); - goto error; - } + goto error; + } - X509_NAME_add_entry_by_txt( - certName, "O", MBSTRING_ASC, reinterpret_cast(subject.c_str()), -1, -1, 0); - X509_NAME_add_entry_by_txt( - certName, "CN", MBSTRING_ASC, reinterpret_cast(subject.c_str()), -1, -1, 0); + X509_NAME_add_entry_by_txt( + certName, "O", MBSTRING_ASC, reinterpret_cast(subject.c_str()), -1, -1, 0); + X509_NAME_add_entry_by_txt( + certName, "CN", MBSTRING_ASC, reinterpret_cast(subject.c_str()), -1, -1, 0); - // It is self-signed so set the issuer name to be the same as the subject. - ret = X509_set_issuer_name(certificate, certName); + // It is self-signed so set the issuer name to be the same as the subject. + ret = X509_set_issuer_name(certificate, certName); - if (ret == 0) - { - LOG_OPENSSL_ERROR("X509_set_issuer_name() failed"); + if (ret == 0) + { + LOG_OPENSSL_ERROR("X509_set_issuer_name() failed"); - goto error; - } + goto error; + } - // Sign the certificate with its own private key. - ret = X509_sign(certificate, privateKey, EVP_sha1()); + // Sign the certificate with its own private key. + ret = X509_sign(certificate, privateKey, EVP_sha1()); - if (ret == 0) - { - LOG_OPENSSL_ERROR("X509_sign() failed"); + if (ret == 0) + { + LOG_OPENSSL_ERROR("X509_sign() failed"); - goto error; - } + goto error; + } - return; + return; - error: + error: - if (ecKey) - EC_KEY_free(ecKey); + if (ecKey) + EC_KEY_free(ecKey); - if (privateKey) - EVP_PKEY_free(privateKey); // NOTE: This also frees the EC key. + if (privateKey) + EVP_PKEY_free(privateKey); // NOTE: This also frees the EC key. - if (certificate) - X509_free(certificate); + if (certificate) + X509_free(certificate); - MS_THROW_ERROR("DTLS certificate and private key generation failed"); - } + MS_THROW_ERROR("DTLS certificate and private key generation failed"); + } - void DtlsTransport::DtlsEnvironment::ReadCertificateAndPrivateKeyFromFiles() - { + void DtlsTransport::DtlsEnvironment::ReadCertificateAndPrivateKeyFromFiles() + { #if 0 - MS_TRACE(); + MS_TRACE(); - FILE* file{ nullptr }; + FILE* file{ nullptr }; - file = fopen(Settings::configuration.dtlsCertificateFile.c_str(), "r"); + file = fopen(Settings::configuration.dtlsCertificateFile.c_str(), "r"); - if (!file) - { - MS_ERROR("error reading DTLS certificate file: %s", std::strerror(errno)); + if (!file) + { + MS_ERROR("error reading DTLS certificate file: %s", std::strerror(errno)); - goto error; - } + goto error; + } - certificate = PEM_read_X509(file, nullptr, nullptr, nullptr); + certificate = PEM_read_X509(file, nullptr, nullptr, nullptr); - if (!certificate) - { - LOG_OPENSSL_ERROR("PEM_read_X509() failed"); + if (!certificate) + { + LOG_OPENSSL_ERROR("PEM_read_X509() failed"); - goto error; - } + goto error; + } - fclose(file); + fclose(file); - file = fopen(Settings::configuration.dtlsPrivateKeyFile.c_str(), "r"); + file = fopen(Settings::configuration.dtlsPrivateKeyFile.c_str(), "r"); - if (!file) - { - MS_ERROR("error reading DTLS private key file: %s", std::strerror(errno)); + if (!file) + { + MS_ERROR("error reading DTLS private key file: %s", std::strerror(errno)); - goto error; - } + goto error; + } - privateKey = PEM_read_PrivateKey(file, nullptr, nullptr, nullptr); + privateKey = PEM_read_PrivateKey(file, nullptr, nullptr, nullptr); - if (!privateKey) - { - LOG_OPENSSL_ERROR("PEM_read_PrivateKey() failed"); + if (!privateKey) + { + LOG_OPENSSL_ERROR("PEM_read_PrivateKey() failed"); - goto error; - } + goto error; + } - fclose(file); + fclose(file); - return; + return; - error: + error: - MS_THROW_ERROR("error reading DTLS certificate and private key PEM files"); + MS_THROW_ERROR("error reading DTLS certificate and private key PEM files"); #endif - } + } - void DtlsTransport::DtlsEnvironment::CreateSslCtx() - { - MS_TRACE(); + void DtlsTransport::DtlsEnvironment::CreateSslCtx() + { + MS_TRACE(); - std::string dtlsSrtpCryptoSuites; - int ret; + std::string dtlsSrtpCryptoSuites; + int ret; - /* Set the global DTLS context. */ + /* Set the global DTLS context. */ - // Both DTLS 1.0 and 1.2 (requires OpenSSL >= 1.1.0). - sslCtx = SSL_CTX_new(DTLS_method()); + // Both DTLS 1.0 and 1.2 (requires OpenSSL >= 1.1.0). + sslCtx = SSL_CTX_new(DTLS_method()); - if (!sslCtx) - { - LOG_OPENSSL_ERROR("SSL_CTX_new() failed"); + if (!sslCtx) + { + LOG_OPENSSL_ERROR("SSL_CTX_new() failed"); - goto error; - } + goto error; + } - ret = SSL_CTX_use_certificate(sslCtx, certificate); + ret = SSL_CTX_use_certificate(sslCtx, certificate); - if (ret == 0) - { - LOG_OPENSSL_ERROR("SSL_CTX_use_certificate() failed"); + if (ret == 0) + { + LOG_OPENSSL_ERROR("SSL_CTX_use_certificate() failed"); - goto error; - } + goto error; + } - ret = SSL_CTX_use_PrivateKey(sslCtx, privateKey); + ret = SSL_CTX_use_PrivateKey(sslCtx, privateKey); - if (ret == 0) - { - LOG_OPENSSL_ERROR("SSL_CTX_use_PrivateKey() failed"); + if (ret == 0) + { + LOG_OPENSSL_ERROR("SSL_CTX_use_PrivateKey() failed"); - goto error; - } + goto error; + } - ret = SSL_CTX_check_private_key(sslCtx); + ret = SSL_CTX_check_private_key(sslCtx); - if (ret == 0) - { - LOG_OPENSSL_ERROR("SSL_CTX_check_private_key() failed"); + if (ret == 0) + { + LOG_OPENSSL_ERROR("SSL_CTX_check_private_key() failed"); - goto error; - } + goto error; + } - // Set options. - SSL_CTX_set_options( - sslCtx, - SSL_OP_CIPHER_SERVER_PREFERENCE | SSL_OP_NO_TICKET | SSL_OP_SINGLE_ECDH_USE | - SSL_OP_NO_QUERY_MTU); + // Set options. + SSL_CTX_set_options( + sslCtx, + SSL_OP_CIPHER_SERVER_PREFERENCE | SSL_OP_NO_TICKET | SSL_OP_SINGLE_ECDH_USE | + SSL_OP_NO_QUERY_MTU); - // Don't use sessions cache. - SSL_CTX_set_session_cache_mode(sslCtx, SSL_SESS_CACHE_OFF); + // Don't use sessions cache. + SSL_CTX_set_session_cache_mode(sslCtx, SSL_SESS_CACHE_OFF); - // Read always as much into the buffer as possible. - // NOTE: This is the default for DTLS, but a bug in non latest OpenSSL - // versions makes this call required. - SSL_CTX_set_read_ahead(sslCtx, 1); + // Read always as much into the buffer as possible. + // NOTE: This is the default for DTLS, but a bug in non latest OpenSSL + // versions makes this call required. + SSL_CTX_set_read_ahead(sslCtx, 1); - SSL_CTX_set_verify_depth(sslCtx, 4); + SSL_CTX_set_verify_depth(sslCtx, 4); - // Require certificate from peer. - SSL_CTX_set_verify( - sslCtx, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, onSslCertificateVerify); + // Require certificate from peer. + SSL_CTX_set_verify( + sslCtx, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, onSslCertificateVerify); - // Set SSL info callback. - SSL_CTX_set_info_callback(sslCtx, [](const SSL* ssl, int where, int ret){ + // Set SSL info callback. + SSL_CTX_set_info_callback(sslCtx, [](const SSL* ssl, int where, int ret){ static_cast(SSL_get_ex_data(ssl, 0))->OnSslInfo(where, ret); }); - // Set ciphers. - ret = SSL_CTX_set_cipher_list( - sslCtx, "DEFAULT:!NULL:!aNULL:!SHA256:!SHA384:!aECDH:!AESGCM+AES256:!aPSK"); + // Set ciphers. + ret = SSL_CTX_set_cipher_list( + sslCtx, "DEFAULT:!NULL:!aNULL:!SHA256:!SHA384:!aECDH:!AESGCM+AES256:!aPSK"); - if (ret == 0) - { - LOG_OPENSSL_ERROR("SSL_CTX_set_cipher_list() failed"); + if (ret == 0) + { + LOG_OPENSSL_ERROR("SSL_CTX_set_cipher_list() failed"); - goto error; - } + goto error; + } - // Enable ECDH ciphers. - // DOC: http://en.wikibooks.org/wiki/OpenSSL/Diffie-Hellman_parameters - // NOTE: https://code.google.com/p/chromium/issues/detail?id=406458 - // NOTE: https://bugs.ruby-lang.org/issues/12324 + // Enable ECDH ciphers. + // DOC: http://en.wikibooks.org/wiki/OpenSSL/Diffie-Hellman_parameters + // NOTE: https://code.google.com/p/chromium/issues/detail?id=406458 + // NOTE: https://bugs.ruby-lang.org/issues/12324 - // For OpenSSL >= 1.0.2. - SSL_CTX_set_ecdh_auto(sslCtx, 1); + // For OpenSSL >= 1.0.2. + SSL_CTX_set_ecdh_auto(sslCtx, 1); - // Set the "use_srtp" DTLS extension. - for (auto it = DtlsTransport::srtpCryptoSuites.begin(); - it != DtlsTransport::srtpCryptoSuites.end(); - ++it) - { - if (it != DtlsTransport::srtpCryptoSuites.begin()) - dtlsSrtpCryptoSuites += ":"; + // Set the "use_srtp" DTLS extension. + for (auto it = DtlsTransport::srtpCryptoSuites.begin(); + it != DtlsTransport::srtpCryptoSuites.end(); + ++it) + { + if (it != DtlsTransport::srtpCryptoSuites.begin()) + dtlsSrtpCryptoSuites += ":"; - SrtpCryptoSuiteMapEntry* cryptoSuiteEntry = std::addressof(*it); - dtlsSrtpCryptoSuites += cryptoSuiteEntry->name; - } + SrtpCryptoSuiteMapEntry* cryptoSuiteEntry = std::addressof(*it); + dtlsSrtpCryptoSuites += cryptoSuiteEntry->name; + } - MS_DEBUG_2TAGS(dtls, srtp, "setting SRTP cryptoSuites for DTLS: %s", dtlsSrtpCryptoSuites.c_str()); + MS_DEBUG_2TAGS(dtls, srtp, "setting SRTP cryptoSuites for DTLS: %s", dtlsSrtpCryptoSuites.c_str()); - // NOTE: This function returns 0 on success. - ret = SSL_CTX_set_tlsext_use_srtp(sslCtx, dtlsSrtpCryptoSuites.c_str()); + // NOTE: This function returns 0 on success. + ret = SSL_CTX_set_tlsext_use_srtp(sslCtx, dtlsSrtpCryptoSuites.c_str()); - if (ret != 0) - { - MS_ERROR( - "SSL_CTX_set_tlsext_use_srtp() failed when entering '%s'", dtlsSrtpCryptoSuites.c_str()); - LOG_OPENSSL_ERROR("SSL_CTX_set_tlsext_use_srtp() failed"); + if (ret != 0) + { + MS_ERROR( + "SSL_CTX_set_tlsext_use_srtp() failed when entering '%s'", dtlsSrtpCryptoSuites.c_str()); + LOG_OPENSSL_ERROR("SSL_CTX_set_tlsext_use_srtp() failed"); - goto error; - } + goto error; + } - return; + return; - error: + error: - if (sslCtx) - { - SSL_CTX_free(sslCtx); - sslCtx = nullptr; - } + if (sslCtx) + { + SSL_CTX_free(sslCtx); + sslCtx = nullptr; + } - MS_THROW_ERROR("SSL context creation failed"); - } + MS_THROW_ERROR("SSL context creation failed"); + } - void DtlsTransport::DtlsEnvironment::GenerateFingerprints() - { - MS_TRACE(); + void DtlsTransport::DtlsEnvironment::GenerateFingerprints() + { + MS_TRACE(); - for (auto& kv : DtlsTransport::string2FingerprintAlgorithm) - { - const std::string& algorithmString = kv.first; - FingerprintAlgorithm algorithm = kv.second; - uint8_t binaryFingerprint[EVP_MAX_MD_SIZE]; - unsigned int size{ 0 }; - char hexFingerprint[(EVP_MAX_MD_SIZE * 3) + 1]; - const EVP_MD* hashFunction; - int ret; + for (auto& kv : DtlsTransport::string2FingerprintAlgorithm) + { + const std::string& algorithmString = kv.first; + FingerprintAlgorithm algorithm = kv.second; + uint8_t binaryFingerprint[EVP_MAX_MD_SIZE]; + unsigned int size{ 0 }; + char hexFingerprint[(EVP_MAX_MD_SIZE * 3) + 1]; + const EVP_MD* hashFunction; + int ret; - switch (algorithm) - { - case FingerprintAlgorithm::SHA1: - hashFunction = EVP_sha1(); - break; + switch (algorithm) + { + case FingerprintAlgorithm::SHA1: + hashFunction = EVP_sha1(); + break; - case FingerprintAlgorithm::SHA224: - hashFunction = EVP_sha224(); - break; + case FingerprintAlgorithm::SHA224: + hashFunction = EVP_sha224(); + break; - case FingerprintAlgorithm::SHA256: - hashFunction = EVP_sha256(); - break; + case FingerprintAlgorithm::SHA256: + hashFunction = EVP_sha256(); + break; - case FingerprintAlgorithm::SHA384: - hashFunction = EVP_sha384(); - break; + case FingerprintAlgorithm::SHA384: + hashFunction = EVP_sha384(); + break; - case FingerprintAlgorithm::SHA512: - hashFunction = EVP_sha512(); - break; + case FingerprintAlgorithm::SHA512: + hashFunction = EVP_sha512(); + break; - default: - MS_THROW_ERROR("unknown algorithm"); - } + default: + MS_THROW_ERROR("unknown algorithm"); + } - ret = X509_digest(certificate, hashFunction, binaryFingerprint, &size); + ret = X509_digest(certificate, hashFunction, binaryFingerprint, &size); - if (ret == 0) - { - MS_ERROR("X509_digest() failed"); - MS_THROW_ERROR("Fingerprints generation failed"); - } + if (ret == 0) + { + MS_ERROR("X509_digest() failed"); + MS_THROW_ERROR("Fingerprints generation failed"); + } - // Convert to hexadecimal format in uppercase with colons. - for (unsigned int i{ 0 }; i < size; ++i) - { - std::sprintf(hexFingerprint + (i * 3), "%.2X:", binaryFingerprint[i]); - } - hexFingerprint[(size * 3) - 1] = '\0'; + // Convert to hexadecimal format in uppercase with colons. + for (unsigned int i{ 0 }; i < size; ++i) + { + std::sprintf(hexFingerprint + (i * 3), "%.2X:", binaryFingerprint[i]); + } + hexFingerprint[(size * 3) - 1] = '\0'; - MS_DEBUG_TAG(dtls, "%-7s fingerprint: %s", algorithmString.c_str(), hexFingerprint); + MS_DEBUG_TAG(dtls, "%-7s fingerprint: %s", algorithmString.c_str(), hexFingerprint); - // Store it in the vector. - DtlsTransport::Fingerprint fingerprint; + // Store it in the vector. + DtlsTransport::Fingerprint fingerprint; - fingerprint.algorithm = DtlsTransport::GetFingerprintAlgorithm(algorithmString); - fingerprint.value = hexFingerprint; + fingerprint.algorithm = DtlsTransport::GetFingerprintAlgorithm(algorithmString); + fingerprint.value = hexFingerprint; - localFingerprints.push_back(fingerprint); - } - } + localFingerprints.push_back(fingerprint); + } + } - /* Instance methods. */ + /* Instance methods. */ - DtlsTransport::DtlsTransport(EventPoller::Ptr poller,Listener* listener) : poller(std::move(poller)), listener(listener) - { - MS_TRACE(); + DtlsTransport::DtlsTransport(EventPoller::Ptr poller,Listener* listener) : poller(std::move(poller)), listener(listener) + { + MS_TRACE(); env = DtlsEnvironment::Instance().shared_from_this(); - /* Set SSL. */ + /* Set SSL. */ - this->ssl = SSL_new(env->sslCtx); + this->ssl = SSL_new(env->sslCtx); - if (!this->ssl) - { - LOG_OPENSSL_ERROR("SSL_new() failed"); + if (!this->ssl) + { + LOG_OPENSSL_ERROR("SSL_new() failed"); - goto error; - } + goto error; + } - // Set this as custom data. - SSL_set_ex_data(this->ssl, 0, static_cast(this)); + // Set this as custom data. + SSL_set_ex_data(this->ssl, 0, static_cast(this)); - this->sslBioFromNetwork = BIO_new(BIO_s_mem()); + this->sslBioFromNetwork = BIO_new(BIO_s_mem()); - if (!this->sslBioFromNetwork) - { - LOG_OPENSSL_ERROR("BIO_new() failed"); + if (!this->sslBioFromNetwork) + { + LOG_OPENSSL_ERROR("BIO_new() failed"); - SSL_free(this->ssl); + SSL_free(this->ssl); - goto error; - } - - this->sslBioToNetwork = BIO_new(BIO_s_mem()); - - if (!this->sslBioToNetwork) - { - LOG_OPENSSL_ERROR("BIO_new() failed"); - - BIO_free(this->sslBioFromNetwork); - SSL_free(this->ssl); - - goto error; - } - - SSL_set_bio(this->ssl, this->sslBioFromNetwork, this->sslBioToNetwork); - - // Set the MTU so that we don't send packets that are too large with no fragmentation. - SSL_set_mtu(this->ssl, DtlsMtu); - DTLS_set_link_mtu(this->ssl, DtlsMtu); - - // Set callback handler for setting DTLS timer interval. - DTLS_set_timer_cb(this->ssl, onSslDtlsTimer); - - return; - - error: - - // NOTE: At this point SSL_set_bio() was not called so we must free BIOs as - // well. - if (this->sslBioFromNetwork) - BIO_free(this->sslBioFromNetwork); - - if (this->sslBioToNetwork) - BIO_free(this->sslBioToNetwork); - - if (this->ssl) - SSL_free(this->ssl); - - // NOTE: If this is not catched by the caller the program will abort, but - // this should never happen. - MS_THROW_ERROR("DtlsTransport instance creation failed"); - } - - DtlsTransport::~DtlsTransport() - { - MS_TRACE(); - - if (IsRunning()) - { - // Send close alert to the peer. - SSL_shutdown(this->ssl); - SendPendingOutgoingDtlsData(); - } - - if (this->ssl) - { - SSL_free(this->ssl); - - this->ssl = nullptr; - this->sslBioFromNetwork = nullptr; - this->sslBioToNetwork = nullptr; - } - - // Close the DTLS timer. - this->timer = nullptr; - } - - void DtlsTransport::Dump() const - { - MS_TRACE(); - - std::string state{ "new" }; - std::string role{ "none " }; - - switch (this->state) - { - case DtlsState::CONNECTING: - state = "connecting"; - break; - case DtlsState::CONNECTED: - state = "connected"; - break; - case DtlsState::FAILED: - state = "failed"; - break; - case DtlsState::CLOSED: - state = "closed"; - break; - default:; - } - - switch (this->localRole) - { - case Role::AUTO: - role = "auto"; - break; - case Role::SERVER: - role = "server"; - break; - case Role::CLIENT: - role = "client"; - break; - default:; - } - - MS_DUMP(""); - MS_DUMP(" state : %s", state.c_str()); - MS_DUMP(" role : %s", role.c_str()); - MS_DUMP(" handshake done: : %s", this->handshakeDone ? "yes" : "no"); - MS_DUMP(""); - } - - void DtlsTransport::Run(Role localRole) - { - MS_TRACE(); - - MS_ASSERT( - localRole == Role::CLIENT || localRole == Role::SERVER, - "local DTLS role must be 'client' or 'server'"); - - Role previousLocalRole = this->localRole; - - if (localRole == previousLocalRole) - { - MS_ERROR("same local DTLS role provided, doing nothing"); - - return; - } - - // If the previous local DTLS role was 'client' or 'server' do reset. - if (previousLocalRole == Role::CLIENT || previousLocalRole == Role::SERVER) - { - MS_DEBUG_TAG(dtls, "resetting DTLS due to local role change"); - - Reset(); - } - - // Update local role. - this->localRole = localRole; - - // Set state and notify the listener. - this->state = DtlsState::CONNECTING; - this->listener->OnDtlsTransportConnecting(this); - - switch (this->localRole) - { - case Role::CLIENT: - { - MS_DEBUG_TAG(dtls, "running [role:client]"); - - SSL_set_connect_state(this->ssl); - SSL_do_handshake(this->ssl); - SendPendingOutgoingDtlsData(); - SetTimeout(); - - break; - } - - case Role::SERVER: - { - MS_DEBUG_TAG(dtls, "running [role:server]"); - - SSL_set_accept_state(this->ssl); - SSL_do_handshake(this->ssl); - - break; - } - - default: - { - MS_ABORT("invalid local DTLS role"); - } - } - } - - bool DtlsTransport::SetRemoteFingerprint(Fingerprint fingerprint) - { - MS_TRACE(); - - MS_ASSERT( - fingerprint.algorithm != FingerprintAlgorithm::NONE, "no fingerprint algorithm provided"); - - this->remoteFingerprint = fingerprint; - - // The remote fingerpring may have been set after DTLS handshake was done, - // so we may need to process it now. - if (this->handshakeDone && this->state != DtlsState::CONNECTED) - { - MS_DEBUG_TAG(dtls, "handshake already done, processing it right now"); - - return ProcessHandshake(); - } - - return true; - } - - void DtlsTransport::ProcessDtlsData(const uint8_t* data, size_t len) - { - MS_TRACE(); - - int written; - int read; + goto error; + } + + this->sslBioToNetwork = BIO_new(BIO_s_mem()); + + if (!this->sslBioToNetwork) + { + LOG_OPENSSL_ERROR("BIO_new() failed"); + + BIO_free(this->sslBioFromNetwork); + SSL_free(this->ssl); + + goto error; + } + + SSL_set_bio(this->ssl, this->sslBioFromNetwork, this->sslBioToNetwork); + + // Set the MTU so that we don't send packets that are too large with no fragmentation. + SSL_set_mtu(this->ssl, DtlsMtu); + DTLS_set_link_mtu(this->ssl, DtlsMtu); + + // Set callback handler for setting DTLS timer interval. + DTLS_set_timer_cb(this->ssl, onSslDtlsTimer); + + return; + + error: + + // NOTE: At this point SSL_set_bio() was not called so we must free BIOs as + // well. + if (this->sslBioFromNetwork) + BIO_free(this->sslBioFromNetwork); + + if (this->sslBioToNetwork) + BIO_free(this->sslBioToNetwork); + + if (this->ssl) + SSL_free(this->ssl); + + // NOTE: If this is not catched by the caller the program will abort, but + // this should never happen. + MS_THROW_ERROR("DtlsTransport instance creation failed"); + } + + DtlsTransport::~DtlsTransport() + { + MS_TRACE(); + + if (IsRunning()) + { + // Send close alert to the peer. + SSL_shutdown(this->ssl); + SendPendingOutgoingDtlsData(); + } + + if (this->ssl) + { + SSL_free(this->ssl); + + this->ssl = nullptr; + this->sslBioFromNetwork = nullptr; + this->sslBioToNetwork = nullptr; + } + + // Close the DTLS timer. + this->timer = nullptr; + } + + void DtlsTransport::Dump() const + { + MS_TRACE(); + + std::string state{ "new" }; + std::string role{ "none " }; + + switch (this->state) + { + case DtlsState::CONNECTING: + state = "connecting"; + break; + case DtlsState::CONNECTED: + state = "connected"; + break; + case DtlsState::FAILED: + state = "failed"; + break; + case DtlsState::CLOSED: + state = "closed"; + break; + default:; + } + + switch (this->localRole) + { + case Role::AUTO: + role = "auto"; + break; + case Role::SERVER: + role = "server"; + break; + case Role::CLIENT: + role = "client"; + break; + default:; + } + + MS_DUMP(""); + MS_DUMP(" state : %s", state.c_str()); + MS_DUMP(" role : %s", role.c_str()); + MS_DUMP(" handshake done: : %s", this->handshakeDone ? "yes" : "no"); + MS_DUMP(""); + } + + void DtlsTransport::Run(Role localRole) + { + MS_TRACE(); + + MS_ASSERT( + localRole == Role::CLIENT || localRole == Role::SERVER, + "local DTLS role must be 'client' or 'server'"); + + Role previousLocalRole = this->localRole; + + if (localRole == previousLocalRole) + { + MS_ERROR("same local DTLS role provided, doing nothing"); + + return; + } + + // If the previous local DTLS role was 'client' or 'server' do reset. + if (previousLocalRole == Role::CLIENT || previousLocalRole == Role::SERVER) + { + MS_DEBUG_TAG(dtls, "resetting DTLS due to local role change"); + + Reset(); + } + + // Update local role. + this->localRole = localRole; + + // Set state and notify the listener. + this->state = DtlsState::CONNECTING; + this->listener->OnDtlsTransportConnecting(this); + + switch (this->localRole) + { + case Role::CLIENT: + { + MS_DEBUG_TAG(dtls, "running [role:client]"); + + SSL_set_connect_state(this->ssl); + SSL_do_handshake(this->ssl); + SendPendingOutgoingDtlsData(); + SetTimeout(); + + break; + } + + case Role::SERVER: + { + MS_DEBUG_TAG(dtls, "running [role:server]"); + + SSL_set_accept_state(this->ssl); + SSL_do_handshake(this->ssl); + + break; + } + + default: + { + MS_ABORT("invalid local DTLS role"); + } + } + } + + bool DtlsTransport::SetRemoteFingerprint(Fingerprint fingerprint) + { + MS_TRACE(); + + MS_ASSERT( + fingerprint.algorithm != FingerprintAlgorithm::NONE, "no fingerprint algorithm provided"); + + this->remoteFingerprint = fingerprint; + + // The remote fingerpring may have been set after DTLS handshake was done, + // so we may need to process it now. + if (this->handshakeDone && this->state != DtlsState::CONNECTED) + { + MS_DEBUG_TAG(dtls, "handshake already done, processing it right now"); + + return ProcessHandshake(); + } + + return true; + } + + void DtlsTransport::ProcessDtlsData(const uint8_t* data, size_t len) + { + MS_TRACE(); + + int written; + int read; - if (!IsRunning()) - { - MS_ERROR("cannot process data while not running"); + if (!IsRunning()) + { + MS_ERROR("cannot process data while not running"); - return; - } + return; + } - // Write the received DTLS data into the sslBioFromNetwork. - written = - BIO_write(this->sslBioFromNetwork, static_cast(data), static_cast(len)); + // Write the received DTLS data into the sslBioFromNetwork. + written = + BIO_write(this->sslBioFromNetwork, static_cast(data), static_cast(len)); - if (written != static_cast(len)) - { - MS_WARN_TAG( - dtls, - "OpenSSL BIO_write() wrote less (%zu bytes) than given data (%zu bytes)", - static_cast(written), - len); - } - - // Must call SSL_read() to process received DTLS data. - read = SSL_read(this->ssl, static_cast(DtlsTransport::sslReadBuffer), SslReadBufferSize); + if (written != static_cast(len)) + { + MS_WARN_TAG( + dtls, + "OpenSSL BIO_write() wrote less (%zu bytes) than given data (%zu bytes)", + static_cast(written), + len); + } + + // Must call SSL_read() to process received DTLS data. + read = SSL_read(this->ssl, static_cast(DtlsTransport::sslReadBuffer), SslReadBufferSize); - // Send data if it's ready. - SendPendingOutgoingDtlsData(); + // Send data if it's ready. + SendPendingOutgoingDtlsData(); - // Check SSL status and return if it is bad/closed. - if (!CheckStatus(read)) - return; + // Check SSL status and return if it is bad/closed. + if (!CheckStatus(read)) + return; - // Set/update the DTLS timeout. - if (!SetTimeout()) - return; + // Set/update the DTLS timeout. + if (!SetTimeout()) + return; - // Application data received. Notify to the listener. - if (read > 0) - { - // It is allowed to receive DTLS data even before validating remote fingerprint. - if (!this->handshakeDone) - { - MS_WARN_TAG(dtls, "ignoring application data received while DTLS handshake not done"); - - return; - } + // Application data received. Notify to the listener. + if (read > 0) + { + // It is allowed to receive DTLS data even before validating remote fingerprint. + if (!this->handshakeDone) + { + MS_WARN_TAG(dtls, "ignoring application data received while DTLS handshake not done"); + + return; + } - // Notify the listener. - this->listener->OnDtlsTransportApplicationDataReceived( - this, (uint8_t*)DtlsTransport::sslReadBuffer, static_cast(read)); - } - } + // Notify the listener. + this->listener->OnDtlsTransportApplicationDataReceived( + this, (uint8_t*)DtlsTransport::sslReadBuffer, static_cast(read)); + } + } - void DtlsTransport::SendApplicationData(const uint8_t* data, size_t len) - { - MS_TRACE(); + void DtlsTransport::SendApplicationData(const uint8_t* data, size_t len) + { + MS_TRACE(); - // We cannot send data to the peer if its remote fingerprint is not validated. - if (this->state != DtlsState::CONNECTED) - { - MS_WARN_TAG(dtls, "cannot send application data while DTLS is not fully connected"); + // We cannot send data to the peer if its remote fingerprint is not validated. + if (this->state != DtlsState::CONNECTED) + { + MS_WARN_TAG(dtls, "cannot send application data while DTLS is not fully connected"); - return; - } + return; + } - if (len == 0) - { - MS_WARN_TAG(dtls, "ignoring 0 length data"); + if (len == 0) + { + MS_WARN_TAG(dtls, "ignoring 0 length data"); - return; - } + return; + } - int written; + int written; - written = SSL_write(this->ssl, static_cast(data), static_cast(len)); + written = SSL_write(this->ssl, static_cast(data), static_cast(len)); - if (written < 0) - { - LOG_OPENSSL_ERROR("SSL_write() failed"); + if (written < 0) + { + LOG_OPENSSL_ERROR("SSL_write() failed"); - if (!CheckStatus(written)) - return; - } - else if (written != static_cast(len)) - { - MS_WARN_TAG( - dtls, "OpenSSL SSL_write() wrote less (%d bytes) than given data (%zu bytes)", written, len); - } + if (!CheckStatus(written)) + return; + } + else if (written != static_cast(len)) + { + MS_WARN_TAG( + dtls, "OpenSSL SSL_write() wrote less (%d bytes) than given data (%zu bytes)", written, len); + } - // Send data. - SendPendingOutgoingDtlsData(); - } - - void DtlsTransport::Reset() - { - MS_TRACE(); - - int ret; + // Send data. + SendPendingOutgoingDtlsData(); + } + + void DtlsTransport::Reset() + { + MS_TRACE(); + + int ret; - if (!IsRunning()) - return; - - MS_WARN_TAG(dtls, "resetting DTLS transport"); - - // Stop the DTLS timer. - this->timer = nullptr; - - // We need to reset the SSL instance so we need to "shutdown" it, but we - // don't want to send a Close Alert to the peer, so just don't call - // SendPendingOutgoingDTLSData(). - SSL_shutdown(this->ssl); - - this->localRole = Role::NONE; - this->state = DtlsState::NEW; - this->handshakeDone = false; - this->handshakeDoneNow = false; - - // Reset SSL status. - // NOTE: For this to properly work, SSL_shutdown() must be called before. - // NOTE: This may fail if not enough DTLS handshake data has been received, - // but we don't care so just clear the error queue. - ret = SSL_clear(this->ssl); - - if (ret == 0) - ERR_clear_error(); - } - - inline bool DtlsTransport::CheckStatus(int returnCode) - { - MS_TRACE(); - - int err; - bool wasHandshakeDone = this->handshakeDone; - - err = SSL_get_error(this->ssl, returnCode); - - switch (err) - { - case SSL_ERROR_NONE: - break; + if (!IsRunning()) + return; + + MS_WARN_TAG(dtls, "resetting DTLS transport"); + + // Stop the DTLS timer. + this->timer = nullptr; + + // We need to reset the SSL instance so we need to "shutdown" it, but we + // don't want to send a Close Alert to the peer, so just don't call + // SendPendingOutgoingDTLSData(). + SSL_shutdown(this->ssl); + + this->localRole = Role::NONE; + this->state = DtlsState::NEW; + this->handshakeDone = false; + this->handshakeDoneNow = false; + + // Reset SSL status. + // NOTE: For this to properly work, SSL_shutdown() must be called before. + // NOTE: This may fail if not enough DTLS handshake data has been received, + // but we don't care so just clear the error queue. + ret = SSL_clear(this->ssl); + + if (ret == 0) + ERR_clear_error(); + } + + inline bool DtlsTransport::CheckStatus(int returnCode) + { + MS_TRACE(); + + int err; + bool wasHandshakeDone = this->handshakeDone; + + err = SSL_get_error(this->ssl, returnCode); + + switch (err) + { + case SSL_ERROR_NONE: + break; - case SSL_ERROR_SSL: - LOG_OPENSSL_ERROR("SSL status: SSL_ERROR_SSL"); - break; + case SSL_ERROR_SSL: + LOG_OPENSSL_ERROR("SSL status: SSL_ERROR_SSL"); + break; - case SSL_ERROR_WANT_READ: - break; + case SSL_ERROR_WANT_READ: + break; - case SSL_ERROR_WANT_WRITE: - MS_WARN_TAG(dtls, "SSL status: SSL_ERROR_WANT_WRITE"); - break; + case SSL_ERROR_WANT_WRITE: + MS_WARN_TAG(dtls, "SSL status: SSL_ERROR_WANT_WRITE"); + break; - case SSL_ERROR_WANT_X509_LOOKUP: - MS_DEBUG_TAG(dtls, "SSL status: SSL_ERROR_WANT_X509_LOOKUP"); - break; + case SSL_ERROR_WANT_X509_LOOKUP: + MS_DEBUG_TAG(dtls, "SSL status: SSL_ERROR_WANT_X509_LOOKUP"); + break; - case SSL_ERROR_SYSCALL: - LOG_OPENSSL_ERROR("SSL status: SSL_ERROR_SYSCALL"); - break; + case SSL_ERROR_SYSCALL: + LOG_OPENSSL_ERROR("SSL status: SSL_ERROR_SYSCALL"); + break; - case SSL_ERROR_ZERO_RETURN: - break; + case SSL_ERROR_ZERO_RETURN: + break; - case SSL_ERROR_WANT_CONNECT: - MS_WARN_TAG(dtls, "SSL status: SSL_ERROR_WANT_CONNECT"); - break; + case SSL_ERROR_WANT_CONNECT: + MS_WARN_TAG(dtls, "SSL status: SSL_ERROR_WANT_CONNECT"); + break; - case SSL_ERROR_WANT_ACCEPT: - MS_WARN_TAG(dtls, "SSL status: SSL_ERROR_WANT_ACCEPT"); - break; + case SSL_ERROR_WANT_ACCEPT: + MS_WARN_TAG(dtls, "SSL status: SSL_ERROR_WANT_ACCEPT"); + break; - default: - MS_WARN_TAG(dtls, "SSL status: unknown error"); - } + default: + MS_WARN_TAG(dtls, "SSL status: unknown error"); + } - // Check if the handshake (or re-handshake) has been done right now. - if (this->handshakeDoneNow) - { - this->handshakeDoneNow = false; - this->handshakeDone = true; - - // Stop the timer. - this->timer = nullptr; + // Check if the handshake (or re-handshake) has been done right now. + if (this->handshakeDoneNow) + { + this->handshakeDoneNow = false; + this->handshakeDone = true; + + // Stop the timer. + this->timer = nullptr; - // Process the handshake just once (ignore if DTLS renegotiation). - if (!wasHandshakeDone && this->remoteFingerprint.algorithm != FingerprintAlgorithm::NONE) - return ProcessHandshake(); + // Process the handshake just once (ignore if DTLS renegotiation). + if (!wasHandshakeDone && this->remoteFingerprint.algorithm != FingerprintAlgorithm::NONE) + return ProcessHandshake(); - return true; - } - // Check if the peer sent close alert or a fatal error happened. - else if (((SSL_get_shutdown(this->ssl) & SSL_RECEIVED_SHUTDOWN) != 0) || err == SSL_ERROR_SSL || err == SSL_ERROR_SYSCALL) - { - if (this->state == DtlsState::CONNECTED) - { - MS_DEBUG_TAG(dtls, "disconnected"); + return true; + } + // Check if the peer sent close alert or a fatal error happened. + else if (((SSL_get_shutdown(this->ssl) & SSL_RECEIVED_SHUTDOWN) != 0) || err == SSL_ERROR_SSL || err == SSL_ERROR_SYSCALL) + { + if (this->state == DtlsState::CONNECTED) + { + MS_DEBUG_TAG(dtls, "disconnected"); - Reset(); + Reset(); - // Set state and notify the listener. - this->state = DtlsState::CLOSED; - this->listener->OnDtlsTransportClosed(this); - } - else - { - MS_WARN_TAG(dtls, "connection failed"); + // Set state and notify the listener. + this->state = DtlsState::CLOSED; + this->listener->OnDtlsTransportClosed(this); + } + else + { + MS_WARN_TAG(dtls, "connection failed"); - Reset(); - - // Set state and notify the listener. - this->state = DtlsState::FAILED; - this->listener->OnDtlsTransportFailed(this); - } - - return false; - } - else - { - return true; - } - } + Reset(); + + // Set state and notify the listener. + this->state = DtlsState::FAILED; + this->listener->OnDtlsTransportFailed(this); + } + + return false; + } + else + { + return true; + } + } - inline void DtlsTransport::SendPendingOutgoingDtlsData() - { - MS_TRACE(); + inline void DtlsTransport::SendPendingOutgoingDtlsData() + { + MS_TRACE(); - if (BIO_eof(this->sslBioToNetwork)) - return; + if (BIO_eof(this->sslBioToNetwork)) + return; - int64_t read; - char* data{ nullptr }; + int64_t read; + char* data{ nullptr }; - read = BIO_get_mem_data(this->sslBioToNetwork, &data); // NOLINT + read = BIO_get_mem_data(this->sslBioToNetwork, &data); // NOLINT - if (read <= 0) - return; - - MS_DEBUG_DEV("%" PRIu64 " bytes of DTLS data ready to sent to the peer", read); + if (read <= 0) + return; + + MS_DEBUG_DEV("%" PRIu64 " bytes of DTLS data ready to sent to the peer", read); - // Notify the listener. - this->listener->OnDtlsTransportSendData( - this, reinterpret_cast(data), static_cast(read)); + // Notify the listener. + this->listener->OnDtlsTransportSendData( + this, reinterpret_cast(data), static_cast(read)); - // Clear the BIO buffer. - // NOTE: the (void) avoids the -Wunused-value warning. - (void)BIO_reset(this->sslBioToNetwork); - } + // Clear the BIO buffer. + // NOTE: the (void) avoids the -Wunused-value warning. + (void)BIO_reset(this->sslBioToNetwork); + } - inline bool DtlsTransport::SetTimeout() - { - MS_TRACE(); + inline bool DtlsTransport::SetTimeout() + { + MS_TRACE(); - MS_ASSERT( - this->state == DtlsState::CONNECTING || this->state == DtlsState::CONNECTED, - "invalid DTLS state"); + MS_ASSERT( + this->state == DtlsState::CONNECTING || this->state == DtlsState::CONNECTED, + "invalid DTLS state"); - int64_t ret; + int64_t ret; struct timeval dtlsTimeout{ 0, 0 }; - uint64_t timeoutMs; + uint64_t timeoutMs; - // NOTE: If ret == 0 then ignore the value in dtlsTimeout. - // NOTE: No DTLSv_1_2_get_timeout() or DTLS_get_timeout() in OpenSSL 1.1.0-dev. - ret = DTLSv1_get_timeout(this->ssl, static_cast(&dtlsTimeout)); // NOLINT + // NOTE: If ret == 0 then ignore the value in dtlsTimeout. + // NOTE: No DTLSv_1_2_get_timeout() or DTLS_get_timeout() in OpenSSL 1.1.0-dev. + ret = DTLSv1_get_timeout(this->ssl, static_cast(&dtlsTimeout)); // NOLINT - if (ret == 0) - return true; + if (ret == 0) + return true; - timeoutMs = (dtlsTimeout.tv_sec * static_cast(1000)) + (dtlsTimeout.tv_usec / 1000); + timeoutMs = (dtlsTimeout.tv_sec * static_cast(1000)) + (dtlsTimeout.tv_usec / 1000); - if (timeoutMs == 0) - { - return true; - } - else if (timeoutMs < 30000) - { - MS_DEBUG_DEV("DTLS timer set in %" PRIu64 "ms", timeoutMs); + if (timeoutMs == 0) + { + return true; + } + else if (timeoutMs < 30000) + { + MS_DEBUG_DEV("DTLS timer set in %" PRIu64 "ms", timeoutMs); - weak_ptr weak_self = shared_from_this(); - this->timer = std::make_shared(timeoutMs / 1000.0f, [weak_self](){ - auto strong_self = weak_self.lock(); - if(strong_self){ + weak_ptr weak_self = shared_from_this(); + this->timer = std::make_shared(timeoutMs / 1000.0f, [weak_self](){ + auto strong_self = weak_self.lock(); + if(strong_self){ strong_self->OnTimer(); - } + } return true; - }, this->poller); + }, this->poller); - return true; - } - // NOTE: Don't start the timer again if the timeout is greater than 30 seconds. - else - { - MS_WARN_TAG(dtls, "DTLS timeout too high (%" PRIu64 "ms), resetting DLTS", timeoutMs); + return true; + } + // NOTE: Don't start the timer again if the timeout is greater than 30 seconds. + else + { + MS_WARN_TAG(dtls, "DTLS timeout too high (%" PRIu64 "ms), resetting DLTS", timeoutMs); - Reset(); + Reset(); - // Set state and notify the listener. - this->state = DtlsState::FAILED; - this->listener->OnDtlsTransportFailed(this); + // Set state and notify the listener. + this->state = DtlsState::FAILED; + this->listener->OnDtlsTransportFailed(this); - return false; - } - } + return false; + } + } - inline bool DtlsTransport::ProcessHandshake() - { - MS_TRACE(); + inline bool DtlsTransport::ProcessHandshake() + { + MS_TRACE(); - MS_ASSERT(this->handshakeDone, "handshake not done yet"); - MS_ASSERT( - this->remoteFingerprint.algorithm != FingerprintAlgorithm::NONE, "remote fingerprint not set"); + MS_ASSERT(this->handshakeDone, "handshake not done yet"); + MS_ASSERT( + this->remoteFingerprint.algorithm != FingerprintAlgorithm::NONE, "remote fingerprint not set"); - // Validate the remote fingerprint. - if (!CheckRemoteFingerprint()) - { - Reset(); + // Validate the remote fingerprint. + if (!CheckRemoteFingerprint()) + { + Reset(); - // Set state and notify the listener. - this->state = DtlsState::FAILED; - this->listener->OnDtlsTransportFailed(this); + // Set state and notify the listener. + this->state = DtlsState::FAILED; + this->listener->OnDtlsTransportFailed(this); - return false; - } + return false; + } - // Get the negotiated SRTP crypto suite. - RTC::SrtpSession::CryptoSuite srtpCryptoSuite = GetNegotiatedSrtpCryptoSuite(); + // Get the negotiated SRTP crypto suite. + RTC::SrtpSession::CryptoSuite srtpCryptoSuite = GetNegotiatedSrtpCryptoSuite(); - if (srtpCryptoSuite != RTC::SrtpSession::CryptoSuite::NONE) - { - // Extract the SRTP keys (will notify the listener with them). - ExtractSrtpKeys(srtpCryptoSuite); + if (srtpCryptoSuite != RTC::SrtpSession::CryptoSuite::NONE) + { + // Extract the SRTP keys (will notify the listener with them). + ExtractSrtpKeys(srtpCryptoSuite); - return true; - } + return true; + } - // NOTE: We assume that "use_srtp" DTLS extension is required even if - // there is no audio/video. - MS_WARN_2TAGS(dtls, srtp, "SRTP crypto suite not negotiated"); + // NOTE: We assume that "use_srtp" DTLS extension is required even if + // there is no audio/video. + MS_WARN_2TAGS(dtls, srtp, "SRTP crypto suite not negotiated"); - Reset(); + Reset(); - // Set state and notify the listener. - this->state = DtlsState::FAILED; - this->listener->OnDtlsTransportFailed(this); + // Set state and notify the listener. + this->state = DtlsState::FAILED; + this->listener->OnDtlsTransportFailed(this); - return false; - } + return false; + } - inline bool DtlsTransport::CheckRemoteFingerprint() - { - MS_TRACE(); + inline bool DtlsTransport::CheckRemoteFingerprint() + { + MS_TRACE(); - MS_ASSERT( - this->remoteFingerprint.algorithm != FingerprintAlgorithm::NONE, "remote fingerprint not set"); + MS_ASSERT( + this->remoteFingerprint.algorithm != FingerprintAlgorithm::NONE, "remote fingerprint not set"); - X509* certificate; - uint8_t binaryFingerprint[EVP_MAX_MD_SIZE]; - unsigned int size{ 0 }; - char hexFingerprint[(EVP_MAX_MD_SIZE * 3) + 1]; - const EVP_MD* hashFunction; - int ret; + X509* certificate; + uint8_t binaryFingerprint[EVP_MAX_MD_SIZE]; + unsigned int size{ 0 }; + char hexFingerprint[(EVP_MAX_MD_SIZE * 3) + 1]; + const EVP_MD* hashFunction; + int ret; - certificate = SSL_get_peer_certificate(this->ssl); + certificate = SSL_get_peer_certificate(this->ssl); - if (!certificate) - { - MS_WARN_TAG(dtls, "no certificate was provided by the peer"); + if (!certificate) + { + MS_WARN_TAG(dtls, "no certificate was provided by the peer"); - return false; - } + return false; + } - switch (this->remoteFingerprint.algorithm) - { - case FingerprintAlgorithm::SHA1: - hashFunction = EVP_sha1(); - break; + switch (this->remoteFingerprint.algorithm) + { + case FingerprintAlgorithm::SHA1: + hashFunction = EVP_sha1(); + break; - case FingerprintAlgorithm::SHA224: - hashFunction = EVP_sha224(); - break; + case FingerprintAlgorithm::SHA224: + hashFunction = EVP_sha224(); + break; - case FingerprintAlgorithm::SHA256: - hashFunction = EVP_sha256(); - break; + case FingerprintAlgorithm::SHA256: + hashFunction = EVP_sha256(); + break; - case FingerprintAlgorithm::SHA384: - hashFunction = EVP_sha384(); - break; + case FingerprintAlgorithm::SHA384: + hashFunction = EVP_sha384(); + break; - case FingerprintAlgorithm::SHA512: - hashFunction = EVP_sha512(); - break; + case FingerprintAlgorithm::SHA512: + hashFunction = EVP_sha512(); + break; - default: - MS_ABORT("unknown algorithm"); - } + default: + MS_ABORT("unknown algorithm"); + } - // Compare the remote fingerprint with the value given via signaling. - ret = X509_digest(certificate, hashFunction, binaryFingerprint, &size); + // Compare the remote fingerprint with the value given via signaling. + ret = X509_digest(certificate, hashFunction, binaryFingerprint, &size); - if (ret == 0) - { - MS_ERROR("X509_digest() failed"); + if (ret == 0) + { + MS_ERROR("X509_digest() failed"); - X509_free(certificate); + X509_free(certificate); - return false; - } + return false; + } - // Convert to hexadecimal format in uppercase with colons. - for (unsigned int i{ 0 }; i < size; ++i) - { - std::sprintf(hexFingerprint + (i * 3), "%.2X:", binaryFingerprint[i]); - } - hexFingerprint[(size * 3) - 1] = '\0'; + // Convert to hexadecimal format in uppercase with colons. + for (unsigned int i{ 0 }; i < size; ++i) + { + std::sprintf(hexFingerprint + (i * 3), "%.2X:", binaryFingerprint[i]); + } + hexFingerprint[(size * 3) - 1] = '\0'; - if (this->remoteFingerprint.value != hexFingerprint) - { - MS_WARN_TAG( - dtls, - "fingerprint in the remote certificate (%s) does not match the announced one (%s)", - hexFingerprint, - this->remoteFingerprint.value.c_str()); - X509_free(certificate); - return false; - } + if (this->remoteFingerprint.value != hexFingerprint) + { + MS_WARN_TAG( + dtls, + "fingerprint in the remote certificate (%s) does not match the announced one (%s)", + hexFingerprint, + this->remoteFingerprint.value.c_str()); + X509_free(certificate); + return false; + } - MS_DEBUG_TAG(dtls, "valid remote fingerprint"); + MS_DEBUG_TAG(dtls, "valid remote fingerprint"); - // Get the remote certificate in PEM format. + // Get the remote certificate in PEM format. - BIO* bio = BIO_new(BIO_s_mem()); + BIO* bio = BIO_new(BIO_s_mem()); - // Ensure the underlying BUF_MEM structure is also freed. - // NOTE: Avoid stupid "warning: value computed is not used [-Wunused-value]" since - // BIO_set_close() always returns 1. - (void)BIO_set_close(bio, BIO_CLOSE); + // Ensure the underlying BUF_MEM structure is also freed. + // NOTE: Avoid stupid "warning: value computed is not used [-Wunused-value]" since + // BIO_set_close() always returns 1. + (void)BIO_set_close(bio, BIO_CLOSE); - ret = PEM_write_bio_X509(bio, certificate); + ret = PEM_write_bio_X509(bio, certificate); - if (ret != 1) - { - LOG_OPENSSL_ERROR("PEM_write_bio_X509() failed"); + if (ret != 1) + { + LOG_OPENSSL_ERROR("PEM_write_bio_X509() failed"); - X509_free(certificate); - BIO_free(bio); - - return false; - } - - BUF_MEM* mem; + X509_free(certificate); + BIO_free(bio); + + return false; + } + + BUF_MEM* mem; - BIO_get_mem_ptr(bio, &mem); // NOLINT[cppcoreguidelines-pro-type-cstyle-cast] - - if (!mem || !mem->data || mem->length == 0u) - { - LOG_OPENSSL_ERROR("BIO_get_mem_ptr() failed"); - - X509_free(certificate); - BIO_free(bio); - - return false; - } - - this->remoteCert = std::string(mem->data, mem->length); - - X509_free(certificate); - BIO_free(bio); - - return true; - } - - inline void DtlsTransport::ExtractSrtpKeys(RTC::SrtpSession::CryptoSuite srtpCryptoSuite) - { - MS_TRACE(); - - size_t srtpKeyLength{ 0 }; - size_t srtpSaltLength{ 0 }; - size_t srtpMasterLength{ 0 }; - - switch (srtpCryptoSuite) - { - case RTC::SrtpSession::CryptoSuite::AES_CM_128_HMAC_SHA1_80: - case RTC::SrtpSession::CryptoSuite::AES_CM_128_HMAC_SHA1_32: - { - srtpKeyLength = SrtpMasterKeyLength; - srtpSaltLength = SrtpMasterSaltLength; - srtpMasterLength = SrtpMasterLength; - - break; - } - - case RTC::SrtpSession::CryptoSuite::AEAD_AES_256_GCM: - { - srtpKeyLength = SrtpAesGcm256MasterKeyLength; - srtpSaltLength = SrtpAesGcm256MasterSaltLength; - srtpMasterLength = SrtpAesGcm256MasterLength; - - break; - } - - case RTC::SrtpSession::CryptoSuite::AEAD_AES_128_GCM: - { - srtpKeyLength = SrtpAesGcm128MasterKeyLength; - srtpSaltLength = SrtpAesGcm128MasterSaltLength; - srtpMasterLength = SrtpAesGcm128MasterLength; - - break; - } - - default: - { - MS_ABORT("unknown SRTP crypto suite"); - } - } - - auto* srtpMaterial = new uint8_t[srtpMasterLength * 2]; - uint8_t* srtpLocalKey{ nullptr }; - uint8_t* srtpLocalSalt{ nullptr }; - uint8_t* srtpRemoteKey{ nullptr }; - uint8_t* srtpRemoteSalt{ nullptr }; - auto* srtpLocalMasterKey = new uint8_t[srtpMasterLength]; - auto* srtpRemoteMasterKey = new uint8_t[srtpMasterLength]; - int ret; - - ret = SSL_export_keying_material( - this->ssl, srtpMaterial, srtpMasterLength * 2, "EXTRACTOR-dtls_srtp", 19, nullptr, 0, 0); - - MS_ASSERT(ret != 0, "SSL_export_keying_material() failed"); - - switch (this->localRole) - { - case Role::SERVER: - { - srtpRemoteKey = srtpMaterial; - srtpLocalKey = srtpRemoteKey + srtpKeyLength; - srtpRemoteSalt = srtpLocalKey + srtpKeyLength; - srtpLocalSalt = srtpRemoteSalt + srtpSaltLength; - - break; - } - - case Role::CLIENT: - { - srtpLocalKey = srtpMaterial; - srtpRemoteKey = srtpLocalKey + srtpKeyLength; - srtpLocalSalt = srtpRemoteKey + srtpKeyLength; - srtpRemoteSalt = srtpLocalSalt + srtpSaltLength; - - break; - } - - default: - { - MS_ABORT("no DTLS role set"); - } - } - - // Create the SRTP local master key. - std::memcpy(srtpLocalMasterKey, srtpLocalKey, srtpKeyLength); - std::memcpy(srtpLocalMasterKey + srtpKeyLength, srtpLocalSalt, srtpSaltLength); - // Create the SRTP remote master key. - std::memcpy(srtpRemoteMasterKey, srtpRemoteKey, srtpKeyLength); - std::memcpy(srtpRemoteMasterKey + srtpKeyLength, srtpRemoteSalt, srtpSaltLength); - - // Set state and notify the listener. - this->state = DtlsState::CONNECTED; - this->listener->OnDtlsTransportConnected( - this, - srtpCryptoSuite, - srtpLocalMasterKey, - srtpMasterLength, - srtpRemoteMasterKey, - srtpMasterLength, - this->remoteCert); - - delete[] srtpMaterial; - delete[] srtpLocalMasterKey; - delete[] srtpRemoteMasterKey; - } - - inline RTC::SrtpSession::CryptoSuite DtlsTransport::GetNegotiatedSrtpCryptoSuite() - { - MS_TRACE(); - - RTC::SrtpSession::CryptoSuite negotiatedSrtpCryptoSuite = RTC::SrtpSession::CryptoSuite::NONE; - - // Ensure that the SRTP crypto suite has been negotiated. - // NOTE: This is a OpenSSL type. - SRTP_PROTECTION_PROFILE* sslSrtpCryptoSuite = SSL_get_selected_srtp_profile(this->ssl); - - if (!sslSrtpCryptoSuite) - return negotiatedSrtpCryptoSuite; - - // Get the negotiated SRTP crypto suite. - for (auto& srtpCryptoSuite : DtlsTransport::srtpCryptoSuites) - { - SrtpCryptoSuiteMapEntry* cryptoSuiteEntry = std::addressof(srtpCryptoSuite); - - if (std::strcmp(sslSrtpCryptoSuite->name, cryptoSuiteEntry->name) == 0) - { - MS_DEBUG_2TAGS(dtls, srtp, "chosen SRTP crypto suite: %s", cryptoSuiteEntry->name); - - negotiatedSrtpCryptoSuite = cryptoSuiteEntry->cryptoSuite; - } - } - - MS_ASSERT( - negotiatedSrtpCryptoSuite != RTC::SrtpSession::CryptoSuite::NONE, - "chosen SRTP crypto suite is not an available one"); - - return negotiatedSrtpCryptoSuite; - } - - inline void DtlsTransport::OnSslInfo(int where, int ret) - { - MS_TRACE(); - - int w = where & -SSL_ST_MASK; - const char* role; - - if ((w & SSL_ST_CONNECT) != 0) - role = "client"; - else if ((w & SSL_ST_ACCEPT) != 0) - role = "server"; - else - role = "undefined"; - - if ((where & SSL_CB_LOOP) != 0) - { - MS_DEBUG_TAG(dtls, "[role:%s, action:'%s']", role, SSL_state_string_long(this->ssl)); - } - else if ((where & SSL_CB_ALERT) != 0) - { - const char* alertType; - - switch (*SSL_alert_type_string(ret)) - { - case 'W': - alertType = "warning"; - break; - - case 'F': - alertType = "fatal"; - break; - - default: - alertType = "undefined"; - } - - if ((where & SSL_CB_READ) != 0) - { - MS_WARN_TAG(dtls, "received DTLS %s alert: %s", alertType, SSL_alert_desc_string_long(ret)); - } - else if ((where & SSL_CB_WRITE) != 0) - { - MS_DEBUG_TAG(dtls, "sending DTLS %s alert: %s", alertType, SSL_alert_desc_string_long(ret)); - } - else - { - MS_DEBUG_TAG(dtls, "DTLS %s alert: %s", alertType, SSL_alert_desc_string_long(ret)); - } - } - else if ((where & SSL_CB_EXIT) != 0) - { - if (ret == 0) - MS_DEBUG_TAG(dtls, "[role:%s, failed:'%s']", role, SSL_state_string_long(this->ssl)); - else if (ret < 0) - MS_DEBUG_TAG(dtls, "role: %s, waiting:'%s']", role, SSL_state_string_long(this->ssl)); - } - else if ((where & SSL_CB_HANDSHAKE_START) != 0) - { - MS_DEBUG_TAG(dtls, "DTLS handshake start"); - } - else if ((where & SSL_CB_HANDSHAKE_DONE) != 0) - { - MS_DEBUG_TAG(dtls, "DTLS handshake done"); - - this->handshakeDoneNow = true; - } - - // NOTE: checking SSL_get_shutdown(this->ssl) & SSL_RECEIVED_SHUTDOWN here upon - // receipt of a close alert does not work (the flag is set after this callback). - } - - inline void DtlsTransport::OnTimer() - { - MS_TRACE(); - - // Workaround for https://github.com/openssl/openssl/issues/7998. - if (this->handshakeDone) - { - MS_DEBUG_DEV("handshake is done so return"); - - return; - } - - DTLSv1_handle_timeout(this->ssl); - - // If required, send DTLS data. - SendPendingOutgoingDtlsData(); - - // Set the DTLS timer again. - SetTimeout(); - } + BIO_get_mem_ptr(bio, &mem); // NOLINT[cppcoreguidelines-pro-type-cstyle-cast] + + if (!mem || !mem->data || mem->length == 0u) + { + LOG_OPENSSL_ERROR("BIO_get_mem_ptr() failed"); + + X509_free(certificate); + BIO_free(bio); + + return false; + } + + this->remoteCert = std::string(mem->data, mem->length); + + X509_free(certificate); + BIO_free(bio); + + return true; + } + + inline void DtlsTransport::ExtractSrtpKeys(RTC::SrtpSession::CryptoSuite srtpCryptoSuite) + { + MS_TRACE(); + + size_t srtpKeyLength{ 0 }; + size_t srtpSaltLength{ 0 }; + size_t srtpMasterLength{ 0 }; + + switch (srtpCryptoSuite) + { + case RTC::SrtpSession::CryptoSuite::AES_CM_128_HMAC_SHA1_80: + case RTC::SrtpSession::CryptoSuite::AES_CM_128_HMAC_SHA1_32: + { + srtpKeyLength = SrtpMasterKeyLength; + srtpSaltLength = SrtpMasterSaltLength; + srtpMasterLength = SrtpMasterLength; + + break; + } + + case RTC::SrtpSession::CryptoSuite::AEAD_AES_256_GCM: + { + srtpKeyLength = SrtpAesGcm256MasterKeyLength; + srtpSaltLength = SrtpAesGcm256MasterSaltLength; + srtpMasterLength = SrtpAesGcm256MasterLength; + + break; + } + + case RTC::SrtpSession::CryptoSuite::AEAD_AES_128_GCM: + { + srtpKeyLength = SrtpAesGcm128MasterKeyLength; + srtpSaltLength = SrtpAesGcm128MasterSaltLength; + srtpMasterLength = SrtpAesGcm128MasterLength; + + break; + } + + default: + { + MS_ABORT("unknown SRTP crypto suite"); + } + } + + auto* srtpMaterial = new uint8_t[srtpMasterLength * 2]; + uint8_t* srtpLocalKey{ nullptr }; + uint8_t* srtpLocalSalt{ nullptr }; + uint8_t* srtpRemoteKey{ nullptr }; + uint8_t* srtpRemoteSalt{ nullptr }; + auto* srtpLocalMasterKey = new uint8_t[srtpMasterLength]; + auto* srtpRemoteMasterKey = new uint8_t[srtpMasterLength]; + int ret; + + ret = SSL_export_keying_material( + this->ssl, srtpMaterial, srtpMasterLength * 2, "EXTRACTOR-dtls_srtp", 19, nullptr, 0, 0); + + MS_ASSERT(ret != 0, "SSL_export_keying_material() failed"); + + switch (this->localRole) + { + case Role::SERVER: + { + srtpRemoteKey = srtpMaterial; + srtpLocalKey = srtpRemoteKey + srtpKeyLength; + srtpRemoteSalt = srtpLocalKey + srtpKeyLength; + srtpLocalSalt = srtpRemoteSalt + srtpSaltLength; + + break; + } + + case Role::CLIENT: + { + srtpLocalKey = srtpMaterial; + srtpRemoteKey = srtpLocalKey + srtpKeyLength; + srtpLocalSalt = srtpRemoteKey + srtpKeyLength; + srtpRemoteSalt = srtpLocalSalt + srtpSaltLength; + + break; + } + + default: + { + MS_ABORT("no DTLS role set"); + } + } + + // Create the SRTP local master key. + std::memcpy(srtpLocalMasterKey, srtpLocalKey, srtpKeyLength); + std::memcpy(srtpLocalMasterKey + srtpKeyLength, srtpLocalSalt, srtpSaltLength); + // Create the SRTP remote master key. + std::memcpy(srtpRemoteMasterKey, srtpRemoteKey, srtpKeyLength); + std::memcpy(srtpRemoteMasterKey + srtpKeyLength, srtpRemoteSalt, srtpSaltLength); + + // Set state and notify the listener. + this->state = DtlsState::CONNECTED; + this->listener->OnDtlsTransportConnected( + this, + srtpCryptoSuite, + srtpLocalMasterKey, + srtpMasterLength, + srtpRemoteMasterKey, + srtpMasterLength, + this->remoteCert); + + delete[] srtpMaterial; + delete[] srtpLocalMasterKey; + delete[] srtpRemoteMasterKey; + } + + inline RTC::SrtpSession::CryptoSuite DtlsTransport::GetNegotiatedSrtpCryptoSuite() + { + MS_TRACE(); + + RTC::SrtpSession::CryptoSuite negotiatedSrtpCryptoSuite = RTC::SrtpSession::CryptoSuite::NONE; + + // Ensure that the SRTP crypto suite has been negotiated. + // NOTE: This is a OpenSSL type. + SRTP_PROTECTION_PROFILE* sslSrtpCryptoSuite = SSL_get_selected_srtp_profile(this->ssl); + + if (!sslSrtpCryptoSuite) + return negotiatedSrtpCryptoSuite; + + // Get the negotiated SRTP crypto suite. + for (auto& srtpCryptoSuite : DtlsTransport::srtpCryptoSuites) + { + SrtpCryptoSuiteMapEntry* cryptoSuiteEntry = std::addressof(srtpCryptoSuite); + + if (std::strcmp(sslSrtpCryptoSuite->name, cryptoSuiteEntry->name) == 0) + { + MS_DEBUG_2TAGS(dtls, srtp, "chosen SRTP crypto suite: %s", cryptoSuiteEntry->name); + + negotiatedSrtpCryptoSuite = cryptoSuiteEntry->cryptoSuite; + } + } + + MS_ASSERT( + negotiatedSrtpCryptoSuite != RTC::SrtpSession::CryptoSuite::NONE, + "chosen SRTP crypto suite is not an available one"); + + return negotiatedSrtpCryptoSuite; + } + + inline void DtlsTransport::OnSslInfo(int where, int ret) + { + MS_TRACE(); + + int w = where & -SSL_ST_MASK; + const char* role; + + if ((w & SSL_ST_CONNECT) != 0) + role = "client"; + else if ((w & SSL_ST_ACCEPT) != 0) + role = "server"; + else + role = "undefined"; + + if ((where & SSL_CB_LOOP) != 0) + { + MS_DEBUG_TAG(dtls, "[role:%s, action:'%s']", role, SSL_state_string_long(this->ssl)); + } + else if ((where & SSL_CB_ALERT) != 0) + { + const char* alertType; + + switch (*SSL_alert_type_string(ret)) + { + case 'W': + alertType = "warning"; + break; + + case 'F': + alertType = "fatal"; + break; + + default: + alertType = "undefined"; + } + + if ((where & SSL_CB_READ) != 0) + { + MS_WARN_TAG(dtls, "received DTLS %s alert: %s", alertType, SSL_alert_desc_string_long(ret)); + } + else if ((where & SSL_CB_WRITE) != 0) + { + MS_DEBUG_TAG(dtls, "sending DTLS %s alert: %s", alertType, SSL_alert_desc_string_long(ret)); + } + else + { + MS_DEBUG_TAG(dtls, "DTLS %s alert: %s", alertType, SSL_alert_desc_string_long(ret)); + } + } + else if ((where & SSL_CB_EXIT) != 0) + { + if (ret == 0) + MS_DEBUG_TAG(dtls, "[role:%s, failed:'%s']", role, SSL_state_string_long(this->ssl)); + else if (ret < 0) + MS_DEBUG_TAG(dtls, "role: %s, waiting:'%s']", role, SSL_state_string_long(this->ssl)); + } + else if ((where & SSL_CB_HANDSHAKE_START) != 0) + { + MS_DEBUG_TAG(dtls, "DTLS handshake start"); + } + else if ((where & SSL_CB_HANDSHAKE_DONE) != 0) + { + MS_DEBUG_TAG(dtls, "DTLS handshake done"); + + this->handshakeDoneNow = true; + } + + // NOTE: checking SSL_get_shutdown(this->ssl) & SSL_RECEIVED_SHUTDOWN here upon + // receipt of a close alert does not work (the flag is set after this callback). + } + + inline void DtlsTransport::OnTimer() + { + MS_TRACE(); + + // Workaround for https://github.com/openssl/openssl/issues/7998. + if (this->handshakeDone) + { + MS_DEBUG_DEV("handshake is done so return"); + + return; + } + + DTLSv1_handle_timeout(this->ssl); + + // If required, send DTLS data. + SendPendingOutgoingDtlsData(); + + // Set the DTLS timer again. + SetTimeout(); + } } // namespace RTC diff --git a/webrtc/DtlsTransport.hpp b/webrtc/DtlsTransport.hpp index fb28a6a4..bf57d01d 100644 --- a/webrtc/DtlsTransport.hpp +++ b/webrtc/DtlsTransport.hpp @@ -33,50 +33,50 @@ using namespace toolkit; namespace RTC { class DtlsTransport : public std::enable_shared_from_this - { - public: - enum class DtlsState - { - NEW = 1, - CONNECTING, - CONNECTED, - FAILED, - CLOSED - }; + { + public: + enum class DtlsState + { + NEW = 1, + CONNECTING, + CONNECTED, + FAILED, + CLOSED + }; - public: - enum class Role - { - NONE = 0, - AUTO = 1, - CLIENT, - SERVER - }; + public: + enum class Role + { + NONE = 0, + AUTO = 1, + CLIENT, + SERVER + }; - public: - enum class FingerprintAlgorithm - { - NONE = 0, - SHA1 = 1, - SHA224, - SHA256, - SHA384, - SHA512 - }; + public: + enum class FingerprintAlgorithm + { + NONE = 0, + SHA1 = 1, + SHA224, + SHA256, + SHA384, + SHA512 + }; - public: - struct Fingerprint - { - FingerprintAlgorithm algorithm{ FingerprintAlgorithm::NONE }; - std::string value; - }; + public: + struct Fingerprint + { + FingerprintAlgorithm algorithm{ FingerprintAlgorithm::NONE }; + std::string value; + }; - private: - struct SrtpCryptoSuiteMapEntry - { - RTC::SrtpSession::CryptoSuite cryptoSuite; - const char* name; - }; + private: + struct SrtpCryptoSuiteMapEntry + { + RTC::SrtpSession::CryptoSuite cryptoSuite; + const char* name; + }; class DtlsEnvironment : public std::enable_shared_from_this { @@ -99,154 +99,154 @@ namespace RTC std::vector localFingerprints; }; - public: - class Listener - { - public: - // DTLS is in the process of negotiating a secure connection. Incoming - // media can flow through. - // NOTE: The caller MUST NOT call any method during this callback. - virtual void OnDtlsTransportConnecting(const RTC::DtlsTransport* dtlsTransport) = 0; - // DTLS has completed negotiation of a secure connection (including DTLS-SRTP - // and remote fingerprint verification). Outgoing media can now flow through. - // NOTE: The caller MUST NOT call any method during this callback. - virtual void OnDtlsTransportConnected( - const RTC::DtlsTransport* dtlsTransport, - RTC::SrtpSession::CryptoSuite srtpCryptoSuite, - uint8_t* srtpLocalKey, - size_t srtpLocalKeyLen, - uint8_t* srtpRemoteKey, - size_t srtpRemoteKeyLen, - std::string& remoteCert) = 0; - // The DTLS connection has been closed as the result of an error (such as a - // DTLS alert or a failure to validate the remote fingerprint). - virtual void OnDtlsTransportFailed(const RTC::DtlsTransport* dtlsTransport) = 0; - // The DTLS connection has been closed due to receipt of a close_notify alert. - virtual void OnDtlsTransportClosed(const RTC::DtlsTransport* dtlsTransport) = 0; - // Need to send DTLS data to the peer. - virtual void OnDtlsTransportSendData( - const RTC::DtlsTransport* dtlsTransport, const uint8_t* data, size_t len) = 0; - // DTLS application data received. - virtual void OnDtlsTransportApplicationDataReceived( - const RTC::DtlsTransport* dtlsTransport, const uint8_t* data, size_t len) = 0; - }; + public: + class Listener + { + public: + // DTLS is in the process of negotiating a secure connection. Incoming + // media can flow through. + // NOTE: The caller MUST NOT call any method during this callback. + virtual void OnDtlsTransportConnecting(const RTC::DtlsTransport* dtlsTransport) = 0; + // DTLS has completed negotiation of a secure connection (including DTLS-SRTP + // and remote fingerprint verification). Outgoing media can now flow through. + // NOTE: The caller MUST NOT call any method during this callback. + virtual void OnDtlsTransportConnected( + const RTC::DtlsTransport* dtlsTransport, + RTC::SrtpSession::CryptoSuite srtpCryptoSuite, + uint8_t* srtpLocalKey, + size_t srtpLocalKeyLen, + uint8_t* srtpRemoteKey, + size_t srtpRemoteKeyLen, + std::string& remoteCert) = 0; + // The DTLS connection has been closed as the result of an error (such as a + // DTLS alert or a failure to validate the remote fingerprint). + virtual void OnDtlsTransportFailed(const RTC::DtlsTransport* dtlsTransport) = 0; + // The DTLS connection has been closed due to receipt of a close_notify alert. + virtual void OnDtlsTransportClosed(const RTC::DtlsTransport* dtlsTransport) = 0; + // Need to send DTLS data to the peer. + virtual void OnDtlsTransportSendData( + const RTC::DtlsTransport* dtlsTransport, const uint8_t* data, size_t len) = 0; + // DTLS application data received. + virtual void OnDtlsTransportApplicationDataReceived( + const RTC::DtlsTransport* dtlsTransport, const uint8_t* data, size_t len) = 0; + }; - public: - static Role StringToRole(const std::string& role) - { - auto it = DtlsTransport::string2Role.find(role); + public: + static Role StringToRole(const std::string& role) + { + auto it = DtlsTransport::string2Role.find(role); - if (it != DtlsTransport::string2Role.end()) - return it->second; - else - return DtlsTransport::Role::NONE; - } - static FingerprintAlgorithm GetFingerprintAlgorithm(const std::string& fingerprint) - { - auto it = DtlsTransport::string2FingerprintAlgorithm.find(fingerprint); + if (it != DtlsTransport::string2Role.end()) + return it->second; + else + return DtlsTransport::Role::NONE; + } + static FingerprintAlgorithm GetFingerprintAlgorithm(const std::string& fingerprint) + { + auto it = DtlsTransport::string2FingerprintAlgorithm.find(fingerprint); - if (it != DtlsTransport::string2FingerprintAlgorithm.end()) - return it->second; - else - return DtlsTransport::FingerprintAlgorithm::NONE; - } - static std::string& GetFingerprintAlgorithmString(FingerprintAlgorithm fingerprint) - { - auto it = DtlsTransport::fingerprintAlgorithm2String.find(fingerprint); + if (it != DtlsTransport::string2FingerprintAlgorithm.end()) + return it->second; + else + return DtlsTransport::FingerprintAlgorithm::NONE; + } + static std::string& GetFingerprintAlgorithmString(FingerprintAlgorithm fingerprint) + { + auto it = DtlsTransport::fingerprintAlgorithm2String.find(fingerprint); - return it->second; - } - static bool IsDtls(const uint8_t* data, size_t len) - { - // clang-format off - return ( - // Minimum DTLS record length is 13 bytes. - (len >= 13) && - // DOC: https://tools.ietf.org/html/draft-ietf-avtcore-rfc5764-mux-fixes - (data[0] > 19 && data[0] < 64) - ); - // clang-format on - } - - private: - static std::map string2Role; - static std::map string2FingerprintAlgorithm; - static std::map fingerprintAlgorithm2String; - static std::vector srtpCryptoSuites; - - public: - DtlsTransport(EventPoller::Ptr poller, Listener* listener); - ~DtlsTransport(); - - public: - void Dump() const; - void Run(Role localRole); - std::vector& GetLocalFingerprints() const - { - return env->localFingerprints; - } - bool SetRemoteFingerprint(Fingerprint fingerprint); - void ProcessDtlsData(const uint8_t* data, size_t len); - DtlsState GetState() const - { - return this->state; - } - Role GetLocalRole() const - { - return this->localRole; - } - void SendApplicationData(const uint8_t* data, size_t len); - - private: - bool IsRunning() const - { - switch (this->state) - { - case DtlsState::NEW: - return false; - case DtlsState::CONNECTING: - case DtlsState::CONNECTED: - return true; - case DtlsState::FAILED: - case DtlsState::CLOSED: - return false; - } - - // Make GCC 4.9 happy. - return false; - } - void Reset(); - bool CheckStatus(int returnCode); - void SendPendingOutgoingDtlsData(); - bool SetTimeout(); - bool ProcessHandshake(); - bool CheckRemoteFingerprint(); - void ExtractSrtpKeys(RTC::SrtpSession::CryptoSuite srtpCryptoSuite); - RTC::SrtpSession::CryptoSuite GetNegotiatedSrtpCryptoSuite(); + return it->second; + } + static bool IsDtls(const uint8_t* data, size_t len) + { + // clang-format off + return ( + // Minimum DTLS record length is 13 bytes. + (len >= 13) && + // DOC: https://tools.ietf.org/html/draft-ietf-avtcore-rfc5764-mux-fixes + (data[0] > 19 && data[0] < 64) + ); + // clang-format on + } private: - void OnSslInfo(int where, int ret); - void OnTimer(); + static std::map string2Role; + static std::map string2FingerprintAlgorithm; + static std::map fingerprintAlgorithm2String; + static std::vector srtpCryptoSuites; - private: + public: + DtlsTransport(EventPoller::Ptr poller, Listener* listener); + ~DtlsTransport(); + + public: + void Dump() const; + void Run(Role localRole); + std::vector& GetLocalFingerprints() const + { + return env->localFingerprints; + } + bool SetRemoteFingerprint(Fingerprint fingerprint); + void ProcessDtlsData(const uint8_t* data, size_t len); + DtlsState GetState() const + { + return this->state; + } + Role GetLocalRole() const + { + return this->localRole; + } + void SendApplicationData(const uint8_t* data, size_t len); + + private: + bool IsRunning() const + { + switch (this->state) + { + case DtlsState::NEW: + return false; + case DtlsState::CONNECTING: + case DtlsState::CONNECTED: + return true; + case DtlsState::FAILED: + case DtlsState::CLOSED: + return false; + } + + // Make GCC 4.9 happy. + return false; + } + void Reset(); + bool CheckStatus(int returnCode); + void SendPendingOutgoingDtlsData(); + bool SetTimeout(); + bool ProcessHandshake(); + bool CheckRemoteFingerprint(); + void ExtractSrtpKeys(RTC::SrtpSession::CryptoSuite srtpCryptoSuite); + RTC::SrtpSession::CryptoSuite GetNegotiatedSrtpCryptoSuite(); + + private: + void OnSslInfo(int where, int ret); + void OnTimer(); + + private: DtlsEnvironment::Ptr env; EventPoller::Ptr poller; // Passed by argument. - Listener* listener{ nullptr }; - // Allocated by this. - SSL* ssl{ nullptr }; - BIO* sslBioFromNetwork{ nullptr }; // The BIO from which ssl reads. - BIO* sslBioToNetwork{ nullptr }; // The BIO in which ssl writes. - Timer::Ptr timer; - // Others. - DtlsState state{ DtlsState::NEW }; - Role localRole{ Role::NONE }; - Fingerprint remoteFingerprint; - bool handshakeDone{ false }; - bool handshakeDoneNow{ false }; - std::string remoteCert; - //最大不超过mtu - static constexpr int SslReadBufferSize{ 2000 }; + Listener* listener{ nullptr }; + // Allocated by this. + SSL* ssl{ nullptr }; + BIO* sslBioFromNetwork{ nullptr }; // The BIO from which ssl reads. + BIO* sslBioToNetwork{ nullptr }; // The BIO in which ssl writes. + Timer::Ptr timer; + // Others. + DtlsState state{ DtlsState::NEW }; + Role localRole{ Role::NONE }; + Fingerprint remoteFingerprint; + bool handshakeDone{ false }; + bool handshakeDoneNow{ false }; + std::string remoteCert; + //最大不超过mtu + static constexpr int SslReadBufferSize{ 2000 }; uint8_t sslReadBuffer[SslReadBufferSize]; }; } // namespace RTC diff --git a/webrtc/IceServer.cpp b/webrtc/IceServer.cpp index 7dfac0b7..f0f79358 100644 --- a/webrtc/IceServer.cpp +++ b/webrtc/IceServer.cpp @@ -24,505 +24,505 @@ OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. namespace RTC { - /* Static. */ - /* Instance methods. */ + /* Static. */ + /* Instance methods. */ - IceServer::IceServer(Listener* listener, const std::string& usernameFragment, const std::string& password) - : listener(listener), usernameFragment(usernameFragment), password(password) - { - MS_TRACE(); - } + IceServer::IceServer(Listener* listener, const std::string& usernameFragment, const std::string& password) + : listener(listener), usernameFragment(usernameFragment), password(password) + { + MS_TRACE(); + } - void IceServer::ProcessStunPacket(RTC::StunPacket* packet, RTC::TransportTuple* tuple) - { - MS_TRACE(); + void IceServer::ProcessStunPacket(RTC::StunPacket* packet, RTC::TransportTuple* tuple) + { + MS_TRACE(); - // Must be a Binding method. - if (packet->GetMethod() != RTC::StunPacket::Method::BINDING) - { - if (packet->GetClass() == RTC::StunPacket::Class::REQUEST) - { - MS_WARN_TAG( - ice, - "unknown method %#.3x in STUN Request => 400", - static_cast(packet->GetMethod())); + // Must be a Binding method. + if (packet->GetMethod() != RTC::StunPacket::Method::BINDING) + { + if (packet->GetClass() == RTC::StunPacket::Class::REQUEST) + { + MS_WARN_TAG( + ice, + "unknown method %#.3x in STUN Request => 400", + static_cast(packet->GetMethod())); - // Reply 400. - RTC::StunPacket* response = packet->CreateErrorResponse(400); + // Reply 400. + RTC::StunPacket* response = packet->CreateErrorResponse(400); - response->Serialize(StunSerializeBuffer); - this->listener->OnIceServerSendStunPacket(this, response, tuple); + response->Serialize(StunSerializeBuffer); + this->listener->OnIceServerSendStunPacket(this, response, tuple); - delete response; - } - else - { - MS_WARN_TAG( - ice, - "ignoring STUN Indication or Response with unknown method %#.3x", - static_cast(packet->GetMethod())); - } + delete response; + } + else + { + MS_WARN_TAG( + ice, + "ignoring STUN Indication or Response with unknown method %#.3x", + static_cast(packet->GetMethod())); + } - return; - } + return; + } - // Must use FINGERPRINT (optional for ICE STUN indications). - if (!packet->HasFingerprint() && packet->GetClass() != RTC::StunPacket::Class::INDICATION) - { - if (packet->GetClass() == RTC::StunPacket::Class::REQUEST) - { - MS_WARN_TAG(ice, "STUN Binding Request without FINGERPRINT => 400"); + // Must use FINGERPRINT (optional for ICE STUN indications). + if (!packet->HasFingerprint() && packet->GetClass() != RTC::StunPacket::Class::INDICATION) + { + if (packet->GetClass() == RTC::StunPacket::Class::REQUEST) + { + MS_WARN_TAG(ice, "STUN Binding Request without FINGERPRINT => 400"); - // Reply 400. - RTC::StunPacket* response = packet->CreateErrorResponse(400); + // Reply 400. + RTC::StunPacket* response = packet->CreateErrorResponse(400); - response->Serialize(StunSerializeBuffer); - this->listener->OnIceServerSendStunPacket(this, response, tuple); + response->Serialize(StunSerializeBuffer); + this->listener->OnIceServerSendStunPacket(this, response, tuple); - delete response; - } - else - { - MS_WARN_TAG(ice, "ignoring STUN Binding Response without FINGERPRINT"); - } + delete response; + } + else + { + MS_WARN_TAG(ice, "ignoring STUN Binding Response without FINGERPRINT"); + } - return; - } + return; + } - switch (packet->GetClass()) - { - case RTC::StunPacket::Class::REQUEST: - { - // USERNAME, MESSAGE-INTEGRITY and PRIORITY are required. - if (!packet->HasMessageIntegrity() || (packet->GetPriority() == 0u) || packet->GetUsername().empty()) - { - MS_WARN_TAG(ice, "mising required attributes in STUN Binding Request => 400"); + switch (packet->GetClass()) + { + case RTC::StunPacket::Class::REQUEST: + { + // USERNAME, MESSAGE-INTEGRITY and PRIORITY are required. + if (!packet->HasMessageIntegrity() || (packet->GetPriority() == 0u) || packet->GetUsername().empty()) + { + MS_WARN_TAG(ice, "mising required attributes in STUN Binding Request => 400"); - // Reply 400. - RTC::StunPacket* response = packet->CreateErrorResponse(400); + // Reply 400. + RTC::StunPacket* response = packet->CreateErrorResponse(400); - response->Serialize(StunSerializeBuffer); - this->listener->OnIceServerSendStunPacket(this, response, tuple); + response->Serialize(StunSerializeBuffer); + this->listener->OnIceServerSendStunPacket(this, response, tuple); - delete response; + delete response; - return; - } + return; + } - // Check authentication. - switch (packet->CheckAuthentication(this->usernameFragment, this->password)) - { - case RTC::StunPacket::Authentication::OK: - { - if (!this->oldPassword.empty()) - { - MS_DEBUG_TAG(ice, "new ICE credentials applied"); + // Check authentication. + switch (packet->CheckAuthentication(this->usernameFragment, this->password)) + { + case RTC::StunPacket::Authentication::OK: + { + if (!this->oldPassword.empty()) + { + MS_DEBUG_TAG(ice, "new ICE credentials applied"); - this->oldUsernameFragment.clear(); - this->oldPassword.clear(); - } + this->oldUsernameFragment.clear(); + this->oldPassword.clear(); + } - break; - } + break; + } - case RTC::StunPacket::Authentication::UNAUTHORIZED: - { - // We may have changed our usernameFragment and password, so check - // the old ones. - // clang-format off - if ( - !this->oldUsernameFragment.empty() && - !this->oldPassword.empty() && - packet->CheckAuthentication(this->oldUsernameFragment, this->oldPassword) == RTC::StunPacket::Authentication::OK - ) - // clang-format on - { - MS_DEBUG_TAG(ice, "using old ICE credentials"); + case RTC::StunPacket::Authentication::UNAUTHORIZED: + { + // We may have changed our usernameFragment and password, so check + // the old ones. + // clang-format off + if ( + !this->oldUsernameFragment.empty() && + !this->oldPassword.empty() && + packet->CheckAuthentication(this->oldUsernameFragment, this->oldPassword) == RTC::StunPacket::Authentication::OK + ) + // clang-format on + { + MS_DEBUG_TAG(ice, "using old ICE credentials"); - break; - } + break; + } - MS_WARN_TAG(ice, "wrong authentication in STUN Binding Request => 401"); + MS_WARN_TAG(ice, "wrong authentication in STUN Binding Request => 401"); - // Reply 401. - RTC::StunPacket* response = packet->CreateErrorResponse(401); + // Reply 401. + RTC::StunPacket* response = packet->CreateErrorResponse(401); - response->Serialize(StunSerializeBuffer); - this->listener->OnIceServerSendStunPacket(this, response, tuple); + response->Serialize(StunSerializeBuffer); + this->listener->OnIceServerSendStunPacket(this, response, tuple); - delete response; + delete response; - return; - } + return; + } - case RTC::StunPacket::Authentication::BAD_REQUEST: - { - MS_WARN_TAG(ice, "cannot check authentication in STUN Binding Request => 400"); + case RTC::StunPacket::Authentication::BAD_REQUEST: + { + MS_WARN_TAG(ice, "cannot check authentication in STUN Binding Request => 400"); - // Reply 400. - RTC::StunPacket* response = packet->CreateErrorResponse(400); + // Reply 400. + RTC::StunPacket* response = packet->CreateErrorResponse(400); - response->Serialize(StunSerializeBuffer); - this->listener->OnIceServerSendStunPacket(this, response, tuple); + response->Serialize(StunSerializeBuffer); + this->listener->OnIceServerSendStunPacket(this, response, tuple); - delete response; + delete response; - return; - } - } + return; + } + } #if 0 - // The remote peer must be ICE controlling. - if (packet->GetIceControlled()) - { - MS_WARN_TAG(ice, "peer indicates ICE-CONTROLLED in STUN Binding Request => 487"); + // The remote peer must be ICE controlling. + if (packet->GetIceControlled()) + { + MS_WARN_TAG(ice, "peer indicates ICE-CONTROLLED in STUN Binding Request => 487"); - // Reply 487 (Role Conflict). - RTC::StunPacket* response = packet->CreateErrorResponse(487); + // Reply 487 (Role Conflict). + RTC::StunPacket* response = packet->CreateErrorResponse(487); - response->Serialize(StunSerializeBuffer); - this->listener->OnIceServerSendStunPacket(this, response, tuple); + response->Serialize(StunSerializeBuffer); + this->listener->OnIceServerSendStunPacket(this, response, tuple); - delete response; + delete response; - return; - } + return; + } #endif - //MS_DEBUG_DEV( - // "processing STUN Binding Request [Priority:%" PRIu32 ", UseCandidate:%s]", - // static_cast(packet->GetPriority()), - // packet->HasUseCandidate() ? "true" : "false"); + //MS_DEBUG_DEV( + // "processing STUN Binding Request [Priority:%" PRIu32 ", UseCandidate:%s]", + // static_cast(packet->GetPriority()), + // packet->HasUseCandidate() ? "true" : "false"); - // Create a success response. - RTC::StunPacket* response = packet->CreateSuccessResponse(); + // Create a success response. + RTC::StunPacket* response = packet->CreateSuccessResponse(); - sockaddr_storage peerAddr; - socklen_t addr_len = sizeof(peerAddr); - getpeername(tuple->getSock()->rawFD(), (struct sockaddr *)&peerAddr, &addr_len); - - // Add XOR-MAPPED-ADDRESS. - response->SetXorMappedAddress((struct sockaddr *)&peerAddr); + sockaddr_storage peerAddr; + socklen_t addr_len = sizeof(peerAddr); + getpeername(tuple->getSock()->rawFD(), (struct sockaddr *)&peerAddr, &addr_len); + + // Add XOR-MAPPED-ADDRESS. + response->SetXorMappedAddress((struct sockaddr *)&peerAddr); - // Authenticate the response. - if (this->oldPassword.empty()) - response->Authenticate(this->password); - else - response->Authenticate(this->oldPassword); + // Authenticate the response. + if (this->oldPassword.empty()) + response->Authenticate(this->password); + else + response->Authenticate(this->oldPassword); - // Send back. - response->Serialize(StunSerializeBuffer); - this->listener->OnIceServerSendStunPacket(this, response, tuple); + // Send back. + response->Serialize(StunSerializeBuffer); + this->listener->OnIceServerSendStunPacket(this, response, tuple); - delete response; + delete response; - // Handle the tuple. - HandleTuple(tuple, packet->HasUseCandidate()); + // Handle the tuple. + HandleTuple(tuple, packet->HasUseCandidate()); - break; - } + break; + } - case RTC::StunPacket::Class::INDICATION: - { - MS_DEBUG_TAG(ice, "STUN Binding Indication processed"); + case RTC::StunPacket::Class::INDICATION: + { + MS_DEBUG_TAG(ice, "STUN Binding Indication processed"); - break; - } - - case RTC::StunPacket::Class::SUCCESS_RESPONSE: - { - MS_DEBUG_TAG(ice, "STUN Binding Success Response processed"); - - break; - } + break; + } + + case RTC::StunPacket::Class::SUCCESS_RESPONSE: + { + MS_DEBUG_TAG(ice, "STUN Binding Success Response processed"); + + break; + } - case RTC::StunPacket::Class::ERROR_RESPONSE: - { - MS_DEBUG_TAG(ice, "STUN Binding Error Response processed"); + case RTC::StunPacket::Class::ERROR_RESPONSE: + { + MS_DEBUG_TAG(ice, "STUN Binding Error Response processed"); - break; - } - } - } - - bool IceServer::IsValidTuple(const RTC::TransportTuple* tuple) const - { - MS_TRACE(); - - return HasTuple(tuple) != nullptr; - } - - void IceServer::RemoveTuple(RTC::TransportTuple* tuple) - { - MS_TRACE(); - - RTC::TransportTuple* removedTuple{ nullptr }; - - // Find the removed tuple. - auto it = this->tuples.begin(); - - for (; it != this->tuples.end(); ++it) - { - RTC::TransportTuple* storedTuple = *it; - - if (storedTuple == tuple) - { - removedTuple = storedTuple; - - break; - } - } - - // If not found, ignore. - if (!removedTuple) - return; - - // Remove from the list of tuples. - this->tuples.erase(it); - - // If this is not the selected tuple, stop here. - if (removedTuple != this->selectedTuple) - return; - - // Otherwise this was the selected tuple. - this->selectedTuple = nullptr; - - // Mark the first tuple as selected tuple (if any). - if (!this->tuples.empty()) - { - SetSelectedTuple(this->tuples.front()); - } - // Or just emit 'disconnected'. - else - { - // Update state. - this->state = IceState::DISCONNECTED; - // Notify the listener. - this->listener->OnIceServerDisconnected(this); - } - } - - void IceServer::ForceSelectedTuple(const RTC::TransportTuple* tuple) - { - MS_TRACE(); - - MS_ASSERT( - this->selectedTuple, "cannot force the selected tuple if there was not a selected tuple"); - - auto* storedTuple = HasTuple(tuple); - - MS_ASSERT( - storedTuple, - "cannot force the selected tuple if the given tuple was not already a valid tuple"); - - // Mark it as selected tuple. - SetSelectedTuple(storedTuple); - } - - void IceServer::HandleTuple(RTC::TransportTuple* tuple, bool hasUseCandidate) - { - MS_TRACE(); - - switch (this->state) - { - case IceState::NEW: - { - // There should be no tuples. - MS_ASSERT( - this->tuples.empty(), "state is 'new' but there are %zu tuples", this->tuples.size()); - - // There shouldn't be a selected tuple. - MS_ASSERT(!this->selectedTuple, "state is 'new' but there is selected tuple"); - - if (!hasUseCandidate) - { - MS_DEBUG_TAG(ice, "transition from state 'new' to 'connected'"); - - // Store the tuple. - auto* storedTuple = AddTuple(tuple); - - // Mark it as selected tuple. - SetSelectedTuple(storedTuple); - // Update state. - this->state = IceState::CONNECTED; - // Notify the listener. - this->listener->OnIceServerConnected(this); - } - else - { - MS_DEBUG_TAG(ice, "transition from state 'new' to 'completed'"); - - // Store the tuple. - auto* storedTuple = AddTuple(tuple); - - // Mark it as selected tuple. - SetSelectedTuple(storedTuple); - // Update state. - this->state = IceState::COMPLETED; - // Notify the listener. - this->listener->OnIceServerCompleted(this); - } - - break; - } - - case IceState::DISCONNECTED: - { - // There should be no tuples. - MS_ASSERT( - this->tuples.empty(), - "state is 'disconnected' but there are %zu tuples", - this->tuples.size()); - - // There shouldn't be a selected tuple. - MS_ASSERT(!this->selectedTuple, "state is 'disconnected' but there is selected tuple"); - - if (!hasUseCandidate) - { - MS_DEBUG_TAG(ice, "transition from state 'disconnected' to 'connected'"); - - // Store the tuple. - auto* storedTuple = AddTuple(tuple); - - // Mark it as selected tuple. - SetSelectedTuple(storedTuple); - // Update state. - this->state = IceState::CONNECTED; - // Notify the listener. - this->listener->OnIceServerConnected(this); - } - else - { - MS_DEBUG_TAG(ice, "transition from state 'disconnected' to 'completed'"); - - // Store the tuple. - auto* storedTuple = AddTuple(tuple); - - // Mark it as selected tuple. - SetSelectedTuple(storedTuple); - // Update state. - this->state = IceState::COMPLETED; - // Notify the listener. - this->listener->OnIceServerCompleted(this); - } - - break; - } - - case IceState::CONNECTED: - { - // There should be some tuples. - MS_ASSERT(!this->tuples.empty(), "state is 'connected' but there are no tuples"); - - // There should be a selected tuple. - MS_ASSERT(this->selectedTuple, "state is 'connected' but there is not selected tuple"); - - if (!hasUseCandidate) - { - // If a new tuple store it. - if (!HasTuple(tuple)) - AddTuple(tuple); - } - else - { - MS_DEBUG_TAG(ice, "transition from state 'connected' to 'completed'"); - - auto* storedTuple = HasTuple(tuple); - - // If a new tuple store it. - if (!storedTuple) - storedTuple = AddTuple(tuple); - - // Mark it as selected tuple. - SetSelectedTuple(storedTuple); - // Update state. - this->state = IceState::COMPLETED; - // Notify the listener. - this->listener->OnIceServerCompleted(this); - } - - break; - } - - case IceState::COMPLETED: - { - // There should be some tuples. - MS_ASSERT(!this->tuples.empty(), "state is 'completed' but there are no tuples"); - - // There should be a selected tuple. - MS_ASSERT(this->selectedTuple, "state is 'completed' but there is not selected tuple"); - - if (!hasUseCandidate) - { - // If a new tuple store it. - if (!HasTuple(tuple)) - AddTuple(tuple); - } - else - { - auto* storedTuple = HasTuple(tuple); - - // If a new tuple store it. - if (!storedTuple) - storedTuple = AddTuple(tuple); - - // Mark it as selected tuple. - SetSelectedTuple(storedTuple); - } - - break; - } - } - } - - inline RTC::TransportTuple* IceServer::AddTuple(RTC::TransportTuple* tuple) - { - MS_TRACE(); - - // Add the new tuple at the beginning of the list. - this->tuples.push_front(tuple); - - // Return the address of the inserted tuple. - return tuple; - } - - inline RTC::TransportTuple* IceServer::HasTuple(const RTC::TransportTuple* tuple) const - { - MS_TRACE(); - - // If there is no selected tuple yet then we know that the tuples list - // is empty. - if (!this->selectedTuple) - return nullptr; - - // Check the current selected tuple. - if (selectedTuple == tuple) - return this->selectedTuple; - - // Otherwise check other stored tuples. - for (const auto& it : this->tuples) - { - auto& storedTuple = it; - if (storedTuple == tuple) - return storedTuple; - } - - return nullptr; - } - - inline void IceServer::SetSelectedTuple(RTC::TransportTuple* storedTuple) - { - MS_TRACE(); - - // If already the selected tuple do nothing. - if (storedTuple == this->selectedTuple) - return; - - this->selectedTuple = storedTuple; + break; + } + } + } + + bool IceServer::IsValidTuple(const RTC::TransportTuple* tuple) const + { + MS_TRACE(); + + return HasTuple(tuple) != nullptr; + } + + void IceServer::RemoveTuple(RTC::TransportTuple* tuple) + { + MS_TRACE(); + + RTC::TransportTuple* removedTuple{ nullptr }; + + // Find the removed tuple. + auto it = this->tuples.begin(); + + for (; it != this->tuples.end(); ++it) + { + RTC::TransportTuple* storedTuple = *it; + + if (storedTuple == tuple) + { + removedTuple = storedTuple; + + break; + } + } + + // If not found, ignore. + if (!removedTuple) + return; + + // Remove from the list of tuples. + this->tuples.erase(it); + + // If this is not the selected tuple, stop here. + if (removedTuple != this->selectedTuple) + return; + + // Otherwise this was the selected tuple. + this->selectedTuple = nullptr; + + // Mark the first tuple as selected tuple (if any). + if (!this->tuples.empty()) + { + SetSelectedTuple(this->tuples.front()); + } + // Or just emit 'disconnected'. + else + { + // Update state. + this->state = IceState::DISCONNECTED; + // Notify the listener. + this->listener->OnIceServerDisconnected(this); + } + } + + void IceServer::ForceSelectedTuple(const RTC::TransportTuple* tuple) + { + MS_TRACE(); + + MS_ASSERT( + this->selectedTuple, "cannot force the selected tuple if there was not a selected tuple"); + + auto* storedTuple = HasTuple(tuple); + + MS_ASSERT( + storedTuple, + "cannot force the selected tuple if the given tuple was not already a valid tuple"); + + // Mark it as selected tuple. + SetSelectedTuple(storedTuple); + } + + void IceServer::HandleTuple(RTC::TransportTuple* tuple, bool hasUseCandidate) + { + MS_TRACE(); + + switch (this->state) + { + case IceState::NEW: + { + // There should be no tuples. + MS_ASSERT( + this->tuples.empty(), "state is 'new' but there are %zu tuples", this->tuples.size()); + + // There shouldn't be a selected tuple. + MS_ASSERT(!this->selectedTuple, "state is 'new' but there is selected tuple"); + + if (!hasUseCandidate) + { + MS_DEBUG_TAG(ice, "transition from state 'new' to 'connected'"); + + // Store the tuple. + auto* storedTuple = AddTuple(tuple); + + // Mark it as selected tuple. + SetSelectedTuple(storedTuple); + // Update state. + this->state = IceState::CONNECTED; + // Notify the listener. + this->listener->OnIceServerConnected(this); + } + else + { + MS_DEBUG_TAG(ice, "transition from state 'new' to 'completed'"); + + // Store the tuple. + auto* storedTuple = AddTuple(tuple); + + // Mark it as selected tuple. + SetSelectedTuple(storedTuple); + // Update state. + this->state = IceState::COMPLETED; + // Notify the listener. + this->listener->OnIceServerCompleted(this); + } + + break; + } + + case IceState::DISCONNECTED: + { + // There should be no tuples. + MS_ASSERT( + this->tuples.empty(), + "state is 'disconnected' but there are %zu tuples", + this->tuples.size()); + + // There shouldn't be a selected tuple. + MS_ASSERT(!this->selectedTuple, "state is 'disconnected' but there is selected tuple"); + + if (!hasUseCandidate) + { + MS_DEBUG_TAG(ice, "transition from state 'disconnected' to 'connected'"); + + // Store the tuple. + auto* storedTuple = AddTuple(tuple); + + // Mark it as selected tuple. + SetSelectedTuple(storedTuple); + // Update state. + this->state = IceState::CONNECTED; + // Notify the listener. + this->listener->OnIceServerConnected(this); + } + else + { + MS_DEBUG_TAG(ice, "transition from state 'disconnected' to 'completed'"); + + // Store the tuple. + auto* storedTuple = AddTuple(tuple); + + // Mark it as selected tuple. + SetSelectedTuple(storedTuple); + // Update state. + this->state = IceState::COMPLETED; + // Notify the listener. + this->listener->OnIceServerCompleted(this); + } + + break; + } + + case IceState::CONNECTED: + { + // There should be some tuples. + MS_ASSERT(!this->tuples.empty(), "state is 'connected' but there are no tuples"); + + // There should be a selected tuple. + MS_ASSERT(this->selectedTuple, "state is 'connected' but there is not selected tuple"); + + if (!hasUseCandidate) + { + // If a new tuple store it. + if (!HasTuple(tuple)) + AddTuple(tuple); + } + else + { + MS_DEBUG_TAG(ice, "transition from state 'connected' to 'completed'"); + + auto* storedTuple = HasTuple(tuple); + + // If a new tuple store it. + if (!storedTuple) + storedTuple = AddTuple(tuple); + + // Mark it as selected tuple. + SetSelectedTuple(storedTuple); + // Update state. + this->state = IceState::COMPLETED; + // Notify the listener. + this->listener->OnIceServerCompleted(this); + } + + break; + } + + case IceState::COMPLETED: + { + // There should be some tuples. + MS_ASSERT(!this->tuples.empty(), "state is 'completed' but there are no tuples"); + + // There should be a selected tuple. + MS_ASSERT(this->selectedTuple, "state is 'completed' but there is not selected tuple"); + + if (!hasUseCandidate) + { + // If a new tuple store it. + if (!HasTuple(tuple)) + AddTuple(tuple); + } + else + { + auto* storedTuple = HasTuple(tuple); + + // If a new tuple store it. + if (!storedTuple) + storedTuple = AddTuple(tuple); + + // Mark it as selected tuple. + SetSelectedTuple(storedTuple); + } + + break; + } + } + } + + inline RTC::TransportTuple* IceServer::AddTuple(RTC::TransportTuple* tuple) + { + MS_TRACE(); + + // Add the new tuple at the beginning of the list. + this->tuples.push_front(tuple); + + // Return the address of the inserted tuple. + return tuple; + } + + inline RTC::TransportTuple* IceServer::HasTuple(const RTC::TransportTuple* tuple) const + { + MS_TRACE(); + + // If there is no selected tuple yet then we know that the tuples list + // is empty. + if (!this->selectedTuple) + return nullptr; + + // Check the current selected tuple. + if (selectedTuple == tuple) + return this->selectedTuple; + + // Otherwise check other stored tuples. + for (const auto& it : this->tuples) + { + auto& storedTuple = it; + if (storedTuple == tuple) + return storedTuple; + } + + return nullptr; + } + + inline void IceServer::SetSelectedTuple(RTC::TransportTuple* storedTuple) + { + MS_TRACE(); + + // If already the selected tuple do nothing. + if (storedTuple == this->selectedTuple) + return; + + this->selectedTuple = storedTuple; this->lastSelectedTuple = storedTuple->shared_from_this(); - // Notify the listener. - this->listener->OnIceServerSelectedTuple(this, this->selectedTuple); - } + // Notify the listener. + this->listener->OnIceServerSelectedTuple(this, this->selectedTuple); + } } // namespace RTC diff --git a/webrtc/IceServer.hpp b/webrtc/IceServer.hpp index 587fb1e9..92a0f31a 100644 --- a/webrtc/IceServer.hpp +++ b/webrtc/IceServer.hpp @@ -30,109 +30,109 @@ OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. namespace RTC { - using TransportTuple = toolkit::Session; - class IceServer - { - public: - enum class IceState - { - NEW = 1, - CONNECTED, - COMPLETED, - DISCONNECTED - }; + using TransportTuple = toolkit::Session; + class IceServer + { + public: + enum class IceState + { + NEW = 1, + CONNECTED, + COMPLETED, + DISCONNECTED + }; - public: - class Listener - { - public: - virtual ~Listener() = default; + public: + class Listener + { + public: + virtual ~Listener() = default; - public: - /** - * These callbacks are guaranteed to be called before ProcessStunPacket() - * returns, so the given pointers are still usable. - */ - virtual void OnIceServerSendStunPacket( - const RTC::IceServer* iceServer, const RTC::StunPacket* packet, RTC::TransportTuple* tuple) = 0; - virtual void OnIceServerSelectedTuple( - const RTC::IceServer* iceServer, RTC::TransportTuple* tuple) = 0; - virtual void OnIceServerConnected(const RTC::IceServer* iceServer) = 0; - virtual void OnIceServerCompleted(const RTC::IceServer* iceServer) = 0; - virtual void OnIceServerDisconnected(const RTC::IceServer* iceServer) = 0; - }; + public: + /** + * These callbacks are guaranteed to be called before ProcessStunPacket() + * returns, so the given pointers are still usable. + */ + virtual void OnIceServerSendStunPacket( + const RTC::IceServer* iceServer, const RTC::StunPacket* packet, RTC::TransportTuple* tuple) = 0; + virtual void OnIceServerSelectedTuple( + const RTC::IceServer* iceServer, RTC::TransportTuple* tuple) = 0; + virtual void OnIceServerConnected(const RTC::IceServer* iceServer) = 0; + virtual void OnIceServerCompleted(const RTC::IceServer* iceServer) = 0; + virtual void OnIceServerDisconnected(const RTC::IceServer* iceServer) = 0; + }; - public: - IceServer(Listener* listener, const std::string& usernameFragment, const std::string& password); + public: + IceServer(Listener* listener, const std::string& usernameFragment, const std::string& password); - public: - void ProcessStunPacket(RTC::StunPacket* packet, RTC::TransportTuple* tuple); - const std::string& GetUsernameFragment() const - { - return this->usernameFragment; - } - const std::string& GetPassword() const - { - return this->password; - } - IceState GetState() const - { - return this->state; - } - RTC::TransportTuple* GetSelectedTuple(bool try_last_tuple = false) const - { + public: + void ProcessStunPacket(RTC::StunPacket* packet, RTC::TransportTuple* tuple); + const std::string& GetUsernameFragment() const + { + return this->usernameFragment; + } + const std::string& GetPassword() const + { + return this->password; + } + IceState GetState() const + { + return this->state; + } + RTC::TransportTuple* GetSelectedTuple(bool try_last_tuple = false) const + { return try_last_tuple ? this->lastSelectedTuple.lock().get() : this->selectedTuple; } - void SetUsernameFragment(const std::string& usernameFragment) - { - this->oldUsernameFragment = this->usernameFragment; - this->usernameFragment = usernameFragment; - } - void SetPassword(const std::string& password) - { - this->oldPassword = this->password; - this->password = password; - } - bool IsValidTuple(const RTC::TransportTuple* tuple) const; - void RemoveTuple(RTC::TransportTuple* tuple); - // This should be just called in 'connected' or completed' state - // and the given tuple must be an already valid tuple. - void ForceSelectedTuple(const RTC::TransportTuple* tuple); + void SetUsernameFragment(const std::string& usernameFragment) + { + this->oldUsernameFragment = this->usernameFragment; + this->usernameFragment = usernameFragment; + } + void SetPassword(const std::string& password) + { + this->oldPassword = this->password; + this->password = password; + } + bool IsValidTuple(const RTC::TransportTuple* tuple) const; + void RemoveTuple(RTC::TransportTuple* tuple); + // This should be just called in 'connected' or completed' state + // and the given tuple must be an already valid tuple. + void ForceSelectedTuple(const RTC::TransportTuple* tuple); const std::list& GetTuples() const { return tuples; } private: - void HandleTuple(RTC::TransportTuple* tuple, bool hasUseCandidate); - /** - * Store the given tuple and return its stored address. - */ - RTC::TransportTuple* AddTuple(RTC::TransportTuple* tuple); - /** - * If the given tuple exists return its stored address, nullptr otherwise. - */ - RTC::TransportTuple* HasTuple(const RTC::TransportTuple* tuple) const; - /** - * Set the given tuple as the selected tuple. - * NOTE: The given tuple MUST be already stored within the list. - */ - void SetSelectedTuple(RTC::TransportTuple* storedTuple); + void HandleTuple(RTC::TransportTuple* tuple, bool hasUseCandidate); + /** + * Store the given tuple and return its stored address. + */ + RTC::TransportTuple* AddTuple(RTC::TransportTuple* tuple); + /** + * If the given tuple exists return its stored address, nullptr otherwise. + */ + RTC::TransportTuple* HasTuple(const RTC::TransportTuple* tuple) const; + /** + * Set the given tuple as the selected tuple. + * NOTE: The given tuple MUST be already stored within the list. + */ + void SetSelectedTuple(RTC::TransportTuple* storedTuple); - private: - // Passed by argument. - Listener* listener{ nullptr }; - // Others. - std::string usernameFragment; - std::string password; - std::string oldUsernameFragment; - std::string oldPassword; - IceState state{ IceState::NEW }; - std::list tuples; + private: + // Passed by argument. + Listener* listener{ nullptr }; + // Others. + std::string usernameFragment; + std::string password; + std::string oldUsernameFragment; + std::string oldPassword; + IceState state{ IceState::NEW }; + std::list tuples; RTC::TransportTuple *selectedTuple; std::weak_ptr lastSelectedTuple; - //最大不超过mtu + //最大不超过mtu static constexpr size_t StunSerializeBufferSize{ 1600 }; uint8_t StunSerializeBuffer[StunSerializeBufferSize]; - }; + }; } // namespace RTC #endif diff --git a/webrtc/SctpAssociation.cpp b/webrtc/SctpAssociation.cpp index 0aec443b..84a2c04f 100644 --- a/webrtc/SctpAssociation.cpp +++ b/webrtc/SctpAssociation.cpp @@ -23,14 +23,14 @@ static constexpr uint16_t MaxSctpStreams{ 65535 }; /* clang-format off */ static constexpr uint16_t EventTypes[] = { - SCTP_ADAPTATION_INDICATION, - SCTP_ASSOC_CHANGE, - SCTP_ASSOC_RESET_EVENT, - SCTP_REMOTE_ERROR, - SCTP_SHUTDOWN_EVENT, - SCTP_SEND_FAILED_EVENT, - SCTP_STREAM_RESET_EVENT, - SCTP_STREAM_CHANGE_EVENT + SCTP_ADAPTATION_INDICATION, + SCTP_ASSOC_CHANGE, + SCTP_ASSOC_RESET_EVENT, + SCTP_REMOTE_ERROR, + SCTP_SHUTDOWN_EVENT, + SCTP_SEND_FAILED_EVENT, + SCTP_STREAM_RESET_EVENT, + SCTP_STREAM_CHANGE_EVENT }; /* clang-format on */ @@ -44,45 +44,45 @@ inline static int onRecvSctpData( int flags, void* ulpInfo) { - auto* sctpAssociation = static_cast(ulpInfo); + auto* sctpAssociation = static_cast(ulpInfo); - if (sctpAssociation == nullptr) - { - std::free(data); + if (sctpAssociation == nullptr) + { + std::free(data); - return 0; - } + return 0; + } - if (flags & MSG_NOTIFICATION) - { - sctpAssociation->OnUsrSctpReceiveSctpNotification( - static_cast(data), len); - } - else - { - uint16_t streamId = rcv.rcv_sid; - uint32_t ppid = ntohl(rcv.rcv_ppid); - uint16_t ssn = rcv.rcv_ssn; + if (flags & MSG_NOTIFICATION) + { + sctpAssociation->OnUsrSctpReceiveSctpNotification( + static_cast(data), len); + } + else + { + uint16_t streamId = rcv.rcv_sid; + uint32_t ppid = ntohl(rcv.rcv_ppid); + uint16_t ssn = rcv.rcv_ssn; - MS_DEBUG_TAG( - sctp, - "data chunk received [length:%zu, streamId:%" PRIu16 ", SSN:%" PRIu16 ", TSN:%" PRIu32 - ", PPID:%" PRIu32 ", context:%" PRIu32 ", flags:%d]", - len, - rcv.rcv_sid, - rcv.rcv_ssn, - rcv.rcv_tsn, - ntohl(rcv.rcv_ppid), - rcv.rcv_context, - flags); + MS_DEBUG_TAG( + sctp, + "data chunk received [length:%zu, streamId:%" PRIu16 ", SSN:%" PRIu16 ", TSN:%" PRIu32 + ", PPID:%" PRIu32 ", context:%" PRIu32 ", flags:%d]", + len, + rcv.rcv_sid, + rcv.rcv_ssn, + rcv.rcv_tsn, + ntohl(rcv.rcv_ppid), + rcv.rcv_context, + flags); - sctpAssociation->OnUsrSctpReceiveSctpData( - streamId, ssn, ppid, flags, static_cast(data), len); - } + sctpAssociation->OnUsrSctpReceiveSctpData( + streamId, ssn, ppid, flags, static_cast(data), len); + } - std::free(data); + std::free(data); - return 1; + return 1; } /* Static methods for usrsctp global callbacks. */ @@ -136,824 +136,824 @@ namespace RTC //////////////////////////////////////////////////////////////////////////////////// - /* Instance methods. */ + /* Instance methods. */ - SctpAssociation::SctpAssociation( - Listener* listener, uint16_t os, uint16_t mis, size_t maxSctpMessageSize, bool isDataChannel) - : listener(listener), os(os), mis(mis), maxSctpMessageSize(maxSctpMessageSize), - isDataChannel(isDataChannel) - { - MS_TRACE(); + SctpAssociation::SctpAssociation( + Listener* listener, uint16_t os, uint16_t mis, size_t maxSctpMessageSize, bool isDataChannel) + : listener(listener), os(os), mis(mis), maxSctpMessageSize(maxSctpMessageSize), + isDataChannel(isDataChannel) + { + MS_TRACE(); _env = SctpEnv::Instance().shared_from_this(); - // Register ourselves in usrsctp. - usrsctp_register_address(static_cast(this)); + // Register ourselves in usrsctp. + usrsctp_register_address(static_cast(this)); - int ret; + int ret; - this->socket = usrsctp_socket( - AF_CONN, SOCK_STREAM, IPPROTO_SCTP, onRecvSctpData, nullptr, 0, static_cast(this)); + this->socket = usrsctp_socket( + AF_CONN, SOCK_STREAM, IPPROTO_SCTP, onRecvSctpData, nullptr, 0, static_cast(this)); - if (this->socket == nullptr) - MS_THROW_ERROR("usrsctp_socket() failed: %s", std::strerror(errno)); + if (this->socket == nullptr) + MS_THROW_ERROR("usrsctp_socket() failed: %s", std::strerror(errno)); - usrsctp_set_ulpinfo(this->socket, static_cast(this)); + usrsctp_set_ulpinfo(this->socket, static_cast(this)); - // Make the socket non-blocking. - ret = usrsctp_set_non_blocking(this->socket, 1); + // Make the socket non-blocking. + ret = usrsctp_set_non_blocking(this->socket, 1); - if (ret < 0) - MS_THROW_ERROR("usrsctp_set_non_blocking() failed: %s", std::strerror(errno)); + if (ret < 0) + MS_THROW_ERROR("usrsctp_set_non_blocking() failed: %s", std::strerror(errno)); - // Set SO_LINGER. - // This ensures that the usrsctp close call deletes the association. This - // prevents usrsctp from calling the global send callback with references to - // this class as the address. - struct linger lingerOpt; // NOLINT(cppcoreguidelines-pro-type-member-init) + // Set SO_LINGER. + // This ensures that the usrsctp close call deletes the association. This + // prevents usrsctp from calling the global send callback with references to + // this class as the address. + struct linger lingerOpt; // NOLINT(cppcoreguidelines-pro-type-member-init) - lingerOpt.l_onoff = 1; - lingerOpt.l_linger = 0; + lingerOpt.l_onoff = 1; + lingerOpt.l_linger = 0; - ret = usrsctp_setsockopt(this->socket, SOL_SOCKET, SO_LINGER, &lingerOpt, sizeof(lingerOpt)); + ret = usrsctp_setsockopt(this->socket, SOL_SOCKET, SO_LINGER, &lingerOpt, sizeof(lingerOpt)); - if (ret < 0) - MS_THROW_ERROR("usrsctp_setsockopt(SO_LINGER) failed: %s", std::strerror(errno)); + if (ret < 0) + MS_THROW_ERROR("usrsctp_setsockopt(SO_LINGER) failed: %s", std::strerror(errno)); - // Set SCTP_ENABLE_STREAM_RESET. - struct sctp_assoc_value av; // NOLINT(cppcoreguidelines-pro-type-member-init) + // Set SCTP_ENABLE_STREAM_RESET. + struct sctp_assoc_value av; // NOLINT(cppcoreguidelines-pro-type-member-init) - av.assoc_value = - SCTP_ENABLE_RESET_STREAM_REQ | SCTP_ENABLE_RESET_ASSOC_REQ | SCTP_ENABLE_CHANGE_ASSOC_REQ; + av.assoc_value = + SCTP_ENABLE_RESET_STREAM_REQ | SCTP_ENABLE_RESET_ASSOC_REQ | SCTP_ENABLE_CHANGE_ASSOC_REQ; - ret = usrsctp_setsockopt(this->socket, IPPROTO_SCTP, SCTP_ENABLE_STREAM_RESET, &av, sizeof(av)); + ret = usrsctp_setsockopt(this->socket, IPPROTO_SCTP, SCTP_ENABLE_STREAM_RESET, &av, sizeof(av)); - if (ret < 0) - { - MS_THROW_ERROR("usrsctp_setsockopt(SCTP_ENABLE_STREAM_RESET) failed: %s", std::strerror(errno)); - } + if (ret < 0) + { + MS_THROW_ERROR("usrsctp_setsockopt(SCTP_ENABLE_STREAM_RESET) failed: %s", std::strerror(errno)); + } - // Set SCTP_NODELAY. - uint32_t noDelay = 1; + // Set SCTP_NODELAY. + uint32_t noDelay = 1; - ret = usrsctp_setsockopt(this->socket, IPPROTO_SCTP, SCTP_NODELAY, &noDelay, sizeof(noDelay)); + ret = usrsctp_setsockopt(this->socket, IPPROTO_SCTP, SCTP_NODELAY, &noDelay, sizeof(noDelay)); - if (ret < 0) - MS_THROW_ERROR("usrsctp_setsockopt(SCTP_NODELAY) failed: %s", std::strerror(errno)); + if (ret < 0) + MS_THROW_ERROR("usrsctp_setsockopt(SCTP_NODELAY) failed: %s", std::strerror(errno)); - // Enable events. - struct sctp_event event; // NOLINT(cppcoreguidelines-pro-type-member-init) + // Enable events. + struct sctp_event event; // NOLINT(cppcoreguidelines-pro-type-member-init) - std::memset(&event, 0, sizeof(event)); - event.se_on = 1; + std::memset(&event, 0, sizeof(event)); + event.se_on = 1; - for (size_t i{ 0 }; i < sizeof(EventTypes) / sizeof(uint16_t); ++i) - { - event.se_type = EventTypes[i]; + for (size_t i{ 0 }; i < sizeof(EventTypes) / sizeof(uint16_t); ++i) + { + event.se_type = EventTypes[i]; - ret = usrsctp_setsockopt(this->socket, IPPROTO_SCTP, SCTP_EVENT, &event, sizeof(event)); + ret = usrsctp_setsockopt(this->socket, IPPROTO_SCTP, SCTP_EVENT, &event, sizeof(event)); - if (ret < 0) - MS_THROW_ERROR("usrsctp_setsockopt(SCTP_EVENT) failed: %s", std::strerror(errno)); - } + if (ret < 0) + MS_THROW_ERROR("usrsctp_setsockopt(SCTP_EVENT) failed: %s", std::strerror(errno)); + } - // Init message. - struct sctp_initmsg initmsg; // NOLINT(cppcoreguidelines-pro-type-member-init) + // Init message. + struct sctp_initmsg initmsg; // NOLINT(cppcoreguidelines-pro-type-member-init) - std::memset(&initmsg, 0, sizeof(initmsg)); - initmsg.sinit_num_ostreams = this->os; - initmsg.sinit_max_instreams = this->mis; + std::memset(&initmsg, 0, sizeof(initmsg)); + initmsg.sinit_num_ostreams = this->os; + initmsg.sinit_max_instreams = this->mis; - ret = usrsctp_setsockopt(this->socket, IPPROTO_SCTP, SCTP_INITMSG, &initmsg, sizeof(initmsg)); + ret = usrsctp_setsockopt(this->socket, IPPROTO_SCTP, SCTP_INITMSG, &initmsg, sizeof(initmsg)); - if (ret < 0) - MS_THROW_ERROR("usrsctp_setsockopt(SCTP_INITMSG) failed: %s", std::strerror(errno)); + if (ret < 0) + MS_THROW_ERROR("usrsctp_setsockopt(SCTP_INITMSG) failed: %s", std::strerror(errno)); - // Server side. - struct sockaddr_conn sconn; // NOLINT(cppcoreguidelines-pro-type-member-init) + // Server side. + struct sockaddr_conn sconn; // NOLINT(cppcoreguidelines-pro-type-member-init) - std::memset(&sconn, 0, sizeof(sconn)); - sconn.sconn_family = AF_CONN; - sconn.sconn_port = htons(5000); - sconn.sconn_addr = static_cast(this); + std::memset(&sconn, 0, sizeof(sconn)); + sconn.sconn_family = AF_CONN; + sconn.sconn_port = htons(5000); + sconn.sconn_addr = static_cast(this); #ifdef HAVE_SCONN_LEN sconn.sconn_len = sizeof(sconn); #endif - ret = usrsctp_bind(this->socket, reinterpret_cast(&sconn), sizeof(sconn)); + ret = usrsctp_bind(this->socket, reinterpret_cast(&sconn), sizeof(sconn)); - if (ret < 0) - MS_THROW_ERROR("usrsctp_bind() failed: %s", std::strerror(errno)); - } + if (ret < 0) + MS_THROW_ERROR("usrsctp_bind() failed: %s", std::strerror(errno)); + } - SctpAssociation::~SctpAssociation() - { - MS_TRACE(); + SctpAssociation::~SctpAssociation() + { + MS_TRACE(); - usrsctp_set_ulpinfo(this->socket, nullptr); - usrsctp_close(this->socket); + usrsctp_set_ulpinfo(this->socket, nullptr); + usrsctp_close(this->socket); - // Deregister ourselves from usrsctp. - usrsctp_deregister_address(static_cast(this)); + // Deregister ourselves from usrsctp. + usrsctp_deregister_address(static_cast(this)); - delete[] this->messageBuffer; - } + delete[] this->messageBuffer; + } - void SctpAssociation::TransportConnected() - { - MS_TRACE(); + void SctpAssociation::TransportConnected() + { + MS_TRACE(); - // Just run the SCTP stack if our state is 'new'. - if (this->state != SctpState::NEW) - return; + // Just run the SCTP stack if our state is 'new'. + if (this->state != SctpState::NEW) + return; - try - { - int ret; - struct sockaddr_conn rconn; // NOLINT(cppcoreguidelines-pro-type-member-init) + try + { + int ret; + struct sockaddr_conn rconn; // NOLINT(cppcoreguidelines-pro-type-member-init) - std::memset(&rconn, 0, sizeof(rconn)); - rconn.sconn_family = AF_CONN; - rconn.sconn_port = htons(5000); - rconn.sconn_addr = static_cast(this); + std::memset(&rconn, 0, sizeof(rconn)); + rconn.sconn_family = AF_CONN; + rconn.sconn_port = htons(5000); + rconn.sconn_addr = static_cast(this); #ifdef HAVE_SCONN_LEN - rconn.sconn_len = sizeof(rconn); + rconn.sconn_len = sizeof(rconn); #endif - ret = usrsctp_connect(this->socket, reinterpret_cast(&rconn), sizeof(rconn)); + ret = usrsctp_connect(this->socket, reinterpret_cast(&rconn), sizeof(rconn)); - if (ret < 0 && errno != EINPROGRESS) - MS_THROW_ERROR("usrsctp_connect() failed: %s", std::strerror(errno)); + if (ret < 0 && errno != EINPROGRESS) + MS_THROW_ERROR("usrsctp_connect() failed: %s", std::strerror(errno)); - // Disable MTU discovery. - sctp_paddrparams peerAddrParams; // NOLINT(cppcoreguidelines-pro-type-member-init) + // Disable MTU discovery. + sctp_paddrparams peerAddrParams; // NOLINT(cppcoreguidelines-pro-type-member-init) - std::memset(&peerAddrParams, 0, sizeof(peerAddrParams)); - std::memcpy(&peerAddrParams.spp_address, &rconn, sizeof(rconn)); - peerAddrParams.spp_flags = SPP_PMTUD_DISABLE; + std::memset(&peerAddrParams, 0, sizeof(peerAddrParams)); + std::memcpy(&peerAddrParams.spp_address, &rconn, sizeof(rconn)); + peerAddrParams.spp_flags = SPP_PMTUD_DISABLE; - // The MTU value provided specifies the space available for chunks in the - // packet, so let's subtract the SCTP header size. - peerAddrParams.spp_pathmtu = SctpMtu - sizeof(peerAddrParams); + // The MTU value provided specifies the space available for chunks in the + // packet, so let's subtract the SCTP header size. + peerAddrParams.spp_pathmtu = SctpMtu - sizeof(peerAddrParams); - ret = usrsctp_setsockopt( - this->socket, IPPROTO_SCTP, SCTP_PEER_ADDR_PARAMS, &peerAddrParams, sizeof(peerAddrParams)); + ret = usrsctp_setsockopt( + this->socket, IPPROTO_SCTP, SCTP_PEER_ADDR_PARAMS, &peerAddrParams, sizeof(peerAddrParams)); - if (ret < 0) - MS_THROW_ERROR("usrsctp_setsockopt(SCTP_PEER_ADDR_PARAMS) failed: %s", std::strerror(errno)); + if (ret < 0) + MS_THROW_ERROR("usrsctp_setsockopt(SCTP_PEER_ADDR_PARAMS) failed: %s", std::strerror(errno)); - // Announce connecting state. - this->state = SctpState::CONNECTING; - this->listener->OnSctpAssociationConnecting(this); - } - catch (... /*error*/) - { - this->state = SctpState::FAILED; - this->listener->OnSctpAssociationFailed(this); + // Announce connecting state. + this->state = SctpState::CONNECTING; + this->listener->OnSctpAssociationConnecting(this); + } + catch (... /*error*/) + { + this->state = SctpState::FAILED; + this->listener->OnSctpAssociationFailed(this); throw; - } - } + } + } - void SctpAssociation::ProcessSctpData(const uint8_t* data, size_t len) - { - MS_TRACE(); + void SctpAssociation::ProcessSctpData(const uint8_t* data, size_t len) + { + MS_TRACE(); #if MS_LOG_DEV_LEVEL == 3 - MS_DUMP_DATA(data, len); + MS_DUMP_DATA(data, len); #endif - usrsctp_conninput(static_cast(this), data, len, 0); - } + usrsctp_conninput(static_cast(this), data, len, 0); + } - void SctpAssociation::SendSctpMessage( + void SctpAssociation::SendSctpMessage( const RTC::SctpStreamParameters ¶meters, uint32_t ppid, const uint8_t* msg, size_t len) - { - MS_TRACE(); + { + MS_TRACE(); - // This must be controlled by the DataConsumer. - MS_ASSERT( - len <= this->maxSctpMessageSize, - "given message exceeds max allowed message size [message size:%zu, max message size:%zu]", - len, - this->maxSctpMessageSize); + // This must be controlled by the DataConsumer. + MS_ASSERT( + len <= this->maxSctpMessageSize, + "given message exceeds max allowed message size [message size:%zu, max message size:%zu]", + len, + this->maxSctpMessageSize); - // Fill stcp_sendv_spa. - struct sctp_sendv_spa spa; // NOLINT(cppcoreguidelines-pro-type-member-init) + // Fill stcp_sendv_spa. + struct sctp_sendv_spa spa; // NOLINT(cppcoreguidelines-pro-type-member-init) - std::memset(&spa, 0, sizeof(spa)); - spa.sendv_flags = SCTP_SEND_SNDINFO_VALID; - spa.sendv_sndinfo.snd_sid = parameters.streamId; - spa.sendv_sndinfo.snd_ppid = htonl(ppid); - spa.sendv_sndinfo.snd_flags = SCTP_EOR; + std::memset(&spa, 0, sizeof(spa)); + spa.sendv_flags = SCTP_SEND_SNDINFO_VALID; + spa.sendv_sndinfo.snd_sid = parameters.streamId; + spa.sendv_sndinfo.snd_ppid = htonl(ppid); + spa.sendv_sndinfo.snd_flags = SCTP_EOR; - // If ordered it must be reliable. - if (parameters.ordered) - { - spa.sendv_prinfo.pr_policy = SCTP_PR_SCTP_NONE; - spa.sendv_prinfo.pr_value = 0; - } - // Configure reliability: https://tools.ietf.org/html/rfc3758 - else - { - spa.sendv_flags |= SCTP_SEND_PRINFO_VALID; - spa.sendv_sndinfo.snd_flags |= SCTP_UNORDERED; + // If ordered it must be reliable. + if (parameters.ordered) + { + spa.sendv_prinfo.pr_policy = SCTP_PR_SCTP_NONE; + spa.sendv_prinfo.pr_value = 0; + } + // Configure reliability: https://tools.ietf.org/html/rfc3758 + else + { + spa.sendv_flags |= SCTP_SEND_PRINFO_VALID; + spa.sendv_sndinfo.snd_flags |= SCTP_UNORDERED; - if (parameters.maxPacketLifeTime != 0) - { - spa.sendv_prinfo.pr_policy = SCTP_PR_SCTP_TTL; - spa.sendv_prinfo.pr_value = parameters.maxPacketLifeTime; - } - else if (parameters.maxRetransmits != 0) - { - spa.sendv_prinfo.pr_policy = SCTP_PR_SCTP_RTX; - spa.sendv_prinfo.pr_value = parameters.maxRetransmits; - } - } + if (parameters.maxPacketLifeTime != 0) + { + spa.sendv_prinfo.pr_policy = SCTP_PR_SCTP_TTL; + spa.sendv_prinfo.pr_value = parameters.maxPacketLifeTime; + } + else if (parameters.maxRetransmits != 0) + { + spa.sendv_prinfo.pr_policy = SCTP_PR_SCTP_RTX; + spa.sendv_prinfo.pr_value = parameters.maxRetransmits; + } + } - int ret = usrsctp_sendv( - this->socket, msg, len, nullptr, 0, &spa, static_cast(sizeof(spa)), SCTP_SENDV_SPA, 0); + int ret = usrsctp_sendv( + this->socket, msg, len, nullptr, 0, &spa, static_cast(sizeof(spa)), SCTP_SENDV_SPA, 0); - if (ret < 0) - { - MS_WARN_TAG( - sctp, - "error sending SCTP message [sid:%" PRIu16 ", ppid:%" PRIu32 ", message size:%zu]: %s", - parameters.streamId, - ppid, - len, - std::strerror(errno)); - } - } + if (ret < 0) + { + MS_WARN_TAG( + sctp, + "error sending SCTP message [sid:%" PRIu16 ", ppid:%" PRIu32 ", message size:%zu]: %s", + parameters.streamId, + ppid, + len, + std::strerror(errno)); + } + } - void SctpAssociation::HandleDataConsumer(const RTC::SctpStreamParameters ¶ms) - { - MS_TRACE(); + void SctpAssociation::HandleDataConsumer(const RTC::SctpStreamParameters ¶ms) + { + MS_TRACE(); - auto streamId = params.streamId; + auto streamId = params.streamId; - // We need more OS. - if (streamId > this->os - 1) - AddOutgoingStreams(/*force*/ false); - } + // We need more OS. + if (streamId > this->os - 1) + AddOutgoingStreams(/*force*/ false); + } - void SctpAssociation::DataProducerClosed(const RTC::SctpStreamParameters ¶ms) - { - MS_TRACE(); + void SctpAssociation::DataProducerClosed(const RTC::SctpStreamParameters ¶ms) + { + MS_TRACE(); - auto streamId = params.streamId; + auto streamId = params.streamId; - // Send SCTP_RESET_STREAMS to the remote. - // https://tools.ietf.org/html/draft-ietf-rtcweb-data-channel-13#section-6.7 - if (this->isDataChannel) - ResetSctpStream(streamId, StreamDirection::OUTGOING); - else - ResetSctpStream(streamId, StreamDirection::INCOMING); - } + // Send SCTP_RESET_STREAMS to the remote. + // https://tools.ietf.org/html/draft-ietf-rtcweb-data-channel-13#section-6.7 + if (this->isDataChannel) + ResetSctpStream(streamId, StreamDirection::OUTGOING); + else + ResetSctpStream(streamId, StreamDirection::INCOMING); + } - void SctpAssociation::DataConsumerClosed(const RTC::SctpStreamParameters ¶ms) - { - MS_TRACE(); + void SctpAssociation::DataConsumerClosed(const RTC::SctpStreamParameters ¶ms) + { + MS_TRACE(); - auto streamId = params.streamId; + auto streamId = params.streamId; - // Send SCTP_RESET_STREAMS to the remote. - ResetSctpStream(streamId, StreamDirection::OUTGOING); - } + // Send SCTP_RESET_STREAMS to the remote. + ResetSctpStream(streamId, StreamDirection::OUTGOING); + } - void SctpAssociation::ResetSctpStream(uint16_t streamId, StreamDirection direction) - { - MS_TRACE(); + void SctpAssociation::ResetSctpStream(uint16_t streamId, StreamDirection direction) + { + MS_TRACE(); - // Do nothing if an outgoing stream that could not be allocated by us. - if (direction == StreamDirection::OUTGOING && streamId > this->os - 1) - return; + // Do nothing if an outgoing stream that could not be allocated by us. + if (direction == StreamDirection::OUTGOING && streamId > this->os - 1) + return; - int ret; - struct sctp_assoc_value av; // NOLINT(cppcoreguidelines-pro-type-member-init) - socklen_t len = sizeof(av); + int ret; + struct sctp_assoc_value av; // NOLINT(cppcoreguidelines-pro-type-member-init) + socklen_t len = sizeof(av); #ifndef SCTP_RECONFIG_SUPPORTED #define SCTP_RECONFIG_SUPPORTED 0x00000029 #endif - ret = usrsctp_getsockopt(this->socket, IPPROTO_SCTP, SCTP_RECONFIG_SUPPORTED, &av, &len); + ret = usrsctp_getsockopt(this->socket, IPPROTO_SCTP, SCTP_RECONFIG_SUPPORTED, &av, &len); - if (ret == 0) - { - if (av.assoc_value != 1) - { - MS_DEBUG_TAG(sctp, "stream reconfiguration not negotiated"); + if (ret == 0) + { + if (av.assoc_value != 1) + { + MS_DEBUG_TAG(sctp, "stream reconfiguration not negotiated"); - return; - } - } - else - { - MS_WARN_TAG( - sctp, - "could not retrieve whether stream reconfiguration has been negotiated: %s\n", - std::strerror(errno)); + return; + } + } + else + { + MS_WARN_TAG( + sctp, + "could not retrieve whether stream reconfiguration has been negotiated: %s\n", + std::strerror(errno)); - return; - } + return; + } - // As per spec: https://tools.ietf.org/html/rfc6525#section-4.1 - len = sizeof(sctp_assoc_t) + (2 + 1) * sizeof(uint16_t); + // As per spec: https://tools.ietf.org/html/rfc6525#section-4.1 + len = sizeof(sctp_assoc_t) + (2 + 1) * sizeof(uint16_t); - auto* srs = static_cast(std::malloc(len)); + auto* srs = static_cast(std::malloc(len)); - switch (direction) - { - case StreamDirection::INCOMING: - srs->srs_flags = SCTP_STREAM_RESET_INCOMING; - break; + switch (direction) + { + case StreamDirection::INCOMING: + srs->srs_flags = SCTP_STREAM_RESET_INCOMING; + break; - case StreamDirection::OUTGOING: - srs->srs_flags = SCTP_STREAM_RESET_OUTGOING; - break; - } + case StreamDirection::OUTGOING: + srs->srs_flags = SCTP_STREAM_RESET_OUTGOING; + break; + } - srs->srs_number_streams = 1; - srs->srs_stream_list[0] = streamId; // No need for htonl(). + srs->srs_number_streams = 1; + srs->srs_stream_list[0] = streamId; // No need for htonl(). - ret = usrsctp_setsockopt(this->socket, IPPROTO_SCTP, SCTP_RESET_STREAMS, srs, len); + ret = usrsctp_setsockopt(this->socket, IPPROTO_SCTP, SCTP_RESET_STREAMS, srs, len); - if (ret == 0) - { - MS_DEBUG_TAG(sctp, "SCTP_RESET_STREAMS sent [streamId:%" PRIu16 "]", streamId); - } - else - { - MS_WARN_TAG(sctp, "usrsctp_setsockopt(SCTP_RESET_STREAMS) failed: %s", std::strerror(errno)); - } + if (ret == 0) + { + MS_DEBUG_TAG(sctp, "SCTP_RESET_STREAMS sent [streamId:%" PRIu16 "]", streamId); + } + else + { + MS_WARN_TAG(sctp, "usrsctp_setsockopt(SCTP_RESET_STREAMS) failed: %s", std::strerror(errno)); + } - std::free(srs); - } + std::free(srs); + } - void SctpAssociation::AddOutgoingStreams(bool force) - { - MS_TRACE(); + void SctpAssociation::AddOutgoingStreams(bool force) + { + MS_TRACE(); - uint16_t additionalOs{ 0 }; + uint16_t additionalOs{ 0 }; - if (MaxSctpStreams - this->os >= 32) - additionalOs = 32; - else - additionalOs = MaxSctpStreams - this->os; + if (MaxSctpStreams - this->os >= 32) + additionalOs = 32; + else + additionalOs = MaxSctpStreams - this->os; - if (additionalOs == 0) - { - MS_WARN_TAG(sctp, "cannot add more outgoing streams [OS:%" PRIu16 "]", this->os); + if (additionalOs == 0) + { + MS_WARN_TAG(sctp, "cannot add more outgoing streams [OS:%" PRIu16 "]", this->os); - return; - } + return; + } - auto nextDesiredOs = this->os + additionalOs; + auto nextDesiredOs = this->os + additionalOs; - // Already in progress, ignore (unless forced). - if (!force && nextDesiredOs == this->desiredOs) - return; + // Already in progress, ignore (unless forced). + if (!force && nextDesiredOs == this->desiredOs) + return; - // Update desired value. - this->desiredOs = nextDesiredOs; + // Update desired value. + this->desiredOs = nextDesiredOs; - // If not connected, defer it. - if (this->state != SctpState::CONNECTED) - { - MS_DEBUG_TAG(sctp, "SCTP not connected, deferring OS increase"); + // If not connected, defer it. + if (this->state != SctpState::CONNECTED) + { + MS_DEBUG_TAG(sctp, "SCTP not connected, deferring OS increase"); - return; - } + return; + } - struct sctp_add_streams sas; // NOLINT(cppcoreguidelines-pro-type-member-init) + struct sctp_add_streams sas; // NOLINT(cppcoreguidelines-pro-type-member-init) - std::memset(&sas, 0, sizeof(sas)); - sas.sas_instrms = 0; - sas.sas_outstrms = additionalOs; + std::memset(&sas, 0, sizeof(sas)); + sas.sas_instrms = 0; + sas.sas_outstrms = additionalOs; - MS_DEBUG_TAG(sctp, "adding %" PRIu16 " outgoing streams", additionalOs); + MS_DEBUG_TAG(sctp, "adding %" PRIu16 " outgoing streams", additionalOs); - int ret = usrsctp_setsockopt( - this->socket, IPPROTO_SCTP, SCTP_ADD_STREAMS, &sas, static_cast(sizeof(sas))); + int ret = usrsctp_setsockopt( + this->socket, IPPROTO_SCTP, SCTP_ADD_STREAMS, &sas, static_cast(sizeof(sas))); - if (ret < 0) - MS_WARN_TAG(sctp, "usrsctp_setsockopt(SCTP_ADD_STREAMS) failed: %s", std::strerror(errno)); - } + if (ret < 0) + MS_WARN_TAG(sctp, "usrsctp_setsockopt(SCTP_ADD_STREAMS) failed: %s", std::strerror(errno)); + } - void SctpAssociation::OnUsrSctpSendSctpData(void* buffer, size_t len) - { - MS_TRACE(); + void SctpAssociation::OnUsrSctpSendSctpData(void* buffer, size_t len) + { + MS_TRACE(); - const uint8_t* data = static_cast(buffer); + const uint8_t* data = static_cast(buffer); #if MS_LOG_DEV_LEVEL == 3 - MS_DUMP_DATA(data, len); + MS_DUMP_DATA(data, len); #endif - this->listener->OnSctpAssociationSendData(this, data, len); - } + this->listener->OnSctpAssociationSendData(this, data, len); + } - void SctpAssociation::OnUsrSctpReceiveSctpData( - uint16_t streamId, uint16_t ssn, uint32_t ppid, int flags, const uint8_t* data, size_t len) - { - // Ignore WebRTC DataChannel Control DATA chunks. - if (ppid == 50) - { - MS_WARN_TAG(sctp, "ignoring SCTP data with ppid:50 (WebRTC DataChannel Control)"); + void SctpAssociation::OnUsrSctpReceiveSctpData( + uint16_t streamId, uint16_t ssn, uint32_t ppid, int flags, const uint8_t* data, size_t len) + { + // Ignore WebRTC DataChannel Control DATA chunks. + if (ppid == 50) + { + MS_WARN_TAG(sctp, "ignoring SCTP data with ppid:50 (WebRTC DataChannel Control)"); - return; - } + return; + } - if (this->messageBufferLen != 0 && ssn != this->lastSsnReceived) - { - MS_WARN_TAG( - sctp, - "message chunk received with different SSN while buffer not empty, buffer discarded [ssn:%" PRIu16 - ", last ssn received:%" PRIu16 "]", - ssn, - this->lastSsnReceived); + if (this->messageBufferLen != 0 && ssn != this->lastSsnReceived) + { + MS_WARN_TAG( + sctp, + "message chunk received with different SSN while buffer not empty, buffer discarded [ssn:%" PRIu16 + ", last ssn received:%" PRIu16 "]", + ssn, + this->lastSsnReceived); - this->messageBufferLen = 0; - } + this->messageBufferLen = 0; + } - // Update last SSN received. - this->lastSsnReceived = ssn; + // Update last SSN received. + this->lastSsnReceived = ssn; - auto eor = static_cast(flags & MSG_EOR); + auto eor = static_cast(flags & MSG_EOR); - if (this->messageBufferLen + len > this->maxSctpMessageSize) - { - MS_WARN_TAG( - sctp, - "ongoing received message exceeds max allowed message size [message size:%zu, max message size:%zu, eor:%u]", - this->messageBufferLen + len, - this->maxSctpMessageSize, - eor ? 1 : 0); + if (this->messageBufferLen + len > this->maxSctpMessageSize) + { + MS_WARN_TAG( + sctp, + "ongoing received message exceeds max allowed message size [message size:%zu, max message size:%zu, eor:%u]", + this->messageBufferLen + len, + this->maxSctpMessageSize, + eor ? 1 : 0); - this->lastSsnReceived = 0; + this->lastSsnReceived = 0; - return; - } + return; + } - // If end of message and there is no buffered data, notify it directly. - if (eor && this->messageBufferLen == 0) - { - MS_DEBUG_DEV("directly notifying listener [eor:1, buffer len:0]"); + // If end of message and there is no buffered data, notify it directly. + if (eor && this->messageBufferLen == 0) + { + MS_DEBUG_DEV("directly notifying listener [eor:1, buffer len:0]"); - this->listener->OnSctpAssociationMessageReceived(this, streamId, ppid, data, len); - } - // If end of message and there is buffered data, append data and notify buffer. - else if (eor && this->messageBufferLen != 0) - { - std::memcpy(this->messageBuffer + this->messageBufferLen, data, len); - this->messageBufferLen += len; + this->listener->OnSctpAssociationMessageReceived(this, streamId, ppid, data, len); + } + // If end of message and there is buffered data, append data and notify buffer. + else if (eor && this->messageBufferLen != 0) + { + std::memcpy(this->messageBuffer + this->messageBufferLen, data, len); + this->messageBufferLen += len; - MS_DEBUG_DEV("notifying listener [eor:1, buffer len:%zu]", this->messageBufferLen); + MS_DEBUG_DEV("notifying listener [eor:1, buffer len:%zu]", this->messageBufferLen); - this->listener->OnSctpAssociationMessageReceived( - this, streamId, ppid, this->messageBuffer, this->messageBufferLen); + this->listener->OnSctpAssociationMessageReceived( + this, streamId, ppid, this->messageBuffer, this->messageBufferLen); - this->messageBufferLen = 0; - } - // If non end of message, append data to the buffer. - else if (!eor) - { - // Allocate the buffer if not already done. - if (!this->messageBuffer) - this->messageBuffer = new uint8_t[this->maxSctpMessageSize]; + this->messageBufferLen = 0; + } + // If non end of message, append data to the buffer. + else if (!eor) + { + // Allocate the buffer if not already done. + if (!this->messageBuffer) + this->messageBuffer = new uint8_t[this->maxSctpMessageSize]; - std::memcpy(this->messageBuffer + this->messageBufferLen, data, len); - this->messageBufferLen += len; + std::memcpy(this->messageBuffer + this->messageBufferLen, data, len); + this->messageBufferLen += len; - MS_DEBUG_DEV("data buffered [eor:0, buffer len:%zu]", this->messageBufferLen); - } - } + MS_DEBUG_DEV("data buffered [eor:0, buffer len:%zu]", this->messageBufferLen); + } + } - void SctpAssociation::OnUsrSctpReceiveSctpNotification(union sctp_notification* notification, size_t len) - { - if (notification->sn_header.sn_length != (uint32_t)len) - return; + void SctpAssociation::OnUsrSctpReceiveSctpNotification(union sctp_notification* notification, size_t len) + { + if (notification->sn_header.sn_length != (uint32_t)len) + return; - switch (notification->sn_header.sn_type) - { - case SCTP_ADAPTATION_INDICATION: - { - MS_DEBUG_TAG( - sctp, - "SCTP adaptation indication [%x]", - notification->sn_adaptation_event.sai_adaptation_ind); + switch (notification->sn_header.sn_type) + { + case SCTP_ADAPTATION_INDICATION: + { + MS_DEBUG_TAG( + sctp, + "SCTP adaptation indication [%x]", + notification->sn_adaptation_event.sai_adaptation_ind); - break; - } + break; + } - case SCTP_ASSOC_CHANGE: - { - switch (notification->sn_assoc_change.sac_state) - { - case SCTP_COMM_UP: - { - MS_DEBUG_TAG( - sctp, - "SCTP association connected, streams [out:%" PRIu16 ", in:%" PRIu16 "]", - notification->sn_assoc_change.sac_outbound_streams, - notification->sn_assoc_change.sac_inbound_streams); + case SCTP_ASSOC_CHANGE: + { + switch (notification->sn_assoc_change.sac_state) + { + case SCTP_COMM_UP: + { + MS_DEBUG_TAG( + sctp, + "SCTP association connected, streams [out:%" PRIu16 ", in:%" PRIu16 "]", + notification->sn_assoc_change.sac_outbound_streams, + notification->sn_assoc_change.sac_inbound_streams); - // Update our OS. - this->os = notification->sn_assoc_change.sac_outbound_streams; + // Update our OS. + this->os = notification->sn_assoc_change.sac_outbound_streams; - // Increase if requested before connected. - if (this->desiredOs > this->os) - AddOutgoingStreams(/*force*/ true); + // Increase if requested before connected. + if (this->desiredOs > this->os) + AddOutgoingStreams(/*force*/ true); - if (this->state != SctpState::CONNECTED) - { - this->state = SctpState::CONNECTED; - this->listener->OnSctpAssociationConnected(this); - } + if (this->state != SctpState::CONNECTED) + { + this->state = SctpState::CONNECTED; + this->listener->OnSctpAssociationConnected(this); + } - break; - } + break; + } - case SCTP_COMM_LOST: - { - if (notification->sn_header.sn_length > 0) - { - static const size_t BufferSize{ 1024 }; - static char buffer[BufferSize]; + case SCTP_COMM_LOST: + { + if (notification->sn_header.sn_length > 0) + { + static const size_t BufferSize{ 1024 }; + static char buffer[BufferSize]; - uint32_t len = notification->sn_header.sn_length; + uint32_t len = notification->sn_header.sn_length; - for (uint32_t i{ 0 }; i < len; ++i) - { - std::snprintf( - buffer, BufferSize, " 0x%02x", notification->sn_assoc_change.sac_info[i]); - } + for (uint32_t i{ 0 }; i < len; ++i) + { + std::snprintf( + buffer, BufferSize, " 0x%02x", notification->sn_assoc_change.sac_info[i]); + } - MS_DEBUG_TAG(sctp, "SCTP communication lost [info:%s]", buffer); - } - else - { - MS_DEBUG_TAG(sctp, "SCTP communication lost"); - } + MS_DEBUG_TAG(sctp, "SCTP communication lost [info:%s]", buffer); + } + else + { + MS_DEBUG_TAG(sctp, "SCTP communication lost"); + } - if (this->state != SctpState::CLOSED) - { - this->state = SctpState::CLOSED; - this->listener->OnSctpAssociationClosed(this); - } + if (this->state != SctpState::CLOSED) + { + this->state = SctpState::CLOSED; + this->listener->OnSctpAssociationClosed(this); + } - break; - } + break; + } - case SCTP_RESTART: - { - MS_DEBUG_TAG( - sctp, - "SCTP remote association restarted, streams [out:%" PRIu16 ", int:%" PRIu16 "]", - notification->sn_assoc_change.sac_outbound_streams, - notification->sn_assoc_change.sac_inbound_streams); + case SCTP_RESTART: + { + MS_DEBUG_TAG( + sctp, + "SCTP remote association restarted, streams [out:%" PRIu16 ", int:%" PRIu16 "]", + notification->sn_assoc_change.sac_outbound_streams, + notification->sn_assoc_change.sac_inbound_streams); - // Update our OS. - this->os = notification->sn_assoc_change.sac_outbound_streams; + // Update our OS. + this->os = notification->sn_assoc_change.sac_outbound_streams; - // Increase if requested before connected. - if (this->desiredOs > this->os) - AddOutgoingStreams(/*force*/ true); + // Increase if requested before connected. + if (this->desiredOs > this->os) + AddOutgoingStreams(/*force*/ true); - if (this->state != SctpState::CONNECTED) - { - this->state = SctpState::CONNECTED; - this->listener->OnSctpAssociationConnected(this); - } + if (this->state != SctpState::CONNECTED) + { + this->state = SctpState::CONNECTED; + this->listener->OnSctpAssociationConnected(this); + } - break; - } + break; + } - case SCTP_SHUTDOWN_COMP: - { - MS_DEBUG_TAG(sctp, "SCTP association gracefully closed"); + case SCTP_SHUTDOWN_COMP: + { + MS_DEBUG_TAG(sctp, "SCTP association gracefully closed"); - if (this->state != SctpState::CLOSED) - { - this->state = SctpState::CLOSED; - this->listener->OnSctpAssociationClosed(this); - } + if (this->state != SctpState::CLOSED) + { + this->state = SctpState::CLOSED; + this->listener->OnSctpAssociationClosed(this); + } - break; - } + break; + } - case SCTP_CANT_STR_ASSOC: - { - if (notification->sn_header.sn_length > 0) - { - static const size_t BufferSize{ 1024 }; - static char buffer[BufferSize]; + case SCTP_CANT_STR_ASSOC: + { + if (notification->sn_header.sn_length > 0) + { + static const size_t BufferSize{ 1024 }; + static char buffer[BufferSize]; - uint32_t len = notification->sn_header.sn_length; + uint32_t len = notification->sn_header.sn_length; - for (uint32_t i{ 0 }; i < len; ++i) - { - std::snprintf( - buffer, BufferSize, " 0x%02x", notification->sn_assoc_change.sac_info[i]); - } + for (uint32_t i{ 0 }; i < len; ++i) + { + std::snprintf( + buffer, BufferSize, " 0x%02x", notification->sn_assoc_change.sac_info[i]); + } - MS_WARN_TAG(sctp, "SCTP setup failed: %s", buffer); - } + MS_WARN_TAG(sctp, "SCTP setup failed: %s", buffer); + } - if (this->state != SctpState::FAILED) - { - this->state = SctpState::FAILED; - this->listener->OnSctpAssociationFailed(this); - } + if (this->state != SctpState::FAILED) + { + this->state = SctpState::FAILED; + this->listener->OnSctpAssociationFailed(this); + } - break; - } + break; + } - default:; - } + default:; + } - break; - } + break; + } - // https://tools.ietf.org/html/rfc6525#section-6.1.2. - case SCTP_ASSOC_RESET_EVENT: - { - MS_DEBUG_TAG(sctp, "SCTP association reset event received"); + // https://tools.ietf.org/html/rfc6525#section-6.1.2. + case SCTP_ASSOC_RESET_EVENT: + { + MS_DEBUG_TAG(sctp, "SCTP association reset event received"); - break; - } + break; + } - // An Operation Error is not considered fatal in and of itself, but may be - // used with an ABORT chunk to report a fatal condition. - case SCTP_REMOTE_ERROR: - { - static const size_t BufferSize{ 1024 }; - static char buffer[BufferSize]; + // An Operation Error is not considered fatal in and of itself, but may be + // used with an ABORT chunk to report a fatal condition. + case SCTP_REMOTE_ERROR: + { + static const size_t BufferSize{ 1024 }; + static char buffer[BufferSize]; - uint32_t len = notification->sn_remote_error.sre_length - sizeof(struct sctp_remote_error); + uint32_t len = notification->sn_remote_error.sre_length - sizeof(struct sctp_remote_error); - for (uint32_t i{ 0 }; i < len; i++) - { - std::snprintf(buffer, BufferSize, "0x%02x", notification->sn_remote_error.sre_data[i]); - } + for (uint32_t i{ 0 }; i < len; i++) + { + std::snprintf(buffer, BufferSize, "0x%02x", notification->sn_remote_error.sre_data[i]); + } - MS_WARN_TAG( - sctp, - "remote SCTP association error [type:0x%04x, data:%s]", - notification->sn_remote_error.sre_error, - buffer); + MS_WARN_TAG( + sctp, + "remote SCTP association error [type:0x%04x, data:%s]", + notification->sn_remote_error.sre_error, + buffer); - break; - } + break; + } - // When a peer sends a SHUTDOWN, SCTP delivers this notification to - // inform the application that it should cease sending data. - case SCTP_SHUTDOWN_EVENT: - { - MS_DEBUG_TAG(sctp, "remote SCTP association shutdown"); + // When a peer sends a SHUTDOWN, SCTP delivers this notification to + // inform the application that it should cease sending data. + case SCTP_SHUTDOWN_EVENT: + { + MS_DEBUG_TAG(sctp, "remote SCTP association shutdown"); - if (this->state != SctpState::CLOSED) - { - this->state = SctpState::CLOSED; - this->listener->OnSctpAssociationClosed(this); - } + if (this->state != SctpState::CLOSED) + { + this->state = SctpState::CLOSED; + this->listener->OnSctpAssociationClosed(this); + } - break; - } + break; + } - case SCTP_SEND_FAILED_EVENT: - { - static const size_t BufferSize{ 1024 }; - static char buffer[BufferSize]; + case SCTP_SEND_FAILED_EVENT: + { + static const size_t BufferSize{ 1024 }; + static char buffer[BufferSize]; - uint32_t len = - notification->sn_send_failed_event.ssfe_length - sizeof(struct sctp_send_failed_event); + uint32_t len = + notification->sn_send_failed_event.ssfe_length - sizeof(struct sctp_send_failed_event); - for (uint32_t i{ 0 }; i < len; ++i) - { - std::snprintf(buffer, BufferSize, "0x%02x", notification->sn_send_failed_event.ssfe_data[i]); - } + for (uint32_t i{ 0 }; i < len; ++i) + { + std::snprintf(buffer, BufferSize, "0x%02x", notification->sn_send_failed_event.ssfe_data[i]); + } - MS_WARN_TAG( - sctp, - "SCTP message sent failure [streamId:%" PRIu16 ", ppid:%" PRIu32 - ", sent:%s, error:0x%08x, info:%s]", - notification->sn_send_failed_event.ssfe_info.snd_sid, - ntohl(notification->sn_send_failed_event.ssfe_info.snd_ppid), - (notification->sn_send_failed_event.ssfe_flags & SCTP_DATA_SENT) ? "yes" : "no", - notification->sn_send_failed_event.ssfe_error, - buffer); + MS_WARN_TAG( + sctp, + "SCTP message sent failure [streamId:%" PRIu16 ", ppid:%" PRIu32 + ", sent:%s, error:0x%08x, info:%s]", + notification->sn_send_failed_event.ssfe_info.snd_sid, + ntohl(notification->sn_send_failed_event.ssfe_info.snd_ppid), + (notification->sn_send_failed_event.ssfe_flags & SCTP_DATA_SENT) ? "yes" : "no", + notification->sn_send_failed_event.ssfe_error, + buffer); - break; - } + break; + } - case SCTP_STREAM_RESET_EVENT: - { - bool incoming{ false }; - bool outgoing{ false }; - uint16_t numStreams = - (notification->sn_strreset_event.strreset_length - sizeof(struct sctp_stream_reset_event)) / - sizeof(uint16_t); + case SCTP_STREAM_RESET_EVENT: + { + bool incoming{ false }; + bool outgoing{ false }; + uint16_t numStreams = + (notification->sn_strreset_event.strreset_length - sizeof(struct sctp_stream_reset_event)) / + sizeof(uint16_t); - if (notification->sn_strreset_event.strreset_flags & SCTP_STREAM_RESET_INCOMING_SSN) - incoming = true; + if (notification->sn_strreset_event.strreset_flags & SCTP_STREAM_RESET_INCOMING_SSN) + incoming = true; - if (notification->sn_strreset_event.strreset_flags & SCTP_STREAM_RESET_OUTGOING_SSN) - outgoing = true; + if (notification->sn_strreset_event.strreset_flags & SCTP_STREAM_RESET_OUTGOING_SSN) + outgoing = true; //todo 打印sctp调试信息 - if (false /*MS_HAS_DEBUG_TAG(sctp)*/) - { - std::string streamIds; + if (false /*MS_HAS_DEBUG_TAG(sctp)*/) + { + std::string streamIds; - for (uint16_t i{ 0 }; i < numStreams; ++i) - { - auto streamId = notification->sn_strreset_event.strreset_stream_list[i]; + for (uint16_t i{ 0 }; i < numStreams; ++i) + { + auto streamId = notification->sn_strreset_event.strreset_stream_list[i]; - // Don't log more than 5 stream ids. - if (i > 4) - { - streamIds.append("..."); + // Don't log more than 5 stream ids. + if (i > 4) + { + streamIds.append("..."); - break; - } + break; + } - if (i > 0) - streamIds.append(","); + if (i > 0) + streamIds.append(","); - streamIds.append(std::to_string(streamId)); - } + streamIds.append(std::to_string(streamId)); + } - MS_DEBUG_TAG( - sctp, - "SCTP stream reset event [flags:%x, i|o:%s|%s, num streams:%" PRIu16 ", stream ids:%s]", - notification->sn_strreset_event.strreset_flags, - incoming ? "true" : "false", - outgoing ? "true" : "false", - numStreams, - streamIds.c_str()); - } + MS_DEBUG_TAG( + sctp, + "SCTP stream reset event [flags:%x, i|o:%s|%s, num streams:%" PRIu16 ", stream ids:%s]", + notification->sn_strreset_event.strreset_flags, + incoming ? "true" : "false", + outgoing ? "true" : "false", + numStreams, + streamIds.c_str()); + } - // Special case for WebRTC DataChannels in which we must also reset our - // outgoing SCTP stream. - if (incoming && !outgoing && this->isDataChannel) - { - for (uint16_t i{ 0 }; i < numStreams; ++i) - { - auto streamId = notification->sn_strreset_event.strreset_stream_list[i]; + // Special case for WebRTC DataChannels in which we must also reset our + // outgoing SCTP stream. + if (incoming && !outgoing && this->isDataChannel) + { + for (uint16_t i{ 0 }; i < numStreams; ++i) + { + auto streamId = notification->sn_strreset_event.strreset_stream_list[i]; - ResetSctpStream(streamId, StreamDirection::OUTGOING); - } - } + ResetSctpStream(streamId, StreamDirection::OUTGOING); + } + } - break; - } + break; + } - case SCTP_STREAM_CHANGE_EVENT: - { - if (notification->sn_strchange_event.strchange_flags == 0) - { - MS_DEBUG_TAG( - sctp, - "SCTP stream changed, streams [out:%" PRIu16 ", in:%" PRIu16 ", flags:%x]", - notification->sn_strchange_event.strchange_outstrms, - notification->sn_strchange_event.strchange_instrms, - notification->sn_strchange_event.strchange_flags); - } - else if (notification->sn_strchange_event.strchange_flags & SCTP_STREAM_RESET_DENIED) - { - MS_WARN_TAG( - sctp, - "SCTP stream change denied, streams [out:%" PRIu16 ", in:%" PRIu16 ", flags:%x]", - notification->sn_strchange_event.strchange_outstrms, - notification->sn_strchange_event.strchange_instrms, - notification->sn_strchange_event.strchange_flags); + case SCTP_STREAM_CHANGE_EVENT: + { + if (notification->sn_strchange_event.strchange_flags == 0) + { + MS_DEBUG_TAG( + sctp, + "SCTP stream changed, streams [out:%" PRIu16 ", in:%" PRIu16 ", flags:%x]", + notification->sn_strchange_event.strchange_outstrms, + notification->sn_strchange_event.strchange_instrms, + notification->sn_strchange_event.strchange_flags); + } + else if (notification->sn_strchange_event.strchange_flags & SCTP_STREAM_RESET_DENIED) + { + MS_WARN_TAG( + sctp, + "SCTP stream change denied, streams [out:%" PRIu16 ", in:%" PRIu16 ", flags:%x]", + notification->sn_strchange_event.strchange_outstrms, + notification->sn_strchange_event.strchange_instrms, + notification->sn_strchange_event.strchange_flags); - break; - } - else if (notification->sn_strchange_event.strchange_flags & SCTP_STREAM_RESET_FAILED) - { - MS_WARN_TAG( - sctp, - "SCTP stream change failed, streams [out:%" PRIu16 ", in:%" PRIu16 ", flags:%x]", - notification->sn_strchange_event.strchange_outstrms, - notification->sn_strchange_event.strchange_instrms, - notification->sn_strchange_event.strchange_flags); + break; + } + else if (notification->sn_strchange_event.strchange_flags & SCTP_STREAM_RESET_FAILED) + { + MS_WARN_TAG( + sctp, + "SCTP stream change failed, streams [out:%" PRIu16 ", in:%" PRIu16 ", flags:%x]", + notification->sn_strchange_event.strchange_outstrms, + notification->sn_strchange_event.strchange_instrms, + notification->sn_strchange_event.strchange_flags); - break; - } + break; + } - // Update OS. - this->os = notification->sn_strchange_event.strchange_outstrms; + // Update OS. + this->os = notification->sn_strchange_event.strchange_outstrms; - break; - } + break; + } - default: - { - MS_WARN_TAG( - sctp, "unhandled SCTP event received [type:%" PRIu16 "]", notification->sn_header.sn_type); - } - } - } + default: + { + MS_WARN_TAG( + sctp, "unhandled SCTP event received [type:%" PRIu16 "]", notification->sn_header.sn_type); + } + } + } //////////////////////////////////////////////////////////////////////////////////////// diff --git a/webrtc/SctpAssociation.hpp b/webrtc/SctpAssociation.hpp index 548221c5..9c46d275 100644 --- a/webrtc/SctpAssociation.hpp +++ b/webrtc/SctpAssociation.hpp @@ -18,104 +18,104 @@ namespace RTC uint16_t maxRetransmits{ 0u }; }; - class SctpAssociation - { - public: - enum class SctpState - { - NEW = 1, - CONNECTING, - CONNECTED, - FAILED, - CLOSED - }; + class SctpAssociation + { + public: + enum class SctpState + { + NEW = 1, + CONNECTING, + CONNECTED, + FAILED, + CLOSED + }; - private: - enum class StreamDirection - { - INCOMING = 1, - OUTGOING - }; + private: + enum class StreamDirection + { + INCOMING = 1, + OUTGOING + }; - public: - class Listener - { - public: - virtual void OnSctpAssociationConnecting(RTC::SctpAssociation* sctpAssociation) = 0; - virtual void OnSctpAssociationConnected(RTC::SctpAssociation* sctpAssociation) = 0; - virtual void OnSctpAssociationFailed(RTC::SctpAssociation* sctpAssociation) = 0; - virtual void OnSctpAssociationClosed(RTC::SctpAssociation* sctpAssociation) = 0; - virtual void OnSctpAssociationSendData( - RTC::SctpAssociation* sctpAssociation, const uint8_t* data, size_t len) = 0; - virtual void OnSctpAssociationMessageReceived( - RTC::SctpAssociation* sctpAssociation, - uint16_t streamId, - uint32_t ppid, - const uint8_t* msg, - size_t len) = 0; - }; + public: + class Listener + { + public: + virtual void OnSctpAssociationConnecting(RTC::SctpAssociation* sctpAssociation) = 0; + virtual void OnSctpAssociationConnected(RTC::SctpAssociation* sctpAssociation) = 0; + virtual void OnSctpAssociationFailed(RTC::SctpAssociation* sctpAssociation) = 0; + virtual void OnSctpAssociationClosed(RTC::SctpAssociation* sctpAssociation) = 0; + virtual void OnSctpAssociationSendData( + RTC::SctpAssociation* sctpAssociation, const uint8_t* data, size_t len) = 0; + virtual void OnSctpAssociationMessageReceived( + RTC::SctpAssociation* sctpAssociation, + uint16_t streamId, + uint32_t ppid, + const uint8_t* msg, + size_t len) = 0; + }; - public: - static bool IsSctp(const uint8_t* data, size_t len) - { - // clang-format off - return ( - (len >= 12) && - // Must have Source Port Number and Destination Port Number set to 5000 (hack). - (Utils::Byte::Get2Bytes(data, 0) == 5000) && - (Utils::Byte::Get2Bytes(data, 2) == 5000) - ); - // clang-format on - } + public: + static bool IsSctp(const uint8_t* data, size_t len) + { + // clang-format off + return ( + (len >= 12) && + // Must have Source Port Number and Destination Port Number set to 5000 (hack). + (Utils::Byte::Get2Bytes(data, 0) == 5000) && + (Utils::Byte::Get2Bytes(data, 2) == 5000) + ); + // clang-format on + } - public: - SctpAssociation( - Listener* listener, uint16_t os, uint16_t mis, size_t maxSctpMessageSize, bool isDataChannel); - virtual ~SctpAssociation(); + public: + SctpAssociation( + Listener* listener, uint16_t os, uint16_t mis, size_t maxSctpMessageSize, bool isDataChannel); + virtual ~SctpAssociation(); - public: - void TransportConnected(); - size_t GetMaxSctpMessageSize() const - { - return this->maxSctpMessageSize; - } - SctpState GetState() const - { - return this->state; - } - void ProcessSctpData(const uint8_t* data, size_t len); - void SendSctpMessage(const RTC::SctpStreamParameters ¶ms, uint32_t ppid, const uint8_t* msg, size_t len); - void HandleDataConsumer(const RTC::SctpStreamParameters ¶ms); - void DataProducerClosed(const RTC::SctpStreamParameters ¶ms); - void DataConsumerClosed(const RTC::SctpStreamParameters ¶ms); + public: + void TransportConnected(); + size_t GetMaxSctpMessageSize() const + { + return this->maxSctpMessageSize; + } + SctpState GetState() const + { + return this->state; + } + void ProcessSctpData(const uint8_t* data, size_t len); + void SendSctpMessage(const RTC::SctpStreamParameters ¶ms, uint32_t ppid, const uint8_t* msg, size_t len); + void HandleDataConsumer(const RTC::SctpStreamParameters ¶ms); + void DataProducerClosed(const RTC::SctpStreamParameters ¶ms); + void DataConsumerClosed(const RTC::SctpStreamParameters ¶ms); - private: - void ResetSctpStream(uint16_t streamId, StreamDirection); - void AddOutgoingStreams(bool force = false); + private: + void ResetSctpStream(uint16_t streamId, StreamDirection); + void AddOutgoingStreams(bool force = false); - public: + public: /* Callbacks fired by usrsctp events. */ virtual void OnUsrSctpSendSctpData(void* buffer, size_t len); virtual void OnUsrSctpReceiveSctpData(uint16_t streamId, uint16_t ssn, uint32_t ppid, int flags, const uint8_t* data, size_t len); virtual void OnUsrSctpReceiveSctpNotification(union sctp_notification* notification, size_t len); - private: - // Passed by argument. - Listener* listener{ nullptr }; - uint16_t os{ 1024u }; - uint16_t mis{ 1024u }; - size_t maxSctpMessageSize{ 262144u }; - bool isDataChannel{ false }; - // Allocated by this. - uint8_t* messageBuffer{ nullptr }; - // Others. - SctpState state{ SctpState::NEW }; - struct socket* socket{ nullptr }; - uint16_t desiredOs{ 0u }; - size_t messageBufferLen{ 0u }; - uint16_t lastSsnReceived{ 0u }; // Valid for us since no SCTP I-DATA support. + private: + // Passed by argument. + Listener* listener{ nullptr }; + uint16_t os{ 1024u }; + uint16_t mis{ 1024u }; + size_t maxSctpMessageSize{ 262144u }; + bool isDataChannel{ false }; + // Allocated by this. + uint8_t* messageBuffer{ nullptr }; + // Others. + SctpState state{ SctpState::NEW }; + struct socket* socket{ nullptr }; + uint16_t desiredOs{ 0u }; + size_t messageBufferLen{ 0u }; + uint16_t lastSsnReceived{ 0u }; // Valid for us since no SCTP I-DATA support. std::shared_ptr _env; - }; + }; //保证线程安全 class SctpAssociationImp : public SctpAssociation, public std::enable_shared_from_this{ diff --git a/webrtc/StunPacket.cpp b/webrtc/StunPacket.cpp index c7a403fb..9f648d55 100644 --- a/webrtc/StunPacket.cpp +++ b/webrtc/StunPacket.cpp @@ -97,785 +97,785 @@ namespace RTC return str; } - /* Class variables. */ - - const uint8_t StunPacket::magicCookie[] = { 0x21, 0x12, 0xA4, 0x42 }; - - /* Class methods. */ - - StunPacket* StunPacket::Parse(const uint8_t* data, size_t len) - { - MS_TRACE(); - - if (!StunPacket::IsStun(data, len)) - return nullptr; - - /* - The message type field is decomposed further into the following - structure: - - 0 1 - 2 3 4 5 6 7 8 9 0 1 2 3 4 5 - +--+--+-+-+-+-+-+-+-+-+-+-+-+-+ - |M |M |M|M|M|C|M|M|M|C|M|M|M|M| - |11|10|9|8|7|1|6|5|4|0|3|2|1|0| - +--+--+-+-+-+-+-+-+-+-+-+-+-+-+ - - Figure 3: Format of STUN Message Type Field - - Here the bits in the message type field are shown as most significant - (M11) through least significant (M0). M11 through M0 represent a 12- - bit encoding of the method. C1 and C0 represent a 2-bit encoding of - the class. - */ - - // Get type field. - uint16_t msgType = Utils::Byte::Get2Bytes(data, 0); - - // Get length field. - uint16_t msgLength = Utils::Byte::Get2Bytes(data, 2); - - // length field must be total size minus header's 20 bytes, and must be multiple of 4 Bytes. - if ((static_cast(msgLength) != len - 20) || ((msgLength & 0x03) != 0)) - { - MS_WARN_TAG( - ice, - "length field + 20 does not match total size (or it is not multiple of 4 bytes), " - "packet discarded"); - - return nullptr; - } - - // Get STUN method. - uint16_t msgMethod = (msgType & 0x000f) | ((msgType & 0x00e0) >> 1) | ((msgType & 0x3E00) >> 2); - - // Get STUN class. - uint16_t msgClass = ((data[0] & 0x01) << 1) | ((data[1] & 0x10) >> 4); - - // Create a new StunPacket (data + 8 points to the received TransactionID field). - auto* packet = new StunPacket( - static_cast(msgClass), static_cast(msgMethod), data + 8, data, len); - - /* - STUN Attributes - - After the STUN header are zero or more attributes. Each attribute - MUST be TLV encoded, with a 16-bit type, 16-bit length, and value. - Each STUN attribute MUST end on a 32-bit boundary. As mentioned - above, all fields in an attribute are transmitted most significant - bit first. - - 0 1 2 3 - 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 - +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - | Type | Length | - +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - | Value (variable) .... - +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - */ - - // Start looking for attributes after STUN header (Byte #20). - size_t pos{ 20 }; - // Flags (positions) for special MESSAGE-INTEGRITY and FINGERPRINT attributes. - bool hasMessageIntegrity{ false }; - bool hasFingerprint{ false }; - size_t fingerprintAttrPos; // Will point to the beginning of the attribute. - uint32_t fingerprint; // Holds the value of the FINGERPRINT attribute. - - // Ensure there are at least 4 remaining bytes (attribute with 0 length). - while (pos + 4 <= len) - { - // Get the attribute type. - auto attrType = static_cast(Utils::Byte::Get2Bytes(data, pos)); - - // Get the attribute length. - uint16_t attrLength = Utils::Byte::Get2Bytes(data, pos + 2); - - // Ensure the attribute length is not greater than the remaining size. - if ((pos + 4 + attrLength) > len) - { - MS_WARN_TAG(ice, "the attribute length exceeds the remaining size, packet discarded"); - - delete packet; - return nullptr; - } - - // FINGERPRINT must be the last attribute. - if (hasFingerprint) - { - MS_WARN_TAG(ice, "attribute after FINGERPRINT is not allowed, packet discarded"); - - delete packet; - return nullptr; - } - - // After a MESSAGE-INTEGRITY attribute just FINGERPRINT is allowed. - if (hasMessageIntegrity && attrType != Attribute::FINGERPRINT) - { - MS_WARN_TAG( - ice, - "attribute after MESSAGE-INTEGRITY other than FINGERPRINT is not allowed, " - "packet discarded"); - - delete packet; - return nullptr; - } - - const uint8_t* attrValuePos = data + pos + 4; - - switch (attrType) - { - case Attribute::USERNAME: - { - packet->SetUsername( - reinterpret_cast(attrValuePos), static_cast(attrLength)); - - break; - } - - case Attribute::PRIORITY: - { - // Ensure attribute length is 4 bytes. - if (attrLength != 4) - { - MS_WARN_TAG(ice, "attribute PRIORITY must be 4 bytes length, packet discarded"); - - delete packet; - return nullptr; - } - - packet->SetPriority(Utils::Byte::Get4Bytes(attrValuePos, 0)); - - break; - } - - case Attribute::ICE_CONTROLLING: - { - // Ensure attribute length is 8 bytes. - if (attrLength != 8) - { - MS_WARN_TAG(ice, "attribute ICE-CONTROLLING must be 8 bytes length, packet discarded"); - - delete packet; - return nullptr; - } - - packet->SetIceControlling(Utils::Byte::Get8Bytes(attrValuePos, 0)); - - break; - } - - case Attribute::ICE_CONTROLLED: - { - // Ensure attribute length is 8 bytes. - if (attrLength != 8) - { - MS_WARN_TAG(ice, "attribute ICE-CONTROLLED must be 8 bytes length, packet discarded"); - - delete packet; - return nullptr; - } - - packet->SetIceControlled(Utils::Byte::Get8Bytes(attrValuePos, 0)); - - break; - } - - case Attribute::USE_CANDIDATE: - { - // Ensure attribute length is 0 bytes. - if (attrLength != 0) - { - MS_WARN_TAG(ice, "attribute USE-CANDIDATE must be 0 bytes length, packet discarded"); - - delete packet; - return nullptr; - } - - packet->SetUseCandidate(); - - break; - } - - case Attribute::MESSAGE_INTEGRITY: - { - // Ensure attribute length is 20 bytes. - if (attrLength != 20) - { - MS_WARN_TAG(ice, "attribute MESSAGE-INTEGRITY must be 20 bytes length, packet discarded"); - - delete packet; - return nullptr; - } - - hasMessageIntegrity = true; - packet->SetMessageIntegrity(attrValuePos); - - break; - } - - case Attribute::FINGERPRINT: - { - // Ensure attribute length is 4 bytes. - if (attrLength != 4) - { - MS_WARN_TAG(ice, "attribute FINGERPRINT must be 4 bytes length, packet discarded"); - - delete packet; - return nullptr; - } - - hasFingerprint = true; - fingerprintAttrPos = pos; - fingerprint = Utils::Byte::Get4Bytes(attrValuePos, 0); - packet->SetFingerprint(); - - break; - } - - case Attribute::ERROR_CODE: - { - // Ensure attribute length >= 4bytes. - if (attrLength < 4) - { - MS_WARN_TAG(ice, "attribute ERROR-CODE must be >= 4bytes length, packet discarded"); - - delete packet; - return nullptr; - } - - uint8_t errorClass = Utils::Byte::Get1Byte(attrValuePos, 2); - uint8_t errorNumber = Utils::Byte::Get1Byte(attrValuePos, 3); - auto errorCode = static_cast(errorClass * 100 + errorNumber); - - packet->SetErrorCode(errorCode); - - break; - } - - default:; - } - - // Set next attribute position. - pos = - static_cast(Utils::Byte::PadTo4Bytes(static_cast(pos + 4 + attrLength))); - } - - // Ensure current position matches the total length. - if (pos != len) - { - MS_WARN_TAG(ice, "computed packet size does not match total size, packet discarded"); - - delete packet; - return nullptr; - } - - // If it has FINGERPRINT attribute then verify it. - if (hasFingerprint) - { - // Compute the CRC32 of the received packet up to (but excluding) the - // FINGERPRINT attribute and XOR it with 0x5354554e. - uint32_t computedFingerprint = GetCRC32(data, fingerprintAttrPos) ^ 0x5354554e; - - // Compare with the FINGERPRINT value in the packet. - if (fingerprint != computedFingerprint) - { - MS_WARN_TAG( - ice, - "computed FINGERPRINT value does not match the value in the packet, " - "packet discarded"); - - delete packet; - return nullptr; - } - } - - return packet; - } - - /* Instance methods. */ - - StunPacket::StunPacket( - Class klass, Method method, const uint8_t* transactionId, const uint8_t* data, size_t size) - : klass(klass), method(method), transactionId(transactionId), data(const_cast(data)), - size(size) - { - MS_TRACE(); - } - - StunPacket::~StunPacket() - { - MS_TRACE(); - } + /* Class variables. */ + + const uint8_t StunPacket::magicCookie[] = { 0x21, 0x12, 0xA4, 0x42 }; + + /* Class methods. */ + + StunPacket* StunPacket::Parse(const uint8_t* data, size_t len) + { + MS_TRACE(); + + if (!StunPacket::IsStun(data, len)) + return nullptr; + + /* + The message type field is decomposed further into the following + structure: + + 0 1 + 2 3 4 5 6 7 8 9 0 1 2 3 4 5 + +--+--+-+-+-+-+-+-+-+-+-+-+-+-+ + |M |M |M|M|M|C|M|M|M|C|M|M|M|M| + |11|10|9|8|7|1|6|5|4|0|3|2|1|0| + +--+--+-+-+-+-+-+-+-+-+-+-+-+-+ + + Figure 3: Format of STUN Message Type Field + + Here the bits in the message type field are shown as most significant + (M11) through least significant (M0). M11 through M0 represent a 12- + bit encoding of the method. C1 and C0 represent a 2-bit encoding of + the class. + */ + + // Get type field. + uint16_t msgType = Utils::Byte::Get2Bytes(data, 0); + + // Get length field. + uint16_t msgLength = Utils::Byte::Get2Bytes(data, 2); + + // length field must be total size minus header's 20 bytes, and must be multiple of 4 Bytes. + if ((static_cast(msgLength) != len - 20) || ((msgLength & 0x03) != 0)) + { + MS_WARN_TAG( + ice, + "length field + 20 does not match total size (or it is not multiple of 4 bytes), " + "packet discarded"); + + return nullptr; + } + + // Get STUN method. + uint16_t msgMethod = (msgType & 0x000f) | ((msgType & 0x00e0) >> 1) | ((msgType & 0x3E00) >> 2); + + // Get STUN class. + uint16_t msgClass = ((data[0] & 0x01) << 1) | ((data[1] & 0x10) >> 4); + + // Create a new StunPacket (data + 8 points to the received TransactionID field). + auto* packet = new StunPacket( + static_cast(msgClass), static_cast(msgMethod), data + 8, data, len); + + /* + STUN Attributes + + After the STUN header are zero or more attributes. Each attribute + MUST be TLV encoded, with a 16-bit type, 16-bit length, and value. + Each STUN attribute MUST end on a 32-bit boundary. As mentioned + above, all fields in an attribute are transmitted most significant + bit first. + + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Type | Length | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Value (variable) .... + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + */ + + // Start looking for attributes after STUN header (Byte #20). + size_t pos{ 20 }; + // Flags (positions) for special MESSAGE-INTEGRITY and FINGERPRINT attributes. + bool hasMessageIntegrity{ false }; + bool hasFingerprint{ false }; + size_t fingerprintAttrPos; // Will point to the beginning of the attribute. + uint32_t fingerprint; // Holds the value of the FINGERPRINT attribute. + + // Ensure there are at least 4 remaining bytes (attribute with 0 length). + while (pos + 4 <= len) + { + // Get the attribute type. + auto attrType = static_cast(Utils::Byte::Get2Bytes(data, pos)); + + // Get the attribute length. + uint16_t attrLength = Utils::Byte::Get2Bytes(data, pos + 2); + + // Ensure the attribute length is not greater than the remaining size. + if ((pos + 4 + attrLength) > len) + { + MS_WARN_TAG(ice, "the attribute length exceeds the remaining size, packet discarded"); + + delete packet; + return nullptr; + } + + // FINGERPRINT must be the last attribute. + if (hasFingerprint) + { + MS_WARN_TAG(ice, "attribute after FINGERPRINT is not allowed, packet discarded"); + + delete packet; + return nullptr; + } + + // After a MESSAGE-INTEGRITY attribute just FINGERPRINT is allowed. + if (hasMessageIntegrity && attrType != Attribute::FINGERPRINT) + { + MS_WARN_TAG( + ice, + "attribute after MESSAGE-INTEGRITY other than FINGERPRINT is not allowed, " + "packet discarded"); + + delete packet; + return nullptr; + } + + const uint8_t* attrValuePos = data + pos + 4; + + switch (attrType) + { + case Attribute::USERNAME: + { + packet->SetUsername( + reinterpret_cast(attrValuePos), static_cast(attrLength)); + + break; + } + + case Attribute::PRIORITY: + { + // Ensure attribute length is 4 bytes. + if (attrLength != 4) + { + MS_WARN_TAG(ice, "attribute PRIORITY must be 4 bytes length, packet discarded"); + + delete packet; + return nullptr; + } + + packet->SetPriority(Utils::Byte::Get4Bytes(attrValuePos, 0)); + + break; + } + + case Attribute::ICE_CONTROLLING: + { + // Ensure attribute length is 8 bytes. + if (attrLength != 8) + { + MS_WARN_TAG(ice, "attribute ICE-CONTROLLING must be 8 bytes length, packet discarded"); + + delete packet; + return nullptr; + } + + packet->SetIceControlling(Utils::Byte::Get8Bytes(attrValuePos, 0)); + + break; + } + + case Attribute::ICE_CONTROLLED: + { + // Ensure attribute length is 8 bytes. + if (attrLength != 8) + { + MS_WARN_TAG(ice, "attribute ICE-CONTROLLED must be 8 bytes length, packet discarded"); + + delete packet; + return nullptr; + } + + packet->SetIceControlled(Utils::Byte::Get8Bytes(attrValuePos, 0)); + + break; + } + + case Attribute::USE_CANDIDATE: + { + // Ensure attribute length is 0 bytes. + if (attrLength != 0) + { + MS_WARN_TAG(ice, "attribute USE-CANDIDATE must be 0 bytes length, packet discarded"); + + delete packet; + return nullptr; + } + + packet->SetUseCandidate(); + + break; + } + + case Attribute::MESSAGE_INTEGRITY: + { + // Ensure attribute length is 20 bytes. + if (attrLength != 20) + { + MS_WARN_TAG(ice, "attribute MESSAGE-INTEGRITY must be 20 bytes length, packet discarded"); + + delete packet; + return nullptr; + } + + hasMessageIntegrity = true; + packet->SetMessageIntegrity(attrValuePos); + + break; + } + + case Attribute::FINGERPRINT: + { + // Ensure attribute length is 4 bytes. + if (attrLength != 4) + { + MS_WARN_TAG(ice, "attribute FINGERPRINT must be 4 bytes length, packet discarded"); + + delete packet; + return nullptr; + } + + hasFingerprint = true; + fingerprintAttrPos = pos; + fingerprint = Utils::Byte::Get4Bytes(attrValuePos, 0); + packet->SetFingerprint(); + + break; + } + + case Attribute::ERROR_CODE: + { + // Ensure attribute length >= 4bytes. + if (attrLength < 4) + { + MS_WARN_TAG(ice, "attribute ERROR-CODE must be >= 4bytes length, packet discarded"); + + delete packet; + return nullptr; + } + + uint8_t errorClass = Utils::Byte::Get1Byte(attrValuePos, 2); + uint8_t errorNumber = Utils::Byte::Get1Byte(attrValuePos, 3); + auto errorCode = static_cast(errorClass * 100 + errorNumber); + + packet->SetErrorCode(errorCode); + + break; + } + + default:; + } + + // Set next attribute position. + pos = + static_cast(Utils::Byte::PadTo4Bytes(static_cast(pos + 4 + attrLength))); + } + + // Ensure current position matches the total length. + if (pos != len) + { + MS_WARN_TAG(ice, "computed packet size does not match total size, packet discarded"); + + delete packet; + return nullptr; + } + + // If it has FINGERPRINT attribute then verify it. + if (hasFingerprint) + { + // Compute the CRC32 of the received packet up to (but excluding) the + // FINGERPRINT attribute and XOR it with 0x5354554e. + uint32_t computedFingerprint = GetCRC32(data, fingerprintAttrPos) ^ 0x5354554e; + + // Compare with the FINGERPRINT value in the packet. + if (fingerprint != computedFingerprint) + { + MS_WARN_TAG( + ice, + "computed FINGERPRINT value does not match the value in the packet, " + "packet discarded"); + + delete packet; + return nullptr; + } + } + + return packet; + } + + /* Instance methods. */ + + StunPacket::StunPacket( + Class klass, Method method, const uint8_t* transactionId, const uint8_t* data, size_t size) + : klass(klass), method(method), transactionId(transactionId), data(const_cast(data)), + size(size) + { + MS_TRACE(); + } + + StunPacket::~StunPacket() + { + MS_TRACE(); + } #if 0 - void StunPacket::Dump() const - { - MS_TRACE(); + void StunPacket::Dump() const + { + MS_TRACE(); - MS_DUMP(""); + MS_DUMP(""); - std::string klass; - switch (this->klass) - { - case Class::REQUEST: - klass = "Request"; - break; - case Class::INDICATION: - klass = "Indication"; - break; - case Class::SUCCESS_RESPONSE: - klass = "SuccessResponse"; - break; - case Class::ERROR_RESPONSE: - klass = "ErrorResponse"; - break; - } - if (this->method == Method::BINDING) - { - MS_DUMP(" Binding %s", klass.c_str()); - } - else - { - // This prints the unknown method number. Example: TURN Allocate => 0x003. - MS_DUMP(" %s with unknown method %#.3x", klass.c_str(), static_cast(this->method)); - } - MS_DUMP(" size: %zu bytes", this->size); + std::string klass; + switch (this->klass) + { + case Class::REQUEST: + klass = "Request"; + break; + case Class::INDICATION: + klass = "Indication"; + break; + case Class::SUCCESS_RESPONSE: + klass = "SuccessResponse"; + break; + case Class::ERROR_RESPONSE: + klass = "ErrorResponse"; + break; + } + if (this->method == Method::BINDING) + { + MS_DUMP(" Binding %s", klass.c_str()); + } + else + { + // This prints the unknown method number. Example: TURN Allocate => 0x003. + MS_DUMP(" %s with unknown method %#.3x", klass.c_str(), static_cast(this->method)); + } + MS_DUMP(" size: %zu bytes", this->size); - static char transactionId[25]; + static char transactionId[25]; - for (int i{ 0 }; i < 12; ++i) - { - // NOTE: n must be 3 because snprintf adds a \0 after printed chars. - std::snprintf(transactionId + (i * 2), 3, "%.2x", this->transactionId[i]); - } - MS_DUMP(" transactionId: %s", transactionId); - if (this->errorCode != 0u) - MS_DUMP(" errorCode: %" PRIu16, this->errorCode); - if (!this->username.empty()) - MS_DUMP(" username: %s", this->username.c_str()); - if (this->priority != 0u) - MS_DUMP(" priority: %" PRIu32, this->priority); - if (this->iceControlling != 0u) - MS_DUMP(" iceControlling: %" PRIu64, this->iceControlling); - if (this->iceControlled != 0u) - MS_DUMP(" iceControlled: %" PRIu64, this->iceControlled); - if (this->hasUseCandidate) - MS_DUMP(" useCandidate"); - if (this->xorMappedAddress != nullptr) - { - int family; - uint16_t port; - std::string ip; + for (int i{ 0 }; i < 12; ++i) + { + // NOTE: n must be 3 because snprintf adds a \0 after printed chars. + std::snprintf(transactionId + (i * 2), 3, "%.2x", this->transactionId[i]); + } + MS_DUMP(" transactionId: %s", transactionId); + if (this->errorCode != 0u) + MS_DUMP(" errorCode: %" PRIu16, this->errorCode); + if (!this->username.empty()) + MS_DUMP(" username: %s", this->username.c_str()); + if (this->priority != 0u) + MS_DUMP(" priority: %" PRIu32, this->priority); + if (this->iceControlling != 0u) + MS_DUMP(" iceControlling: %" PRIu64, this->iceControlling); + if (this->iceControlled != 0u) + MS_DUMP(" iceControlled: %" PRIu64, this->iceControlled); + if (this->hasUseCandidate) + MS_DUMP(" useCandidate"); + if (this->xorMappedAddress != nullptr) + { + int family; + uint16_t port; + std::string ip; - Utils::IP::GetAddressInfo(this->xorMappedAddress, family, ip, port); + Utils::IP::GetAddressInfo(this->xorMappedAddress, family, ip, port); - MS_DUMP(" xorMappedAddress: %s : %" PRIu16, ip.c_str(), port); - } - if (this->messageIntegrity != nullptr) - { - static char messageIntegrity[41]; + MS_DUMP(" xorMappedAddress: %s : %" PRIu16, ip.c_str(), port); + } + if (this->messageIntegrity != nullptr) + { + static char messageIntegrity[41]; - for (int i{ 0 }; i < 20; ++i) - { - std::snprintf(messageIntegrity + (i * 2), 3, "%.2x", this->messageIntegrity[i]); - } + for (int i{ 0 }; i < 20; ++i) + { + std::snprintf(messageIntegrity + (i * 2), 3, "%.2x", this->messageIntegrity[i]); + } - MS_DUMP(" messageIntegrity: %s", messageIntegrity); - } - if (this->hasFingerprint) - MS_DUMP(" has fingerprint"); + MS_DUMP(" messageIntegrity: %s", messageIntegrity); + } + if (this->hasFingerprint) + MS_DUMP(" has fingerprint"); - MS_DUMP(""); - } + MS_DUMP(""); + } #endif - StunPacket::Authentication StunPacket::CheckAuthentication( - const std::string& localUsername, const std::string& localPassword) - { - MS_TRACE(); + StunPacket::Authentication StunPacket::CheckAuthentication( + const std::string& localUsername, const std::string& localPassword) + { + MS_TRACE(); - switch (this->klass) - { - case Class::REQUEST: - case Class::INDICATION: - { - // Both USERNAME and MESSAGE-INTEGRITY must be present. - if (!this->messageIntegrity || this->username.empty()) - return Authentication::BAD_REQUEST; + switch (this->klass) + { + case Class::REQUEST: + case Class::INDICATION: + { + // Both USERNAME and MESSAGE-INTEGRITY must be present. + if (!this->messageIntegrity || this->username.empty()) + return Authentication::BAD_REQUEST; - // Check that USERNAME attribute begins with our local username plus ":". - size_t localUsernameLen = localUsername.length(); + // Check that USERNAME attribute begins with our local username plus ":". + size_t localUsernameLen = localUsername.length(); - if ( - this->username.length() <= localUsernameLen || this->username.at(localUsernameLen) != ':' || - (this->username.compare(0, localUsernameLen, localUsername) != 0)) - { - return Authentication::UNAUTHORIZED; - } + if ( + this->username.length() <= localUsernameLen || this->username.at(localUsernameLen) != ':' || + (this->username.compare(0, localUsernameLen, localUsername) != 0)) + { + return Authentication::UNAUTHORIZED; + } - break; - } - // This method cannot check authentication in received responses (as we - // are ICE-Lite and don't generate requests). - case Class::SUCCESS_RESPONSE: - case Class::ERROR_RESPONSE: - { - MS_ERROR("cannot check authentication for a STUN response"); + break; + } + // This method cannot check authentication in received responses (as we + // are ICE-Lite and don't generate requests). + case Class::SUCCESS_RESPONSE: + case Class::ERROR_RESPONSE: + { + MS_ERROR("cannot check authentication for a STUN response"); - return Authentication::BAD_REQUEST; - } - } + return Authentication::BAD_REQUEST; + } + } - // If there is FINGERPRINT it must be discarded for MESSAGE-INTEGRITY calculation, - // so the header length field must be modified (and later restored). - if (this->hasFingerprint) - // Set the header length field: full size - header length (20) - FINGERPRINT length (8). - Utils::Byte::Set2Bytes(this->data, 2, static_cast(this->size - 20 - 8)); + // If there is FINGERPRINT it must be discarded for MESSAGE-INTEGRITY calculation, + // so the header length field must be modified (and later restored). + if (this->hasFingerprint) + // Set the header length field: full size - header length (20) - FINGERPRINT length (8). + Utils::Byte::Set2Bytes(this->data, 2, static_cast(this->size - 20 - 8)); - // Calculate the HMAC-SHA1 of the message according to MESSAGE-INTEGRITY rules. + // Calculate the HMAC-SHA1 of the message according to MESSAGE-INTEGRITY rules. auto computedMessageIntegrity = openssl_HMACsha1( localPassword.data(),localPassword.size(), this->data, (this->messageIntegrity - 4) - this->data); - Authentication result; + Authentication result; - // Compare the computed HMAC-SHA1 with the MESSAGE-INTEGRITY in the packet. - if (std::memcmp(this->messageIntegrity, computedMessageIntegrity.data(), computedMessageIntegrity.size()) == 0) - result = Authentication::OK; - else - result = Authentication::UNAUTHORIZED; + // Compare the computed HMAC-SHA1 with the MESSAGE-INTEGRITY in the packet. + if (std::memcmp(this->messageIntegrity, computedMessageIntegrity.data(), computedMessageIntegrity.size()) == 0) + result = Authentication::OK; + else + result = Authentication::UNAUTHORIZED; - // Restore the header length field. - if (this->hasFingerprint) - Utils::Byte::Set2Bytes(this->data, 2, static_cast(this->size - 20)); + // Restore the header length field. + if (this->hasFingerprint) + Utils::Byte::Set2Bytes(this->data, 2, static_cast(this->size - 20)); - return result; - } + return result; + } - StunPacket* StunPacket::CreateSuccessResponse() - { - MS_TRACE(); + StunPacket* StunPacket::CreateSuccessResponse() + { + MS_TRACE(); - MS_ASSERT( - this->klass == Class::REQUEST, - "attempt to create a success response for a non Request STUN packet"); + MS_ASSERT( + this->klass == Class::REQUEST, + "attempt to create a success response for a non Request STUN packet"); - return new StunPacket(Class::SUCCESS_RESPONSE, this->method, this->transactionId, nullptr, 0); - } + return new StunPacket(Class::SUCCESS_RESPONSE, this->method, this->transactionId, nullptr, 0); + } - StunPacket* StunPacket::CreateErrorResponse(uint16_t errorCode) - { - MS_TRACE(); + StunPacket* StunPacket::CreateErrorResponse(uint16_t errorCode) + { + MS_TRACE(); - MS_ASSERT( - this->klass == Class::REQUEST, - "attempt to create an error response for a non Request STUN packet"); + MS_ASSERT( + this->klass == Class::REQUEST, + "attempt to create an error response for a non Request STUN packet"); - auto* response = - new StunPacket(Class::ERROR_RESPONSE, this->method, this->transactionId, nullptr, 0); + auto* response = + new StunPacket(Class::ERROR_RESPONSE, this->method, this->transactionId, nullptr, 0); - response->SetErrorCode(errorCode); + response->SetErrorCode(errorCode); - return response; - } + return response; + } - void StunPacket::Authenticate(const std::string& password) - { - // Just for Request, Indication and SuccessResponse messages. - if (this->klass == Class::ERROR_RESPONSE) - { - MS_ERROR("cannot set password for ErrorResponse messages"); + void StunPacket::Authenticate(const std::string& password) + { + // Just for Request, Indication and SuccessResponse messages. + if (this->klass == Class::ERROR_RESPONSE) + { + MS_ERROR("cannot set password for ErrorResponse messages"); - return; - } + return; + } - this->password = password; - } + this->password = password; + } - void StunPacket::Serialize(uint8_t* buffer) - { - MS_TRACE(); + void StunPacket::Serialize(uint8_t* buffer) + { + MS_TRACE(); - // Some useful variables. - uint16_t usernamePaddedLen{ 0 }; - uint16_t xorMappedAddressPaddedLen{ 0 }; - bool addXorMappedAddress = - ((this->xorMappedAddress != nullptr) && this->method == StunPacket::Method::BINDING && - this->klass == Class::SUCCESS_RESPONSE); - bool addErrorCode = ((this->errorCode != 0u) && this->klass == Class::ERROR_RESPONSE); - bool addMessageIntegrity = (this->klass != Class::ERROR_RESPONSE && !this->password.empty()); - bool addFingerprint{ true }; // Do always. + // Some useful variables. + uint16_t usernamePaddedLen{ 0 }; + uint16_t xorMappedAddressPaddedLen{ 0 }; + bool addXorMappedAddress = + ((this->xorMappedAddress != nullptr) && this->method == StunPacket::Method::BINDING && + this->klass == Class::SUCCESS_RESPONSE); + bool addErrorCode = ((this->errorCode != 0u) && this->klass == Class::ERROR_RESPONSE); + bool addMessageIntegrity = (this->klass != Class::ERROR_RESPONSE && !this->password.empty()); + bool addFingerprint{ true }; // Do always. - // Update data pointer. - this->data = buffer; + // Update data pointer. + this->data = buffer; - // First calculate the total required size for the entire packet. - this->size = 20; // Header. + // First calculate the total required size for the entire packet. + this->size = 20; // Header. - if (!this->username.empty()) - { - usernamePaddedLen = Utils::Byte::PadTo4Bytes(static_cast(this->username.length())); - this->size += 4 + usernamePaddedLen; - } + if (!this->username.empty()) + { + usernamePaddedLen = Utils::Byte::PadTo4Bytes(static_cast(this->username.length())); + this->size += 4 + usernamePaddedLen; + } - if (this->priority != 0u) - this->size += 4 + 4; + if (this->priority != 0u) + this->size += 4 + 4; - if (this->iceControlling != 0u) - this->size += 4 + 8; + if (this->iceControlling != 0u) + this->size += 4 + 8; - if (this->iceControlled != 0u) - this->size += 4 + 8; + if (this->iceControlled != 0u) + this->size += 4 + 8; - if (this->hasUseCandidate) - this->size += 4; + if (this->hasUseCandidate) + this->size += 4; - if (addXorMappedAddress) - { - switch (this->xorMappedAddress->sa_family) - { - case AF_INET: - { - xorMappedAddressPaddedLen = 8; - this->size += 4 + 8; + if (addXorMappedAddress) + { + switch (this->xorMappedAddress->sa_family) + { + case AF_INET: + { + xorMappedAddressPaddedLen = 8; + this->size += 4 + 8; - break; - } + break; + } - case AF_INET6: - { - xorMappedAddressPaddedLen = 20; - this->size += 4 + 20; + case AF_INET6: + { + xorMappedAddressPaddedLen = 20; + this->size += 4 + 20; - break; - } + break; + } - default: - { - MS_ERROR("invalid inet family in XOR-MAPPED-ADDRESS attribute"); + default: + { + MS_ERROR("invalid inet family in XOR-MAPPED-ADDRESS attribute"); - addXorMappedAddress = false; - } - } - } + addXorMappedAddress = false; + } + } + } - if (addErrorCode) - this->size += 4 + 4; + if (addErrorCode) + this->size += 4 + 4; - if (addMessageIntegrity) - this->size += 4 + 20; + if (addMessageIntegrity) + this->size += 4 + 20; - if (addFingerprint) - this->size += 4 + 4; + if (addFingerprint) + this->size += 4 + 4; - // Merge class and method fields into type. - uint16_t typeField = (static_cast(this->method) & 0x0f80) << 2; + // Merge class and method fields into type. + uint16_t typeField = (static_cast(this->method) & 0x0f80) << 2; - typeField |= (static_cast(this->method) & 0x0070) << 1; - typeField |= (static_cast(this->method) & 0x000f); - typeField |= (static_cast(this->klass) & 0x02) << 7; - typeField |= (static_cast(this->klass) & 0x01) << 4; + typeField |= (static_cast(this->method) & 0x0070) << 1; + typeField |= (static_cast(this->method) & 0x000f); + typeField |= (static_cast(this->klass) & 0x02) << 7; + typeField |= (static_cast(this->klass) & 0x01) << 4; - // Set type field. - Utils::Byte::Set2Bytes(buffer, 0, typeField); - // Set length field. - Utils::Byte::Set2Bytes(buffer, 2, static_cast(this->size) - 20); - // Set magic cookie. - std::memcpy(buffer + 4, StunPacket::magicCookie, 4); - // Set TransactionId field. - std::memcpy(buffer + 8, this->transactionId, 12); - // Update the transaction ID pointer. - this->transactionId = buffer + 8; - // Add atributes. - size_t pos{ 20 }; + // Set type field. + Utils::Byte::Set2Bytes(buffer, 0, typeField); + // Set length field. + Utils::Byte::Set2Bytes(buffer, 2, static_cast(this->size) - 20); + // Set magic cookie. + std::memcpy(buffer + 4, StunPacket::magicCookie, 4); + // Set TransactionId field. + std::memcpy(buffer + 8, this->transactionId, 12); + // Update the transaction ID pointer. + this->transactionId = buffer + 8; + // Add atributes. + size_t pos{ 20 }; - // Add USERNAME. - if (usernamePaddedLen != 0u) - { - Utils::Byte::Set2Bytes(buffer, pos, static_cast(Attribute::USERNAME)); - Utils::Byte::Set2Bytes(buffer, pos + 2, static_cast(this->username.length())); - std::memcpy(buffer + pos + 4, this->username.c_str(), this->username.length()); - pos += 4 + usernamePaddedLen; - } + // Add USERNAME. + if (usernamePaddedLen != 0u) + { + Utils::Byte::Set2Bytes(buffer, pos, static_cast(Attribute::USERNAME)); + Utils::Byte::Set2Bytes(buffer, pos + 2, static_cast(this->username.length())); + std::memcpy(buffer + pos + 4, this->username.c_str(), this->username.length()); + pos += 4 + usernamePaddedLen; + } - // Add PRIORITY. - if (this->priority != 0u) - { - Utils::Byte::Set2Bytes(buffer, pos, static_cast(Attribute::PRIORITY)); - Utils::Byte::Set2Bytes(buffer, pos + 2, 4); - Utils::Byte::Set4Bytes(buffer, pos + 4, this->priority); - pos += 4 + 4; - } + // Add PRIORITY. + if (this->priority != 0u) + { + Utils::Byte::Set2Bytes(buffer, pos, static_cast(Attribute::PRIORITY)); + Utils::Byte::Set2Bytes(buffer, pos + 2, 4); + Utils::Byte::Set4Bytes(buffer, pos + 4, this->priority); + pos += 4 + 4; + } - // Add ICE-CONTROLLING. - if (this->iceControlling != 0u) - { - Utils::Byte::Set2Bytes(buffer, pos, static_cast(Attribute::ICE_CONTROLLING)); - Utils::Byte::Set2Bytes(buffer, pos + 2, 8); - Utils::Byte::Set8Bytes(buffer, pos + 4, this->iceControlling); - pos += 4 + 8; - } + // Add ICE-CONTROLLING. + if (this->iceControlling != 0u) + { + Utils::Byte::Set2Bytes(buffer, pos, static_cast(Attribute::ICE_CONTROLLING)); + Utils::Byte::Set2Bytes(buffer, pos + 2, 8); + Utils::Byte::Set8Bytes(buffer, pos + 4, this->iceControlling); + pos += 4 + 8; + } - // Add ICE-CONTROLLED. - if (this->iceControlled != 0u) - { - Utils::Byte::Set2Bytes(buffer, pos, static_cast(Attribute::ICE_CONTROLLED)); - Utils::Byte::Set2Bytes(buffer, pos + 2, 8); - Utils::Byte::Set8Bytes(buffer, pos + 4, this->iceControlled); - pos += 4 + 8; - } + // Add ICE-CONTROLLED. + if (this->iceControlled != 0u) + { + Utils::Byte::Set2Bytes(buffer, pos, static_cast(Attribute::ICE_CONTROLLED)); + Utils::Byte::Set2Bytes(buffer, pos + 2, 8); + Utils::Byte::Set8Bytes(buffer, pos + 4, this->iceControlled); + pos += 4 + 8; + } - // Add USE-CANDIDATE. - if (this->hasUseCandidate) - { - Utils::Byte::Set2Bytes(buffer, pos, static_cast(Attribute::USE_CANDIDATE)); - Utils::Byte::Set2Bytes(buffer, pos + 2, 0); - pos += 4; - } + // Add USE-CANDIDATE. + if (this->hasUseCandidate) + { + Utils::Byte::Set2Bytes(buffer, pos, static_cast(Attribute::USE_CANDIDATE)); + Utils::Byte::Set2Bytes(buffer, pos + 2, 0); + pos += 4; + } - // Add XOR-MAPPED-ADDRESS - if (addXorMappedAddress) - { - Utils::Byte::Set2Bytes(buffer, pos, static_cast(Attribute::XOR_MAPPED_ADDRESS)); - Utils::Byte::Set2Bytes(buffer, pos + 2, xorMappedAddressPaddedLen); + // Add XOR-MAPPED-ADDRESS + if (addXorMappedAddress) + { + Utils::Byte::Set2Bytes(buffer, pos, static_cast(Attribute::XOR_MAPPED_ADDRESS)); + Utils::Byte::Set2Bytes(buffer, pos + 2, xorMappedAddressPaddedLen); - uint8_t* attrValue = buffer + pos + 4; + uint8_t* attrValue = buffer + pos + 4; - switch (this->xorMappedAddress->sa_family) - { - case AF_INET: - { - // Set first byte to 0. - attrValue[0] = 0; - // Set inet family. - attrValue[1] = 0x01; - // Set port and XOR it. - std::memcpy( - attrValue + 2, - &(reinterpret_cast(this->xorMappedAddress))->sin_port, - 2); - attrValue[2] ^= StunPacket::magicCookie[0]; - attrValue[3] ^= StunPacket::magicCookie[1]; - // Set address and XOR it. - std::memcpy( - attrValue + 4, - &(reinterpret_cast(this->xorMappedAddress))->sin_addr.s_addr, - 4); - attrValue[4] ^= StunPacket::magicCookie[0]; - attrValue[5] ^= StunPacket::magicCookie[1]; - attrValue[6] ^= StunPacket::magicCookie[2]; - attrValue[7] ^= StunPacket::magicCookie[3]; + switch (this->xorMappedAddress->sa_family) + { + case AF_INET: + { + // Set first byte to 0. + attrValue[0] = 0; + // Set inet family. + attrValue[1] = 0x01; + // Set port and XOR it. + std::memcpy( + attrValue + 2, + &(reinterpret_cast(this->xorMappedAddress))->sin_port, + 2); + attrValue[2] ^= StunPacket::magicCookie[0]; + attrValue[3] ^= StunPacket::magicCookie[1]; + // Set address and XOR it. + std::memcpy( + attrValue + 4, + &(reinterpret_cast(this->xorMappedAddress))->sin_addr.s_addr, + 4); + attrValue[4] ^= StunPacket::magicCookie[0]; + attrValue[5] ^= StunPacket::magicCookie[1]; + attrValue[6] ^= StunPacket::magicCookie[2]; + attrValue[7] ^= StunPacket::magicCookie[3]; - pos += 4 + 8; + pos += 4 + 8; - break; - } + break; + } - case AF_INET6: - { - // Set first byte to 0. - attrValue[0] = 0; - // Set inet family. - attrValue[1] = 0x02; - // Set port and XOR it. - std::memcpy( - attrValue + 2, - &(reinterpret_cast(this->xorMappedAddress))->sin6_port, - 2); - attrValue[2] ^= StunPacket::magicCookie[0]; - attrValue[3] ^= StunPacket::magicCookie[1]; - // Set address and XOR it. - std::memcpy( - attrValue + 4, - &(reinterpret_cast(this->xorMappedAddress))->sin6_addr.s6_addr, - 16); - attrValue[4] ^= StunPacket::magicCookie[0]; - attrValue[5] ^= StunPacket::magicCookie[1]; - attrValue[6] ^= StunPacket::magicCookie[2]; - attrValue[7] ^= StunPacket::magicCookie[3]; - attrValue[8] ^= this->transactionId[0]; - attrValue[9] ^= this->transactionId[1]; - attrValue[10] ^= this->transactionId[2]; - attrValue[11] ^= this->transactionId[3]; - attrValue[12] ^= this->transactionId[4]; - attrValue[13] ^= this->transactionId[5]; - attrValue[14] ^= this->transactionId[6]; - attrValue[15] ^= this->transactionId[7]; - attrValue[16] ^= this->transactionId[8]; - attrValue[17] ^= this->transactionId[9]; - attrValue[18] ^= this->transactionId[10]; - attrValue[19] ^= this->transactionId[11]; + case AF_INET6: + { + // Set first byte to 0. + attrValue[0] = 0; + // Set inet family. + attrValue[1] = 0x02; + // Set port and XOR it. + std::memcpy( + attrValue + 2, + &(reinterpret_cast(this->xorMappedAddress))->sin6_port, + 2); + attrValue[2] ^= StunPacket::magicCookie[0]; + attrValue[3] ^= StunPacket::magicCookie[1]; + // Set address and XOR it. + std::memcpy( + attrValue + 4, + &(reinterpret_cast(this->xorMappedAddress))->sin6_addr.s6_addr, + 16); + attrValue[4] ^= StunPacket::magicCookie[0]; + attrValue[5] ^= StunPacket::magicCookie[1]; + attrValue[6] ^= StunPacket::magicCookie[2]; + attrValue[7] ^= StunPacket::magicCookie[3]; + attrValue[8] ^= this->transactionId[0]; + attrValue[9] ^= this->transactionId[1]; + attrValue[10] ^= this->transactionId[2]; + attrValue[11] ^= this->transactionId[3]; + attrValue[12] ^= this->transactionId[4]; + attrValue[13] ^= this->transactionId[5]; + attrValue[14] ^= this->transactionId[6]; + attrValue[15] ^= this->transactionId[7]; + attrValue[16] ^= this->transactionId[8]; + attrValue[17] ^= this->transactionId[9]; + attrValue[18] ^= this->transactionId[10]; + attrValue[19] ^= this->transactionId[11]; - pos += 4 + 20; + pos += 4 + 20; - break; - } - } - } + break; + } + } + } - // Add ERROR-CODE. - if (addErrorCode) - { - Utils::Byte::Set2Bytes(buffer, pos, static_cast(Attribute::ERROR_CODE)); - Utils::Byte::Set2Bytes(buffer, pos + 2, 4); + // Add ERROR-CODE. + if (addErrorCode) + { + Utils::Byte::Set2Bytes(buffer, pos, static_cast(Attribute::ERROR_CODE)); + Utils::Byte::Set2Bytes(buffer, pos + 2, 4); - auto codeClass = static_cast(this->errorCode / 100); - uint8_t codeNumber = static_cast(this->errorCode) - (codeClass * 100); + auto codeClass = static_cast(this->errorCode / 100); + uint8_t codeNumber = static_cast(this->errorCode) - (codeClass * 100); - Utils::Byte::Set2Bytes(buffer, pos + 4, 0); - Utils::Byte::Set1Byte(buffer, pos + 6, codeClass); - Utils::Byte::Set1Byte(buffer, pos + 7, codeNumber); - pos += 4 + 4; - } + Utils::Byte::Set2Bytes(buffer, pos + 4, 0); + Utils::Byte::Set1Byte(buffer, pos + 6, codeClass); + Utils::Byte::Set1Byte(buffer, pos + 7, codeNumber); + pos += 4 + 4; + } - // Add MESSAGE-INTEGRITY. - if (addMessageIntegrity) - { - // Ignore FINGERPRINT. - if (addFingerprint) - Utils::Byte::Set2Bytes(buffer, 2, static_cast(this->size - 20 - 8)); + // Add MESSAGE-INTEGRITY. + if (addMessageIntegrity) + { + // Ignore FINGERPRINT. + if (addFingerprint) + Utils::Byte::Set2Bytes(buffer, 2, static_cast(this->size - 20 - 8)); - // Calculate the HMAC-SHA1 of the packet according to MESSAGE-INTEGRITY rules. + // Calculate the HMAC-SHA1 of the packet according to MESSAGE-INTEGRITY rules. auto computedMessageIntegrity = openssl_HMACsha1(this->password.data(), this->password.size(), buffer, pos); Utils::Byte::Set2Bytes(buffer, pos, static_cast(Attribute::MESSAGE_INTEGRITY)); - Utils::Byte::Set2Bytes(buffer, pos + 2, 20); - std::memcpy(buffer + pos + 4, computedMessageIntegrity.data(), computedMessageIntegrity.size()); + Utils::Byte::Set2Bytes(buffer, pos + 2, 20); + std::memcpy(buffer + pos + 4, computedMessageIntegrity.data(), computedMessageIntegrity.size()); - // Update the pointer. - this->messageIntegrity = buffer + pos + 4; - pos += 4 + 20; + // Update the pointer. + this->messageIntegrity = buffer + pos + 4; + pos += 4 + 20; - // Restore length field. - if (addFingerprint) - Utils::Byte::Set2Bytes(buffer, 2, static_cast(this->size - 20)); - } - else - { - // Unset the pointer (if it was set). - this->messageIntegrity = nullptr; - } + // Restore length field. + if (addFingerprint) + Utils::Byte::Set2Bytes(buffer, 2, static_cast(this->size - 20)); + } + else + { + // Unset the pointer (if it was set). + this->messageIntegrity = nullptr; + } - // Add FINGERPRINT. - if (addFingerprint) - { - // Compute the CRC32 of the packet up to (but excluding) the FINGERPRINT - // attribute and XOR it with 0x5354554e. - uint32_t computedFingerprint = GetCRC32(buffer, pos) ^ 0x5354554e; + // Add FINGERPRINT. + if (addFingerprint) + { + // Compute the CRC32 of the packet up to (but excluding) the FINGERPRINT + // attribute and XOR it with 0x5354554e. + uint32_t computedFingerprint = GetCRC32(buffer, pos) ^ 0x5354554e; - Utils::Byte::Set2Bytes(buffer, pos, static_cast(Attribute::FINGERPRINT)); - Utils::Byte::Set2Bytes(buffer, pos + 2, 4); - Utils::Byte::Set4Bytes(buffer, pos + 4, computedFingerprint); - pos += 4 + 4; + Utils::Byte::Set2Bytes(buffer, pos, static_cast(Attribute::FINGERPRINT)); + Utils::Byte::Set2Bytes(buffer, pos + 2, 4); + Utils::Byte::Set4Bytes(buffer, pos + 4, computedFingerprint); + pos += 4 + 4; - // Set flag. - this->hasFingerprint = true; - } - else - { - this->hasFingerprint = false; - } + // Set flag. + this->hasFingerprint = true; + } + else + { + this->hasFingerprint = false; + } - MS_ASSERT(pos == this->size, "pos != this->size"); - } + MS_ASSERT(pos == this->size, "pos != this->size"); + } } // namespace RTC diff --git a/webrtc/StunPacket.hpp b/webrtc/StunPacket.hpp index a6b2c940..2776a9b6 100644 --- a/webrtc/StunPacket.hpp +++ b/webrtc/StunPacket.hpp @@ -26,188 +26,188 @@ OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. namespace RTC { - class StunPacket - { - public: - // STUN message class. - enum class Class : uint16_t - { - REQUEST = 0, - INDICATION = 1, - SUCCESS_RESPONSE = 2, - ERROR_RESPONSE = 3 - }; + class StunPacket + { + public: + // STUN message class. + enum class Class : uint16_t + { + REQUEST = 0, + INDICATION = 1, + SUCCESS_RESPONSE = 2, + ERROR_RESPONSE = 3 + }; - // STUN message method. - enum class Method : uint16_t - { - BINDING = 1 - }; + // STUN message method. + enum class Method : uint16_t + { + BINDING = 1 + }; - // Attribute type. - enum class Attribute : uint16_t - { - MAPPED_ADDRESS = 0x0001, - USERNAME = 0x0006, - MESSAGE_INTEGRITY = 0x0008, - ERROR_CODE = 0x0009, - UNKNOWN_ATTRIBUTES = 0x000A, - REALM = 0x0014, - NONCE = 0x0015, - XOR_MAPPED_ADDRESS = 0x0020, - PRIORITY = 0x0024, - USE_CANDIDATE = 0x0025, - SOFTWARE = 0x8022, - ALTERNATE_SERVER = 0x8023, - FINGERPRINT = 0x8028, - ICE_CONTROLLED = 0x8029, - ICE_CONTROLLING = 0x802A - }; + // Attribute type. + enum class Attribute : uint16_t + { + MAPPED_ADDRESS = 0x0001, + USERNAME = 0x0006, + MESSAGE_INTEGRITY = 0x0008, + ERROR_CODE = 0x0009, + UNKNOWN_ATTRIBUTES = 0x000A, + REALM = 0x0014, + NONCE = 0x0015, + XOR_MAPPED_ADDRESS = 0x0020, + PRIORITY = 0x0024, + USE_CANDIDATE = 0x0025, + SOFTWARE = 0x8022, + ALTERNATE_SERVER = 0x8023, + FINGERPRINT = 0x8028, + ICE_CONTROLLED = 0x8029, + ICE_CONTROLLING = 0x802A + }; - // Authentication result. - enum class Authentication - { - OK = 0, - UNAUTHORIZED = 1, - BAD_REQUEST = 2 - }; + // Authentication result. + enum class Authentication + { + OK = 0, + UNAUTHORIZED = 1, + BAD_REQUEST = 2 + }; - public: - static bool IsStun(const uint8_t* data, size_t len) - { - // clang-format off - return ( - // STUN headers are 20 bytes. - (len >= 20) && - // DOC: https://tools.ietf.org/html/draft-ietf-avtcore-rfc5764-mux-fixes - (data[0] < 3) && - // Magic cookie must match. - (data[4] == StunPacket::magicCookie[0]) && (data[5] == StunPacket::magicCookie[1]) && - (data[6] == StunPacket::magicCookie[2]) && (data[7] == StunPacket::magicCookie[3]) - ); - // clang-format on - } - static StunPacket* Parse(const uint8_t* data, size_t len); + public: + static bool IsStun(const uint8_t* data, size_t len) + { + // clang-format off + return ( + // STUN headers are 20 bytes. + (len >= 20) && + // DOC: https://tools.ietf.org/html/draft-ietf-avtcore-rfc5764-mux-fixes + (data[0] < 3) && + // Magic cookie must match. + (data[4] == StunPacket::magicCookie[0]) && (data[5] == StunPacket::magicCookie[1]) && + (data[6] == StunPacket::magicCookie[2]) && (data[7] == StunPacket::magicCookie[3]) + ); + // clang-format on + } + static StunPacket* Parse(const uint8_t* data, size_t len); - private: - static const uint8_t magicCookie[]; + private: + static const uint8_t magicCookie[]; - public: - StunPacket( - Class klass, Method method, const uint8_t* transactionId, const uint8_t* data, size_t size); - ~StunPacket(); + public: + StunPacket( + Class klass, Method method, const uint8_t* transactionId, const uint8_t* data, size_t size); + ~StunPacket(); - void Dump() const; - Class GetClass() const - { - return this->klass; - } - Method GetMethod() const - { - return this->method; - } - const uint8_t* GetData() const - { - return this->data; - } - size_t GetSize() const - { - return this->size; - } - void SetUsername(const char* username, size_t len) - { - this->username.assign(username, len); - } - void SetPriority(uint32_t priority) - { - this->priority = priority; - } - void SetIceControlling(uint64_t iceControlling) - { - this->iceControlling = iceControlling; - } - void SetIceControlled(uint64_t iceControlled) - { - this->iceControlled = iceControlled; - } - void SetUseCandidate() - { - this->hasUseCandidate = true; - } - void SetXorMappedAddress(const struct sockaddr* xorMappedAddress) - { - this->xorMappedAddress = xorMappedAddress; - } - void SetErrorCode(uint16_t errorCode) - { - this->errorCode = errorCode; - } - void SetMessageIntegrity(const uint8_t* messageIntegrity) - { - this->messageIntegrity = messageIntegrity; - } - void SetFingerprint() - { - this->hasFingerprint = true; - } - const std::string& GetUsername() const - { - return this->username; - } - uint32_t GetPriority() const - { - return this->priority; - } - uint64_t GetIceControlling() const - { - return this->iceControlling; - } - uint64_t GetIceControlled() const - { - return this->iceControlled; - } - bool HasUseCandidate() const - { - return this->hasUseCandidate; - } - uint16_t GetErrorCode() const - { - return this->errorCode; - } - bool HasMessageIntegrity() const - { - return (this->messageIntegrity ? true : false); - } - bool HasFingerprint() const - { - return this->hasFingerprint; - } - Authentication CheckAuthentication( - const std::string& localUsername, const std::string& localPassword); - StunPacket* CreateSuccessResponse(); - StunPacket* CreateErrorResponse(uint16_t errorCode); - void Authenticate(const std::string& password); - void Serialize(uint8_t* buffer); + void Dump() const; + Class GetClass() const + { + return this->klass; + } + Method GetMethod() const + { + return this->method; + } + const uint8_t* GetData() const + { + return this->data; + } + size_t GetSize() const + { + return this->size; + } + void SetUsername(const char* username, size_t len) + { + this->username.assign(username, len); + } + void SetPriority(uint32_t priority) + { + this->priority = priority; + } + void SetIceControlling(uint64_t iceControlling) + { + this->iceControlling = iceControlling; + } + void SetIceControlled(uint64_t iceControlled) + { + this->iceControlled = iceControlled; + } + void SetUseCandidate() + { + this->hasUseCandidate = true; + } + void SetXorMappedAddress(const struct sockaddr* xorMappedAddress) + { + this->xorMappedAddress = xorMappedAddress; + } + void SetErrorCode(uint16_t errorCode) + { + this->errorCode = errorCode; + } + void SetMessageIntegrity(const uint8_t* messageIntegrity) + { + this->messageIntegrity = messageIntegrity; + } + void SetFingerprint() + { + this->hasFingerprint = true; + } + const std::string& GetUsername() const + { + return this->username; + } + uint32_t GetPriority() const + { + return this->priority; + } + uint64_t GetIceControlling() const + { + return this->iceControlling; + } + uint64_t GetIceControlled() const + { + return this->iceControlled; + } + bool HasUseCandidate() const + { + return this->hasUseCandidate; + } + uint16_t GetErrorCode() const + { + return this->errorCode; + } + bool HasMessageIntegrity() const + { + return (this->messageIntegrity ? true : false); + } + bool HasFingerprint() const + { + return this->hasFingerprint; + } + Authentication CheckAuthentication( + const std::string& localUsername, const std::string& localPassword); + StunPacket* CreateSuccessResponse(); + StunPacket* CreateErrorResponse(uint16_t errorCode); + void Authenticate(const std::string& password); + void Serialize(uint8_t* buffer); - private: - // Passed by argument. - Class klass; // 2 bytes. - Method method; // 2 bytes. - const uint8_t* transactionId{ nullptr }; // 12 bytes. - uint8_t* data{ nullptr }; // Pointer to binary data. - size_t size{ 0u }; // The full message size (including header). - // STUN attributes. - std::string username; // Less than 513 bytes. - uint32_t priority{ 0u }; // 4 bytes unsigned integer. - uint64_t iceControlling{ 0u }; // 8 bytes unsigned integer. - uint64_t iceControlled{ 0u }; // 8 bytes unsigned integer. - bool hasUseCandidate{ false }; // 0 bytes. - const uint8_t* messageIntegrity{ nullptr }; // 20 bytes. - bool hasFingerprint{ false }; // 4 bytes. - const struct sockaddr* xorMappedAddress{ nullptr }; // 8 or 20 bytes. - uint16_t errorCode{ 0u }; // 4 bytes (no reason phrase). - std::string password; - }; + private: + // Passed by argument. + Class klass; // 2 bytes. + Method method; // 2 bytes. + const uint8_t* transactionId{ nullptr }; // 12 bytes. + uint8_t* data{ nullptr }; // Pointer to binary data. + size_t size{ 0u }; // The full message size (including header). + // STUN attributes. + std::string username; // Less than 513 bytes. + uint32_t priority{ 0u }; // 4 bytes unsigned integer. + uint64_t iceControlling{ 0u }; // 8 bytes unsigned integer. + uint64_t iceControlled{ 0u }; // 8 bytes unsigned integer. + bool hasUseCandidate{ false }; // 0 bytes. + const uint8_t* messageIntegrity{ nullptr }; // 20 bytes. + bool hasFingerprint{ false }; // 4 bytes. + const struct sockaddr* xorMappedAddress{ nullptr }; // 8 or 20 bytes. + uint16_t errorCode{ 0u }; // 4 bytes (no reason phrase). + std::string password; + }; } // namespace RTC #endif