Older/ToolKit/Util/SSLUtil.cpp
amass 9de3af15eb
All checks were successful
Deploy / PullDocker (push) Successful in 12s
Deploy / Build (push) Successful in 1m51s
add ZLMediaKit code for learning.
2024-09-28 23:55:00 +08:00

387 lines
11 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/*
* Copyright (c) 2016 The ZLToolKit project authors. All Rights Reserved.
*
* This file is part of ZLToolKit(https://github.com/ZLMediaKit/ZLToolKit).
*
* Use of this source code is governed by MIT license that can be found in the
* LICENSE file in the root of the source tree. All contributing project authors
* may be found in the AUTHORS file in the root of the source tree.
*/
#include "SSLUtil.h"
#include "onceToken.h"
#include "logger.h"
#if defined(ENABLE_OPENSSL)
#include <openssl/bio.h>
#include <openssl/ossl_typ.h>
#include <openssl/pkcs12.h>
#include <openssl/ssl.h>
#include <openssl/rand.h>
#include <openssl/crypto.h>
#include <openssl/err.h>
#include <openssl/conf.h>
#endif //defined(ENABLE_OPENSSL)
using namespace std;
namespace toolkit {
std::string SSLUtil::getLastError() {
#if defined(ENABLE_OPENSSL)
unsigned long errCode = ERR_get_error();
if (errCode != 0) {
char buffer[256];
ERR_error_string_n(errCode, buffer, sizeof(buffer));
return buffer;
} else
#endif //defined(ENABLE_OPENSSL)
{
return "No error";
}
}
#if defined(ENABLE_OPENSSL)
static int getCerType(BIO *bio, const char *passwd, X509 **x509, int type) {
//尝试pem格式
if (type == 1 || type == 0) {
if (type == 0) {
BIO_reset(bio);
}
// 尝试PEM格式
*x509 = PEM_read_bio_X509(bio, nullptr, nullptr, nullptr);
if (*x509) {
return 1;
}
}
if (type == 2 || type == 0) {
if (type == 0) {
BIO_reset(bio);
}
//尝试DER格式
*x509 = d2i_X509_bio(bio, nullptr);
if (*x509) {
return 2;
}
}
if (type == 3 || type == 0) {
if (type == 0) {
BIO_reset(bio);
}
//尝试p12格式
PKCS12 *p12 = d2i_PKCS12_bio(bio, nullptr);
if (p12) {
EVP_PKEY *pkey = nullptr;
PKCS12_parse(p12, passwd, &pkey, x509, nullptr);
PKCS12_free(p12);
if (pkey) {
EVP_PKEY_free(pkey);
}
if (*x509) {
return 3;
}
}
}
return 0;
}
#endif //defined(ENABLE_OPENSSL)
vector<shared_ptr<X509> > SSLUtil::loadPublicKey(const string &file_path_or_data, const string &passwd, bool isFile) {
vector<shared_ptr<X509> > ret;
#if defined(ENABLE_OPENSSL)
BIO *bio = isFile ? BIO_new_file((char *) file_path_or_data.data(), "r") :
BIO_new_mem_buf((char *) file_path_or_data.data(), file_path_or_data.size());
if (!bio) {
WarnL << (isFile ? "BIO_new_file" : "BIO_new_mem_buf") << " failed: " << getLastError();
return ret;
}
onceToken token0(nullptr, [&]() {
BIO_free(bio);
});
int cer_type = 0;
X509 *x509 = nullptr;
do {
cer_type = getCerType(bio, passwd.data(), &x509, cer_type);
if (cer_type) {
ret.push_back(shared_ptr<X509>(x509, [](X509 *ptr) { X509_free(ptr); }));
}
} while (cer_type != 0);
return ret;
#else
return ret;
#endif //defined(ENABLE_OPENSSL)
}
shared_ptr<EVP_PKEY> SSLUtil::loadPrivateKey(const string &file_path_or_data, const string &passwd, bool isFile) {
#if defined(ENABLE_OPENSSL)
BIO *bio = isFile ?
BIO_new_file((char *) file_path_or_data.data(), "r") :
BIO_new_mem_buf((char *) file_path_or_data.data(), file_path_or_data.size());
if (!bio) {
WarnL << (isFile ? "BIO_new_file" : "BIO_new_mem_buf") << " failed: " << getLastError();
return nullptr;
}
pem_password_cb *cb = [](char *buf, int size, int rwflag, void *userdata) -> int {
const string *passwd = (const string *) userdata;
size = size < (int) passwd->size() ? size : (int) passwd->size();
memcpy(buf, passwd->data(), size);
return size;
};
onceToken token0(nullptr, [&]() {
BIO_free(bio);
});
//尝试pem格式
EVP_PKEY *evp_key = PEM_read_bio_PrivateKey(bio, nullptr, cb, (void *) &passwd);
if (!evp_key) {
//尝试p12格式
BIO_reset(bio);
PKCS12 *p12 = d2i_PKCS12_bio(bio, nullptr);
if (!p12) {
return nullptr;
}
X509 *x509 = nullptr;
PKCS12_parse(p12, passwd.data(), &evp_key, &x509, nullptr);
PKCS12_free(p12);
if (x509) {
X509_free(x509);
}
if (!evp_key) {
return nullptr;
}
}
return shared_ptr<EVP_PKEY>(evp_key, [](EVP_PKEY *ptr) {
EVP_PKEY_free(ptr);
});
#else
return nullptr;
#endif //defined(ENABLE_OPENSSL)
}
shared_ptr<SSL_CTX> SSLUtil::makeSSLContext(const vector<shared_ptr<X509> > &cers, const shared_ptr<EVP_PKEY> &key, bool serverMode, bool checkKey) {
#if defined(ENABLE_OPENSSL)
SSL_CTX *ctx = SSL_CTX_new(serverMode ? SSLv23_server_method() : SSLv23_client_method());
if (!ctx) {
WarnL << "SSL_CTX_new " << (serverMode ? "SSLv23_server_method" : "SSLv23_client_method") << " failed: " << getLastError();
return nullptr;
}
int i = 0;
for (auto &cer : cers) {
//加载公钥
if (i++ == 0) {
//SSL_CTX_use_certificate内部会调用X509_up_ref,所以这里不用X509_dup
SSL_CTX_use_certificate(ctx, cer.get());
} else {
//需要先拷贝X509对象否则指针会失效
SSL_CTX_add_extra_chain_cert(ctx, X509_dup(cer.get()));
}
}
if (key) {
//提供了私钥
if (SSL_CTX_use_PrivateKey(ctx, key.get()) != 1) {
WarnL << "SSL_CTX_use_PrivateKey failed: " << getLastError();
SSL_CTX_free(ctx);
return nullptr;
}
}
if (key || checkKey) {
//加载私钥成功
if (SSL_CTX_check_private_key(ctx) != 1) {
WarnL << "SSL_CTX_check_private_key failed: " << getLastError();
SSL_CTX_free(ctx);
return nullptr;
}
}
//公钥私钥匹配或者没有公私钥
return shared_ptr<SSL_CTX>(ctx, [](SSL_CTX *ptr) { SSL_CTX_free(ptr); });
#else
return nullptr;
#endif //defined(ENABLE_OPENSSL)
}
shared_ptr<SSL> SSLUtil::makeSSL(SSL_CTX *ctx) {
#if defined(ENABLE_OPENSSL)
auto *ssl = SSL_new(ctx);
if (!ssl) {
return nullptr;
}
return shared_ptr<SSL>(ssl, [](SSL *ptr) {
SSL_free(ptr);
});
#else
return nullptr;
#endif //defined(ENABLE_OPENSSL)
}
bool SSLUtil::loadDefaultCAs(SSL_CTX *ctx) {
#if defined(ENABLE_OPENSSL)
if (!ctx) {
return false;
}
if (SSL_CTX_set_default_verify_paths(ctx) != 1) {
WarnL << "SSL_CTX_set_default_verify_paths failed: " << getLastError();
return false;
}
return true;
#else
return false;
#endif //defined(ENABLE_OPENSSL)
}
bool SSLUtil::trustCertificate(SSL_CTX *ctx, X509 *cer) {
#if defined(ENABLE_OPENSSL)
X509_STORE *store = SSL_CTX_get_cert_store(ctx);
if (store && cer) {
if (X509_STORE_add_cert(store, cer) != 1) {
WarnL << "X509_STORE_add_cert failed: " << getLastError();
return false;
}
return true;
}
#endif //defined(ENABLE_OPENSSL)
return false;
}
bool SSLUtil::verifyX509(X509 *cer, ...) {
#if defined(ENABLE_OPENSSL)
va_list args;
va_start(args, cer);
X509_STORE *store = X509_STORE_new();
do {
X509 *ca;
if ((ca = va_arg(args, X509*)) == nullptr) {
break;
}
X509_STORE_add_cert(store, ca);
} while (true);
va_end(args);
X509_STORE_CTX *store_ctx = X509_STORE_CTX_new();
X509_STORE_CTX_init(store_ctx, store, cer, nullptr);
auto ret = X509_verify_cert(store_ctx);
if (ret != 1) {
int depth = X509_STORE_CTX_get_error_depth(store_ctx);
int err = X509_STORE_CTX_get_error(store_ctx);
WarnL << "X509_verify_cert failed, depth: " << depth << ", err: " << X509_verify_cert_error_string(err);
}
X509_STORE_CTX_free(store_ctx);
X509_STORE_free(store);
return ret == 1;
#else
WarnL << "ENABLE_OPENSSL disabled, you can not use any features based on openssl";
return false;
#endif //defined(ENABLE_OPENSSL)
}
#ifdef ENABLE_OPENSSL
#ifndef X509_F_X509_PUBKEY_GET0
EVP_PKEY *X509_get0_pubkey(X509 *x){
EVP_PKEY *ret = X509_get_pubkey(x);
if(ret){
EVP_PKEY_free(ret);
}
return ret;
}
#endif //X509_F_X509_PUBKEY_GET0
#ifndef EVP_F_EVP_PKEY_GET0_RSA
RSA *EVP_PKEY_get0_RSA(EVP_PKEY *pkey){
RSA *ret = EVP_PKEY_get1_RSA(pkey);
if(ret){
RSA_free(ret);
}
return ret;
}
#endif //EVP_F_EVP_PKEY_GET0_RSA
#endif //ENABLE_OPENSSL
string SSLUtil::cryptWithRsaPublicKey(X509 *cer, const string &in_str, bool enc_or_dec) {
#if defined(ENABLE_OPENSSL)
EVP_PKEY *public_key = X509_get0_pubkey(cer);
if (!public_key) {
return "";
}
auto rsa = EVP_PKEY_get1_RSA(public_key);
if (!rsa) {
return "";
}
string out_str(RSA_size(rsa), '\0');
int ret = 0;
if (enc_or_dec) {
ret = RSA_public_encrypt(in_str.size(), (uint8_t *) in_str.data(), (uint8_t *) out_str.data(), rsa,
RSA_PKCS1_PADDING);
} else {
ret = RSA_public_decrypt(in_str.size(), (uint8_t *) in_str.data(), (uint8_t *) out_str.data(), rsa,
RSA_PKCS1_PADDING);
}
if (ret > 0) {
out_str.resize(ret);
return out_str;
}
WarnL << (enc_or_dec ? "RSA_public_encrypt" : "RSA_public_decrypt") << " failed: " << getLastError();
return "";
#else
WarnL << "ENABLE_OPENSSL disabled, you can not use any features based on openssl";
return "";
#endif //defined(ENABLE_OPENSSL)
}
string SSLUtil::cryptWithRsaPrivateKey(EVP_PKEY *private_key, const string &in_str, bool enc_or_dec) {
#if defined(ENABLE_OPENSSL)
auto rsa = EVP_PKEY_get1_RSA(private_key);
if (!rsa) {
return "";
}
string out_str(RSA_size(rsa), '\0');
int ret = 0;
if (enc_or_dec) {
ret = RSA_private_encrypt(in_str.size(), (uint8_t *) in_str.data(), (uint8_t *) out_str.data(), rsa,
RSA_PKCS1_PADDING);
} else {
ret = RSA_private_decrypt(in_str.size(), (uint8_t *) in_str.data(), (uint8_t *) out_str.data(), rsa,
RSA_PKCS1_PADDING);
}
if (ret > 0) {
out_str.resize(ret);
return out_str;
}
WarnL << getLastError();
return "";
#else
WarnL << "ENABLE_OPENSSL disabled, you can not use any features based on openssl";
return "";
#endif //defined(ENABLE_OPENSSL)
}
string SSLUtil::getServerName(X509 *cer) {
#if defined(ENABLE_OPENSSL) && defined(SSL_CTRL_SET_TLSEXT_HOSTNAME)
if (!cer) {
return "";
}
//获取证书里的域名
X509_NAME *name = X509_get_subject_name(cer);
char ret[256] = {0};
X509_NAME_get_text_by_NID(name, NID_commonName, ret, sizeof(ret));
return ret;
#else
return "";
#endif
}
}//namespace toolkit