From ad107547793a17c097c184a9f3da0138d2db35c2 Mon Sep 17 00:00:00 2001 From: Henrique Date: Mon, 30 May 2011 00:11:12 -0300 Subject: [PATCH] init protocol login --- src/framework/net/connection.cpp | 91 +++++++++++++++++++-- src/framework/net/connection.h | 14 +++- src/framework/net/inputmessage.cpp | 2 +- src/framework/net/inputmessage.h | 16 +++- src/framework/net/outputmessage.cpp | 12 ++- src/framework/net/outputmessage.h | 18 ++++- src/framework/net/protocol.cpp | 120 +++++++++++++++++++++++++++- src/framework/net/protocol.h | 20 +++++ src/main.cpp | 5 ++ src/protocollogin.cpp | 101 +++++++++++++++++++++-- src/protocollogin.h | 9 ++- 11 files changed, 379 insertions(+), 29 deletions(-) diff --git a/src/framework/net/connection.cpp b/src/framework/net/connection.cpp index b838b54e..af6093f3 100644 --- a/src/framework/net/connection.cpp +++ b/src/framework/net/connection.cpp @@ -40,16 +40,26 @@ void Connection::poll() ioService.reset(); } -void Connection::connect(const std::string& host, uint16 port, const SimpleCallback& callback) +void Connection::connect(const std::string& host, uint16 port, const SimpleCallback& connectCallback) { - m_connectCallback = callback; + m_connectCallback = connectCallback; m_connectionState = CONNECTION_STATE_RESOLVING; boost::asio::ip::tcp::resolver::query query(host, convert_cast(port)); - m_resolver.async_resolve(query, boost::bind(&Connection::onResolve, this, boost::asio::placeholders::error, boost::asio::placeholders::iterator)); + m_resolver.async_resolve(query, boost::bind(&Connection::onResolve, shared_from_this(), boost::asio::placeholders::error, boost::asio::placeholders::iterator)); m_timer.expires_from_now(boost::posix_time::seconds(2)); - m_timer.async_wait(boost::bind(&Connection::onTimeout, this, boost::asio::placeholders::error)); + m_timer.async_wait(boost::bind(&Connection::onTimeout, shared_from_this(), boost::asio::placeholders::error)); +} + +void Connection::send(OutputMessage *outputMessage) +{ + boost::asio::async_write(m_socket, + boost::asio::buffer(outputMessage->getBuffer(), outputMessage->getMessageSize()), + boost::bind(&Connection::onSend, shared_from_this(), boost::asio::placeholders::error, boost::asio::placeholders::bytes_transferred)); + + m_timer.expires_from_now(boost::posix_time::seconds(2)); + m_timer.async_wait(boost::bind(&Connection::onTimeout, shared_from_this(), boost::asio::placeholders::error)); } void Connection::onTimeout(const boost::system::error_code& error) @@ -65,14 +75,15 @@ void Connection::onResolve(const boost::system::error_code& error, boost::asio:: m_timer.cancel(); if(error) { - g_dispatcher.addTask(boost::bind(m_errorCallback, error)); + if(m_errorCallback) + g_dispatcher.addTask(boost::bind(m_errorCallback, error)); return; } - m_socket.async_connect(*endpointIterator, boost::bind(&Connection::onConnect, this, boost::asio::placeholders::error)); + m_socket.async_connect(*endpointIterator, boost::bind(&Connection::onConnect, shared_from_this(), boost::asio::placeholders::error)); m_timer.expires_from_now(boost::posix_time::seconds(2)); - m_timer.async_wait(boost::bind(&Connection::onTimeout, this, boost::asio::placeholders::error)); + m_timer.async_wait(boost::bind(&Connection::onTimeout, shared_from_this(), boost::asio::placeholders::error)); } void Connection::onConnect(const boost::system::error_code& error) @@ -82,9 +93,73 @@ void Connection::onConnect(const boost::system::error_code& error) m_timer.cancel(); if(error) { - g_dispatcher.addTask(boost::bind(m_errorCallback, error)); + if(m_errorCallback) + g_dispatcher.addTask(boost::bind(m_errorCallback, error)); return; } g_dispatcher.addTask(m_connectCallback); + + // Start listening. + InputMessage *inputMessage = new InputMessage; + boost::asio::async_read(m_socket, + boost::asio::buffer(inputMessage->getBuffer(), InputMessage::HEADER_LENGTH), + boost::bind(&Connection::onRecvHeader, shared_from_this(), boost::asio::placeholders::error, inputMessage)); +} + +void Connection::onSend(const boost::system::error_code& error, size_t) +{ + logTrace(); + + m_timer.cancel(); + + if(error) { + if(m_errorCallback) + g_dispatcher.addTask(boost::bind(m_errorCallback, error)); + return; + } +} + +void Connection::onRecvHeader(const boost::system::error_code& error, InputMessage *inputMessage) +{ + logTrace(); + + if(error) { + if(m_errorCallback) + g_dispatcher.addTask(boost::bind(m_errorCallback, error)); + return; + } + + uint16 messageSize = inputMessage->getU16(); + inputMessage->setMessageSize(messageSize); + + boost::asio::async_read(m_socket, + boost::asio::buffer(inputMessage->getBuffer() + InputMessage::CHECKSUM_POS, messageSize), + boost::bind(&Connection::onRecvData, shared_from_this(), boost::asio::placeholders::error, inputMessage)); +} + +void Connection::onRecvData(const boost::system::error_code& error, InputMessage *inputMessage) +{ + logTrace(); + + if(error) { + if(m_errorCallback) + g_dispatcher.addTask(boost::bind(m_errorCallback, error)); + return; + } + + // call callback + if(m_recvCallback) + g_dispatcher.addTask(boost::bind(m_recvCallback, inputMessage)); + + // FIXME: + // TODO declare inside class? call onRecvHeader. + // this needs a remake + /*delete inputMessage; + + inputMessage = new InputMessage; + boost::asio::async_read(m_socket, + boost::asio::buffer(inputMessage->getBuffer(), InputMessage::HEADER_LENGTH), + boost::bind(&Connection::onRecvHeader, shared_from_this(), boost::asio::placeholders::error, inputMessage));*/ + } diff --git a/src/framework/net/connection.h b/src/framework/net/connection.h index 356f7613..a0ca43dc 100644 --- a/src/framework/net/connection.h +++ b/src/framework/net/connection.h @@ -25,24 +25,33 @@ #ifndef CONNECTION_H #define CONNECTION_H +#include +#include #include #include typedef boost::function ErrorCallback; +typedef boost::function RecvCallback; -class Connection +class Connection : public boost::enable_shared_from_this, boost::noncopyable { public: Connection(); static void poll(); - void connect(const std::string& host, uint16 port, const SimpleCallback& callback); + void connect(const std::string& host, uint16 port, const SimpleCallback& connectCallback); + void send(OutputMessage *outputMessage); + void setErrorCallback(const ErrorCallback& errorCallback) { m_errorCallback = errorCallback; } + void setRecvCallback(const RecvCallback& recvCallback) { m_recvCallback = recvCallback; } void onTimeout(const boost::system::error_code& error); void onResolve(const boost::system::error_code& error, boost::asio::ip::tcp::resolver::iterator endpointIterator); void onConnect(const boost::system::error_code& error); + void onSend(const boost::system::error_code& error, size_t); + void onRecvHeader(const boost::system::error_code& error, InputMessage *inputMessage); + void onRecvData(const boost::system::error_code& error, InputMessage *inputMessage); enum ConnectionState_t { CONNECTION_STATE_IDLE = 0, @@ -53,6 +62,7 @@ public: private: ErrorCallback m_errorCallback; + RecvCallback m_recvCallback; SimpleCallback m_connectCallback; ConnectionState_t m_connectionState; diff --git a/src/framework/net/inputmessage.cpp b/src/framework/net/inputmessage.cpp index d779a122..5b7b72b0 100644 --- a/src/framework/net/inputmessage.cpp +++ b/src/framework/net/inputmessage.cpp @@ -32,7 +32,7 @@ InputMessage::InputMessage() void InputMessage::reset() { m_readPos = 0; - m_messageSize = 0; + m_messageSize = 2; } uint8 InputMessage::getU8() diff --git a/src/framework/net/inputmessage.h b/src/framework/net/inputmessage.h index a81169ba..3404f9b8 100644 --- a/src/framework/net/inputmessage.h +++ b/src/framework/net/inputmessage.h @@ -32,8 +32,12 @@ class InputMessage public: enum { BUFFER_MAXSIZE = 256, + HEADER_POS = 0, HEADER_LENGTH = 2, - CHECKSUM_LENGTH = 4 + CHECKSUM_POS = 2, + CHECKSUM_LENGTH = 4, + DATA_POS = 6, + UNENCRYPTED_DATA_POS = 8 }; InputMessage(); @@ -46,13 +50,17 @@ public: uint64 getU64(); std::string getString(); + uint8 *getBuffer() { return m_buffer; } + uint16 getMessageSize() { return m_messageSize; } + void setMessageSize(uint16 messageSize) { m_messageSize = messageSize; } + bool end() { return m_readPos == m_messageSize; } private: bool canRead(int bytes); - uint16_t m_readPos; - uint16_t m_messageSize; - uint8_t m_buffer[BUFFER_MAXSIZE]; + uint16 m_readPos; + uint16 m_messageSize; + uint8 m_buffer[BUFFER_MAXSIZE]; }; #endif diff --git a/src/framework/net/outputmessage.cpp b/src/framework/net/outputmessage.cpp index 300b4e68..dc0cdba1 100644 --- a/src/framework/net/outputmessage.cpp +++ b/src/framework/net/outputmessage.cpp @@ -31,7 +31,7 @@ OutputMessage::OutputMessage() void OutputMessage::reset() { - m_writePos = 0; + m_writePos = DATA_POS; m_messageSize = 0; } @@ -92,6 +92,16 @@ void OutputMessage::addString(const std::string &value) addString(value.c_str()); } +void OutputMessage::addPaddingBytes(int bytes, uint8 byte) +{ + if(!canWrite(bytes)) + return; + + memset((void*)&m_buffer[m_writePos], byte, bytes); + m_writePos += bytes; + m_messageSize += bytes; +} + bool OutputMessage::canWrite(int bytes) { return (m_writePos + bytes <= BUFFER_MAXSIZE); diff --git a/src/framework/net/outputmessage.h b/src/framework/net/outputmessage.h index c3533c65..5721d4fa 100644 --- a/src/framework/net/outputmessage.h +++ b/src/framework/net/outputmessage.h @@ -31,7 +31,12 @@ class OutputMessage { public: enum { - BUFFER_MAXSIZE = 1024 + BUFFER_MAXSIZE = 1024, + HEADER_POS = 0, + HEADER_LENGTH = 2, + CHECKSUM_POS = 2, + CHECKSUM_LENGTH = 4, + DATA_POS = 6 }; OutputMessage(); @@ -44,13 +49,18 @@ public: void addU64(uint64 value); void addString(const char* value); void addString(const std::string &value); + void addPaddingBytes(int bytes, uint8 byte = 0); + + uint8 *getBuffer() { return m_buffer; } + uint16 getMessageSize() { return m_messageSize; } + void setWritePos(uint16 writePos) { m_writePos = writePos; } private: bool canWrite(int bytes); - uint16_t m_writePos; - uint16_t m_messageSize; - uint8_t m_buffer[BUFFER_MAXSIZE]; + uint16 m_writePos; + uint16 m_messageSize; + uint8 m_buffer[BUFFER_MAXSIZE]; }; #endif diff --git a/src/framework/net/protocol.cpp b/src/framework/net/protocol.cpp index 01ab5764..d0adde0d 100644 --- a/src/framework/net/protocol.cpp +++ b/src/framework/net/protocol.cpp @@ -28,6 +28,8 @@ Protocol::Protocol() : m_connection(new Connection) { m_connection->setErrorCallback(boost::bind(&Protocol::onError, this, _1)); + m_connection->setRecvCallback(boost::bind(&Protocol::onRecv, this, _1)); + m_xteaEncryptionEnabled = false; } void Protocol::connect(const std::string& host, uint16 port, const SimpleCallback& callback) @@ -35,12 +37,128 @@ void Protocol::connect(const std::string& host, uint16 port, const SimpleCallbac m_connection->connect(host, port, callback); } +void Protocol::send(OutputMessage *outputMessage) +{ + // Encrypt + if(m_xteaEncryptionEnabled) + xteaEncrypt(outputMessage); + + // Set checksum + uint32 checksum = getAdlerChecksum(outputMessage->getBuffer() + OutputMessage::DATA_POS, outputMessage->getMessageSize()); + outputMessage->setWritePos(OutputMessage::CHECKSUM_POS); + outputMessage->addU32(checksum); + + // Set size + uint16 messageSize = outputMessage->getMessageSize(); + outputMessage->setWritePos(OutputMessage::HEADER_POS); + outputMessage->addU16(messageSize); + + // Send + m_connection->send(outputMessage); +} + +void Protocol::onRecv(InputMessage *inputMessage) +{ + uint32 checksum = getAdlerChecksum(inputMessage->getBuffer() + InputMessage::DATA_POS, inputMessage->getMessageSize() - InputMessage::CHECKSUM_LENGTH); + if(inputMessage->getU32() != checksum) { + // error + logError("Checksum is invalid."); + return; + } + + if(m_xteaEncryptionEnabled) + xteaDecrypt(inputMessage); +} + void Protocol::onError(const boost::system::error_code& error) { - flogError("PROTOCOL ERROR: ", error.message()); + flogError("PROTOCOL ERROR: %s", error.message()); // invalid hostname // connection timeouted // displays a dialog, finish protocol } + +bool Protocol::xteaDecrypt(InputMessage *inputMessage) +{ + // FIXME: this function has not been tested yet + uint16 messageSize = inputMessage->getMessageSize() - InputMessage::CHECKSUM_LENGTH; + if(messageSize % 8 != 0) { + //LOG_TRACE_DEBUG("not valid encrypted message size") + return false; + } + + uint32 *buffer = (uint32*)(inputMessage->getBuffer() + InputMessage::DATA_POS); + int readPos = 0; + + while(readPos < messageSize/4) { + uint32 v0 = buffer[readPos], v1 = buffer[readPos + 1]; + uint32 delta = 0x61C88647; + uint32 sum = 0xC6EF3720; + + for(int32 i = 0; i < 32; i++) { + v1 -= ((v0 << 4 ^ v0 >> 5) + v0) ^ (sum + m_xteaKey[sum>>11 & 3]); + sum += delta; + v0 -= ((v1 << 4 ^ v1 >> 5) + v1) ^ (sum + m_xteaKey[sum & 3]); + } + buffer[readPos] = v0; buffer[readPos + 1] = v1; + readPos = readPos + 2; + } + + int tmp = inputMessage->getU16(); + if(tmp > inputMessage->getMessageSize() - 4) { + //LOG_TRACE_DEBUG("not valid unencrypted message size") + return false; + } + + inputMessage->setMessageSize(tmp + InputMessage::UNENCRYPTED_DATA_POS); + return true; +} + +void Protocol::xteaEncrypt(OutputMessage *outputMessage) +{ + uint16 messageLength = outputMessage->getMessageSize(); + + //add bytes until reach 8 multiple + if((messageLength % 8) != 0) { + uint16 n = 8 - (messageLength % 8); + outputMessage->addPaddingBytes(n); + messageLength += n; + } + + int readPos = 0; + uint32 *buffer = (uint32*)outputMessage->getBuffer() + OutputMessage::DATA_POS; + while(readPos < messageLength / 4) { + uint32 v0 = buffer[readPos], v1 = buffer[readPos + 1]; + uint32 delta = 0x61C88647; + uint32 sum = 0; + + for(int32 i = 0; i < 32; i++) { + v0 += ((v1 << 4 ^ v1 >> 5) + v1) ^ (sum + m_xteaKey[sum & 3]); + sum -= delta; + v1 += ((v0 << 4 ^ v0 >> 5) + v0) ^ (sum + m_xteaKey[sum>>11 & 3]); + } + buffer[readPos] = v0; buffer[readPos + 1] = v1; + readPos = readPos + 2; + } +} + +uint32 Protocol::getAdlerChecksum(uint8 *buffer, uint16 size) +{ + uint32 a = 1, b = 0; + while (size > 0) { + size_t tlen = size > 5552 ? 5552 : size; + size -= tlen; + do { + a += *buffer++; + b += a; + } while (--tlen); + + a %= 65521; + b %= 65521; + } + + return (b << 16) | a; +} + diff --git a/src/framework/net/protocol.h b/src/framework/net/protocol.h index 45727a9f..9801e89c 100644 --- a/src/framework/net/protocol.h +++ b/src/framework/net/protocol.h @@ -26,6 +26,16 @@ #define PROTOCOL_H #include +#include +#include + +#define CIPSOFT_PUBLIC_RSA "1321277432058722840622950990822933849527763264961655079678763618" \ + "4334395343554449668205332383339435179772895415509701210392836078" \ + "6959821132214473291575712138800495033169914814069637740318278150" \ + "2907336840325241747827401343576296990629870233111328210165697754" \ + "88792221429527047321331896351555606801473202394175817" + +//#define RSA "109120132967399429278860960508995541528237502902798129123468757937266291492576446330739696001110603907230888610072655818825358503429057592827629436413108566029093628212635953836686562675849720620786279431090218017681061521755056710823876476444260558147179707119674283982419152118103759076030616683978566631413" class Protocol { @@ -33,10 +43,20 @@ public: Protocol(); void connect(const std::string& host, uint16 port, const SimpleCallback& callback); + void send(OutputMessage *outputMessage); + virtual void onRecv(InputMessage *inputMessage); virtual void onError(const boost::system::error_code& error); +protected: + uint32 m_xteaKey[4]; + bool m_xteaEncryptionEnabled; + private: + bool xteaDecrypt(InputMessage *inputMessage); + void xteaEncrypt(OutputMessage *outputMessage); + uint32 getAdlerChecksum(uint8 *buffer, uint16 size); + ConnectionPtr m_connection; }; diff --git a/src/main.cpp b/src/main.cpp index c8ad2028..ed21a90a 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -32,6 +32,8 @@ #include