From 4e0642c814281ffe53864f838cbb1cc99bce8188 Mon Sep 17 00:00:00 2001 From: Henrique Santiago Date: Fri, 3 Aug 2012 13:01:45 -0300 Subject: [PATCH] Removed GMP and added OpenSSL. Rsa class might be moved to Crypto later. --- src/framework/CMakeLists.txt | 6 +- src/framework/luafunctions.cpp | 2 + src/framework/net/inputmessage.cpp | 7 +- src/framework/net/inputmessage.h | 2 +- src/framework/net/outputmessage.cpp | 3 +- src/framework/util/crypt.cpp | 45 ++++++++++ src/framework/util/crypt.h | 2 + src/framework/util/rsa.cpp | 132 ++++++++++------------------ src/framework/util/rsa.h | 22 ++++- 9 files changed, 123 insertions(+), 98 deletions(-) diff --git a/src/framework/CMakeLists.txt b/src/framework/CMakeLists.txt index 1a1aa3b0..8b2047fb 100644 --- a/src/framework/CMakeLists.txt +++ b/src/framework/CMakeLists.txt @@ -187,14 +187,14 @@ endif() message(STATUS "LuaJIT: " ${LUAJIT}) find_package(PhysFS REQUIRED) -find_package(GMP REQUIRED) +find_package(OpenSSL REQUIRED) find_package(ZLIB REQUIRED) set(framework_LIBRARIES ${framework_LIBRARIES} ${Boost_LIBRARIES} ${LUA_LIBRARY} ${PHYSFS_LIBRARY} - ${GMP_LIBRARY} + ${OPENSSL_LIBRARIES} ${ZLIB_LIBRARY} ) @@ -203,7 +203,7 @@ set(framework_INCLUDE_DIRS ${framework_INCLUDE_DIRS} ${OPENGL_INCLUDE_DIR} ${LUA_INCLUDE_DIR} ${PHYSFS_INCLUDE_DIR} - ${GMP_INCLUDE_DIR} + ${OpenSSL_INCLUDE_DIR} ${ZLIB_INCLUDE_DIR} ) diff --git a/src/framework/luafunctions.cpp b/src/framework/luafunctions.cpp index 2aea5485..d291d91e 100644 --- a/src/framework/luafunctions.cpp +++ b/src/framework/luafunctions.cpp @@ -85,6 +85,8 @@ void Application::registerLuaFunctions() g_lua.registerSingletonClass("g_crypt"); g_lua.bindClassStaticFunction("g_crypt", "encrypt", Crypt::encrypt); g_lua.bindClassStaticFunction("g_crypt", "decrypt", Crypt::decrypt); + g_lua.bindClassStaticFunction("g_crypt", "sha1Encode", Crypt::sha1Encode); + g_lua.bindClassStaticFunction("g_crypt", "md5Encode", Crypt::md5Encode); // Clock g_lua.registerSingletonClass("g_clock"); diff --git a/src/framework/net/inputmessage.cpp b/src/framework/net/inputmessage.cpp index f43b5c46..334edec8 100644 --- a/src/framework/net/inputmessage.cpp +++ b/src/framework/net/inputmessage.cpp @@ -86,10 +86,13 @@ std::string InputMessage::getString() return std::string(v, stringLength); } -bool InputMessage::decryptRsa(int size, const std::string& p, const std::string& q, const std::string& d) +bool InputMessage::decryptRsa(int size, const std::string& key, const std::string& p, const std::string& q, const std::string& d) { checkRead(size); - RSA::decrypt((char*)m_buffer + m_readPos, size, p.c_str(), q.c_str(), d.c_str()); + g_rsa.setPublic(key.c_str(), "65537"); + g_rsa.setPrivate(p.c_str(), q.c_str(), d.c_str()); + g_rsa.check(); + g_rsa.decrypt((unsigned char*)m_buffer + m_readPos, size); return (getU8() == 0x00); } diff --git a/src/framework/net/inputmessage.h b/src/framework/net/inputmessage.h index ca911144..7daf3ee2 100644 --- a/src/framework/net/inputmessage.h +++ b/src/framework/net/inputmessage.h @@ -52,7 +52,7 @@ public: uint32 peekU32() { uint32 v = getU32(); m_readPos-=4; return v; } uint64 peekU64() { uint64 v = getU64(); m_readPos-=8; return v; } - bool decryptRsa(int size, const std::string& p, const std::string& q, const std::string& d); + bool decryptRsa(int size, const std::string& key, const std::string& p, const std::string& q, const std::string& d); int getReadSize() { return m_readPos - m_headerPos; } int getReadPos() { return m_readPos; } diff --git a/src/framework/net/outputmessage.cpp b/src/framework/net/outputmessage.cpp index f3a4725a..ca0432f2 100644 --- a/src/framework/net/outputmessage.cpp +++ b/src/framework/net/outputmessage.cpp @@ -94,7 +94,8 @@ void OutputMessage::encryptRsa(int size, const std::string& key) if(m_messageSize < size) throw stdext::exception("insufficient bytes in buffer to encrypt"); - RSA::encrypt((char*)m_buffer + m_writePos - size, size, key.c_str()); + g_rsa.setPublic(key.c_str(), "65537"); + g_rsa.encrypt((unsigned char*)m_buffer + m_writePos - size, size); } void OutputMessage::writeChecksum() diff --git a/src/framework/util/crypt.cpp b/src/framework/util/crypt.cpp index 9a1ea024..73ec1522 100644 --- a/src/framework/util/crypt.cpp +++ b/src/framework/util/crypt.cpp @@ -26,6 +26,9 @@ #include #include +#include +#include + static const std::string base64_chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; static inline bool is_base64(unsigned char c) { return (isalnum(c) || (c == '+') || (c == '/')); } @@ -159,3 +162,45 @@ std::string Crypt::decrypt(const std::string& encrypted_string) } return std::string(); } + +std::string Crypt::md5Encode(std::string decoded_string, bool upperCase) +{ + MD5_CTX c; + MD5_Init(&c); + MD5_Update(&c, decoded_string.c_str(), decoded_string.length()); + + uint8_t md[MD5_DIGEST_LENGTH]; + MD5_Final(md, &c); + + char output[(MD5_DIGEST_LENGTH << 1) + 1]; + for(int32_t i = 0; i < (int32_t)sizeof(md); ++i) + sprintf(output + (i << 1), "%.2X", md[i]); + + std::string result = output; + if(upperCase) + return result; + + std::transform(result.begin(), result.end(), result.begin(), tolower); + return result; +} + +std::string Crypt::sha1Encode(std::string decoded_string, bool upperCase) +{ + SHA_CTX c; + SHA1_Init(&c); + SHA1_Update(&c, decoded_string.c_str(), decoded_string.length()); + + uint8_t md[SHA_DIGEST_LENGTH]; + SHA1_Final(md, &c); + + char output[(SHA_DIGEST_LENGTH << 1) + 1]; + for(int32_t i = 0; i < (int32_t)sizeof(md); ++i) + sprintf(output + (i << 1), "%.2X", md[i]); + + std::string result = output; + if(upperCase) + return result; + + std::transform(result.begin(), result.end(), result.begin(), tolower); + return result; +} diff --git a/src/framework/util/crypt.h b/src/framework/util/crypt.h index 1f4d7899..0902320c 100644 --- a/src/framework/util/crypt.h +++ b/src/framework/util/crypt.h @@ -33,6 +33,8 @@ namespace Crypt { std::string genUUIDKey(); std::string encrypt(const std::string& decrypted_string); std::string decrypt(const std::string& encrypted_string); + std::string md5Encode(std::string decoded_string, bool upperCase); + std::string sha1Encode(std::string decoded_string, bool upperCase); } #endif diff --git a/src/framework/util/rsa.cpp b/src/framework/util/rsa.cpp index 4f055edc..65e87fb0 100644 --- a/src/framework/util/rsa.cpp +++ b/src/framework/util/rsa.cpp @@ -21,104 +21,62 @@ */ #include "rsa.h" -#include +#include +#include -void RSA::encrypt(char *msg, int size, const char* key) -{ - assert(size <= 128); - - mpz_t plain, c; - mpz_init2(plain, 1024); - mpz_init2(c, 1024); - - mpz_t e; - mpz_init(e); - mpz_set_ui(e,65537); - - mpz_t mod; - mpz_init2(mod, 1024); - mpz_set_str(mod, key, 10); - - mpz_import(plain, size, 1, 1, 0, 0, msg); - mpz_powm(c, plain, e, mod); +Rsa g_rsa; - size_t count = (mpz_sizeinbase(c, 2) + 7)/8; - memset(msg, 0, size - count); - mpz_export(&msg[size - count], NULL, 1, 1, 0, 0, c); - - mpz_clear(c); - mpz_clear(plain); - mpz_clear(e); - mpz_clear(mod); +Rsa::Rsa() +{ + m_rsa = RSA_new(); } -void RSA::decrypt(char *msg, int size, const char *p, const char *q, const char *d) +Rsa::~Rsa() { - assert(size <= 128); - - mpz_t mp, mq, md, u, dp, dq, mod, c, v1, v2, u2, tmp; - mpz_init2(mp, 1024); - mpz_init2(mq, 1024); - mpz_init2(md, 1024); - mpz_init2(u, 1024); - mpz_init2(dp, 1024); - mpz_init2(dq, 1024); - mpz_init2(mod, 1024); - mpz_init2(c, 1024); - mpz_init2(v1, 1024); - mpz_init2(v2, 1024); - mpz_init2(u2, 1024); - mpz_init2(tmp, 1024); - - mpz_set_str(mp, p, 10); - mpz_set_str(mq, q, 10); - mpz_set_str(md, d, 10); + RSA_free(m_rsa); +} - mpz_t pm1,qm1; - mpz_init2(pm1, 520); - mpz_init2(qm1, 520); +void Rsa::setPublic(const char *n, const char *e) +{ + BN_dec2bn(&m_rsa->n, n); + BN_dec2bn(&m_rsa->e, e); +} - mpz_sub_ui(pm1, mp, 1); - mpz_sub_ui(qm1, mq, 1); - mpz_invert(u, mp, mq); - mpz_mod(dp, md, pm1); - mpz_mod(dq, md, qm1); +void Rsa::setPrivate(const char *p, const char *q, const char *d) +{ + BN_dec2bn(&m_rsa->p, p); + BN_dec2bn(&m_rsa->q, q); + BN_dec2bn(&m_rsa->d, d); +} - mpz_mul(mod, mp, mq); +bool Rsa::check() // only used by server, that sets both public and private +{ + if(RSA_check_key(m_rsa)) { + BN_CTX *ctx = BN_CTX_new(); + BN_CTX_start(ctx); - mpz_import(c, size, 1, 1, 0, 0, msg); + BIGNUM *r1 = BN_CTX_get(ctx), *r2 = BN_CTX_get(ctx); + BN_mod(m_rsa->dmp1, m_rsa->d, r1, ctx); + BN_mod(m_rsa->dmq1, m_rsa->d, r2, ctx); - mpz_mod(tmp, c, mp); - mpz_powm(v1, tmp, dp, mp); - mpz_mod(tmp, c, mq); - mpz_powm(v2, tmp, dq, mq); - mpz_sub(u2, v2, v1); - mpz_mul(tmp, u2, u); - mpz_mod(u2, tmp, mq); - if(mpz_cmp_si(u2, 0) < 0) { - mpz_add(tmp, u2, mq); - mpz_set(u2, tmp); + BN_mod_inverse(m_rsa->iqmp, m_rsa->q, m_rsa->p, ctx); + return true; + } + else { + ERR_load_crypto_strings(); + g_logger.error(stdext::format("RSA check failed - %s", ERR_error_string(ERR_get_error(), NULL))); + return false; } - mpz_mul(tmp, u2, mp); - mpz_set_ui(c, 0); - mpz_add(c, v1, tmp); +} - size_t count = (mpz_sizeinbase(c, 2) + 7)/8; - memset(msg, 0, size - count); - mpz_export(&msg[size - count], NULL, 1, 1, 0, 0, c); +bool Rsa::encrypt(unsigned char *msg, int size) +{ + assert(size <= 128); + return RSA_public_encrypt(size, msg, msg, m_rsa, RSA_NO_PADDING) != -1; +} - mpz_clear(c); - mpz_clear(v1); - mpz_clear(v2); - mpz_clear(u2); - mpz_clear(tmp); - mpz_clear(pm1); - mpz_clear(qm1); - mpz_clear(mp); - mpz_clear(mq); - mpz_clear(md); - mpz_clear(u); - mpz_clear(dp); - mpz_clear(dq); - mpz_clear(mod); +bool Rsa::decrypt(unsigned char *msg, int size) +{ + assert(size <= 128); + return RSA_private_decrypt(size, msg, msg, m_rsa, RSA_NO_PADDING) != -1; } diff --git a/src/framework/util/rsa.h b/src/framework/util/rsa.h index d0a4b2a8..cae4bac2 100644 --- a/src/framework/util/rsa.h +++ b/src/framework/util/rsa.h @@ -24,11 +24,25 @@ #define RSA_H #include +#include -namespace RSA +class Rsa { - void encrypt(char *msg, int size, const char *key); - void decrypt(char *msg, int size, const char *p, const char *q, const char *d); +public: + Rsa(); + ~Rsa(); + + void setPublic(const char *n, const char *e); + void setPrivate(const char *p, const char *q, const char *d); + bool check(); + + bool encrypt(unsigned char *msg, int size); + bool decrypt(unsigned char *msg, int size); + +private: + RSA *m_rsa; }; -#endif \ No newline at end of file +extern Rsa g_rsa; + +#endif