From d13c1e6d6812aa568b6cc7215100828016b92582 Mon Sep 17 00:00:00 2001 From: WolverinDEV Date: Mon, 27 Jan 2020 02:21:39 +0100 Subject: [PATCH] A lot of updates --- CMakeLists.txt | 10 +- src/Properties.h | 4 +- src/converters/converter.h | 8 +- src/misc/endianness.h | 2 +- src/protocol/AcknowledgeManager.cpp | 10 +- src/protocol/AcknowledgeManager.h | 2 +- src/protocol/CompressionHandler.cpp | 107 +++-- src/protocol/CompressionHandler.h | 7 + src/protocol/CryptHandler.cpp | 321 +++++++++++++++ .../{CryptionHandler.h => CryptHandler.h} | 42 +- src/protocol/CryptionHandler.cpp | 380 ------------------ src/protocol/Packet.cpp | 24 ++ src/protocol/Packet.h | 43 ++ src/protocol/generation.cpp | 36 ++ src/protocol/generation.h | 26 ++ src/query/command3.h | 4 +- src/query/command_handler.h | 2 +- test/generationTest.cpp | 105 +++++ 18 files changed, 690 insertions(+), 443 deletions(-) create mode 100644 src/protocol/CryptHandler.cpp rename src/protocol/{CryptionHandler.h => CryptHandler.h} (58%) delete mode 100644 src/protocol/CryptionHandler.cpp create mode 100644 src/protocol/generation.cpp create mode 100644 src/protocol/generation.h create mode 100644 test/generationTest.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index b75f62a..584f8ca 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -102,6 +102,7 @@ set(SOURCE_FILES src/query/Command.cpp src/query/escape.cpp + src/protocol/generation.cpp src/protocol/Packet.cpp src/protocol/buffers.cpp src/protocol/buffers_allocator_c.cpp @@ -109,7 +110,7 @@ set(SOURCE_FILES src/Properties.cpp src/BasicChannel.cpp src/Error.cpp - src/protocol/CryptionHandler.cpp + src/protocol/CryptHandler.cpp src/protocol/CompressionHandler.cpp src/Variable.cpp src/linked_helper.cpp @@ -148,7 +149,7 @@ set(HEADER_FILES src/BasicChannel.h src/Definitions.h src/Error.h - src/protocol/CryptionHandler.h + src/protocol/CryptHandler.h src/Variable.h src/misc/queue.h @@ -269,7 +270,7 @@ if(BUILD_TESTS) add_executable(PorpertyTest test/PropertyTest.cpp ${SOURCE_FILES}) target_link_libraries(PorpertyTest ${TEST_LIBRARIES}) - add_executable(BBTest test/BBTest.cpp ${SOURCE_FILES} src/query/command_unused.h src/converters/converter.cpp src/query/command_handler.h) + add_executable(BBTest test/BBTest.cpp ${SOURCE_FILES} src/query/command_unused.h) target_link_libraries(BBTest ${TEST_LIBRARIES}) add_executable(LinkedTest test/LinkedTest.cpp ${SOURCE_FILES}) @@ -277,5 +278,8 @@ if(BUILD_TESTS) add_executable(PermissionTest test/PermissionTest.cpp ${SOURCE_FILES}) target_link_libraries(PermissionTest ${TEST_LIBRARIES}) + + add_executable(GenerationTest test/generationTest.cpp ${SOURCE_FILES} ../server/MySQLLibSSLFix.c) + target_link_libraries(GenerationTest ${TEST_LIBRARIES}) endif() endif() diff --git a/src/Properties.h b/src/Properties.h index ea4689b..7fd76ec 100644 --- a/src/Properties.h +++ b/src/Properties.h @@ -636,7 +636,7 @@ namespace ts { if(data_ptr->casted_value.type() == typeid(T)) return std::any_cast(data_ptr->casted_value); - data_ptr->casted_value = ts::converter::from_string(data_ptr->value); + data_ptr->casted_value = ts::converter::from_string_view(std::string_view{data_ptr->value}); return std::any_cast(data_ptr->casted_value); } }; @@ -696,7 +696,7 @@ namespace ts { if(this->data_ptr->casted_value.type() == typeid(T)) return std::any_cast(this->data_ptr->casted_value); - this->data_ptr->casted_value = ts::converter::from_string(this->data_ptr->value); + this->data_ptr->casted_value = ts::converter::from_string_view(this->data_ptr->value); return std::any_cast(this->data_ptr->casted_value); } catch(std::exception&) { return 0; diff --git a/src/converters/converter.h b/src/converters/converter.h index aba21b7..58330b7 100644 --- a/src/converters/converter.h +++ b/src/converters/converter.h @@ -13,7 +13,7 @@ namespace ts { static constexpr bool supported = false; static constexpr std::string(*to_string)(const std::any&) = nullptr; - static constexpr T(*from_string)(const std::string&) = nullptr; + static constexpr T(*from_string_view)(const std::string_view&) = nullptr; }; #define DECLARE_CONVERTER(type, decode, encode) \ @@ -22,7 +22,7 @@ namespace ts { static constexpr bool supported = true; \ \ static constexpr std::string(*to_string)(const std::any&) = encode; \ - static constexpr type(*from_string)(const std::string_view&) = decode; \ + static constexpr type(*from_string_view)(const std::string_view&) = decode; \ }; #define CONVERTER_METHOD_DECODE(type, name) type name(const std::string_view& str) @@ -79,8 +79,8 @@ namespace ts { static constexpr std::string(*to_string)(const std::any&) = [](const std::any& val) { \ return std::to_string(std::any_cast(val)); \ }; \ - static constexpr class(*from_string)(const std::string&) = [](const std::string& val) { \ - return ((class(*)(const std::string&)) ts::converter::from_string)(val); \ + static constexpr class(*from_string_view)(const std::string_view&) = [](const std::string_view& val) { \ + return ((class(*)(const std::string_view&)) ts::converter::from_string_view)(val); \ }; \ }; \ } diff --git a/src/misc/endianness.h b/src/misc/endianness.h index 016cda5..bfe6c90 100644 --- a/src/misc/endianness.h +++ b/src/misc/endianness.h @@ -24,7 +24,7 @@ template ::type, char>::value || \ std::is_same::type, unsigned char>::value \ , int>::type = 0, typename ResultType = uint ##size ##_t> \ -inline ResultType be2le ##size(BufferType* buffer,T offset = 0, T* offsetCounter = nullptr){ \ +inline ResultType be2le ##size(const BufferType* buffer,T offset = 0, T* offsetCounter = nullptr){ \ ResultType result = 0; \ convert; \ if(offsetCounter) *offsetCounter += (size) / 8; \ diff --git a/src/protocol/AcknowledgeManager.cpp b/src/protocol/AcknowledgeManager.cpp index 779f332..06a93f4 100644 --- a/src/protocol/AcknowledgeManager.cpp +++ b/src/protocol/AcknowledgeManager.cpp @@ -59,13 +59,15 @@ void AcknowledgeManager::process_packet(ts::protocol::BasicPacket &packet) { } } -bool AcknowledgeManager::process_acknowledge(const ts::protocol::BasicPacket &packet, std::string& error) { +bool AcknowledgeManager::process_acknowledge(uint8_t packet_type, const pipes::buffer_view& payload, std::string& error) { + if(payload.length() < 2) return false; + PacketType target_type = PacketType::UNDEFINED; uint16_t target_id = 0; - if(packet.type().type() == PacketType::ACK_LOW) target_type = PacketType::COMMAND_LOW; - else if(packet.type().type() == PacketType::ACK) target_type = PacketType::COMMAND; - target_id = be2le16((char*) packet.data().data_ptr()); + if(packet_type == protocol::ACK_LOW) target_type = PacketType::COMMAND_LOW; + else if(packet_type == protocol::ACK) target_type = PacketType::COMMAND; + target_id = be2le16((char*) payload.data_ptr()); //debugMessage(0, "Got ack for {} {}", target_type, target_id); if(target_type == PacketType::UNDEFINED) { diff --git a/src/protocol/AcknowledgeManager.h b/src/protocol/AcknowledgeManager.h index 1736797..52e9c82 100644 --- a/src/protocol/AcknowledgeManager.h +++ b/src/protocol/AcknowledgeManager.h @@ -31,7 +31,7 @@ namespace ts { void reset(); void process_packet(ts::protocol::BasicPacket& /* packet */); - bool process_acknowledge(const ts::protocol::BasicPacket& /* packet */, std::string& /* error */); + bool process_acknowledge(uint8_t packet_type, const pipes::buffer_view& /* payload */, std::string& /* error */); ssize_t execute_resend( const std::chrono::system_clock::time_point& /* now */, diff --git a/src/protocol/CompressionHandler.cpp b/src/protocol/CompressionHandler.cpp index ea12821..b85c460 100644 --- a/src/protocol/CompressionHandler.cpp +++ b/src/protocol/CompressionHandler.cpp @@ -10,54 +10,103 @@ using namespace ts; using namespace ts::connection; using namespace std; +namespace ts::compression { + class thread_buffer { + public: + void* get_buffer(size_t size) { + if(size > 1024 * 1024 *5) /* we don't want to keep such big buffers in memory */ + return malloc(size); + + if(this->buffer_length < size) { + free(this->buffer_ptr); + + size = std::max(size, (size_t) 1024); + this->buffer_ptr = malloc(size); + this->buffer_length = size; + } + return buffer_ptr; + } + + void free_buffer(void* ptr) { + if(ptr == this->buffer_ptr) return; + free(ptr); + } + + ~thread_buffer() { + free(this->buffer_ptr); + } + private: + void* buffer_ptr{nullptr}; + size_t buffer_length{0}; + }; + thread_local thread_buffer qlz_buffer{}; + + size_t qlz_decompressed_size(const void* payload, size_t payload_length) { + return qlz_size_decompressed((char*) payload) + 400; + } + + bool qlz_decompress_payload(const void* payload, void* buffer, size_t* buffer_size) { + assert(payload != buffer); + + qlz_state_decompress state{}; + size_t data_length = qlz_decompress((char*) payload, (char*) buffer, &state); + if(data_length <= 0) + return false; + + /* test for overflow */ + if(data_length > *buffer_size) terminate(); + *buffer_size = data_length; + return true; + } + + size_t qlz_compressed_size(const void* payload, size_t payload_length) { + //// "Always allocate size + 400 bytes for the destination buffer when compressing." <= http://www.quicklz.com/manual.html + return max(min(payload_length * 2, (size_t) (payload_length + 400ULL)), (size_t) 24ULL); /* at least 12 bytes (QLZ header) */ + } + + bool qlz_compress_payload(const void* payload, size_t payload_length, void* buffer, size_t* buffer_length) { + assert(payload != buffer); + assert(*buffer_length >= qlz_compressed_size(payload, payload_length)); + + qlz_state_compress state{}; + size_t compressed_length = qlz_compress(payload, (char*) buffer, payload_length, &state); + if(compressed_length > *buffer_length) terminate(); + + if(compressed_length <= 0) + return false; + *buffer_length = compressed_length; + return true; + } +} + bool CompressionHandler::compress(protocol::BasicPacket* packet, std::string &error) { - //// "Always allocate size + 400 bytes for the destination buffer when compressing." <= http://www.quicklz.com/manual.html auto packet_payload = packet->data(); auto header_length = packet->length() - packet_payload.length(); - size_t max_compressed_payload_size = max(min(packet_payload.length() * 2, (size_t) (packet_payload.length() + 400ULL)), (size_t) 24ULL); /* at least 12 bytes (QLZ header) */ + size_t max_compressed_payload_size = compression::qlz_compressed_size(packet_payload.data_ptr(), packet_payload.length()); auto target_buffer = buffer::allocate_buffer(max_compressed_payload_size + header_length); - qlz_state_compress state_compress{}; - size_t actual_length = qlz_compress(packet_payload.data_ptr(), (char*) &target_buffer[header_length], packet_payload.length(), &state_compress); - if(actual_length > max_compressed_payload_size) { - logCritical(0, "Buffer overflow! Compressed data is longer than expected. (Expected: {}, Written: {}, Allocated block size: {})", - max_compressed_payload_size, - actual_length, - target_buffer.capacity() - ); - error = "overflow"; - return false; - } - if(actual_length <= 0){ - error = "Cloud not compress packet"; - return false; - } + size_t compressed_size{max_compressed_payload_size}; + if(!compression::qlz_compress_payload(packet_payload.data_ptr(), packet_payload.length(), &target_buffer[header_length], &compressed_size)) return false; memcpy(target_buffer.data_ptr(), packet->buffer().data_ptr(), header_length); - packet->buffer(target_buffer.range(0, actual_length + header_length)); + packet->buffer(target_buffer.range(0, compressed_size + header_length)); return true; } bool CompressionHandler::decompress(protocol::BasicPacket* packet, std::string &error) { - qlz_state_decompress state_decompress{}; - - size_t expected_length = qlz_size_decompressed((char*) packet->data().data_ptr()); + auto expected_length = compression::qlz_decompressed_size(packet->data().data_ptr(), packet->data().length()); if(expected_length > this->max_packet_size){ //Max 16MB. (97% Compression!) error = "Invalid packet size. (Calculated target length of " + to_string(expected_length) + ". Max length: " + to_string(this->max_packet_size) + ")"; return false; } - auto header_length = packet->header().length() + packet->mac().length(); auto buffer = buffer::allocate_buffer(expected_length + header_length); - size_t data_length = qlz_decompress((char*) packet->data().data_ptr(), &buffer[header_length], &state_decompress); - if(data_length <= 0){ - error = "Could not decompress packet."; - return false; - } - memcpy(buffer.data_ptr(), packet->buffer().data_ptr(), header_length); - packet->buffer(buffer.range(0, data_length + header_length)); + size_t compressed_size{expected_length}; + if(!compression::qlz_compress_payload(packet->data().data_ptr(), packet->data().length(), &buffer[header_length], &compressed_size)) return false; + + packet->buffer(buffer.range(0, compressed_size + header_length)); return true; } diff --git a/src/protocol/CompressionHandler.h b/src/protocol/CompressionHandler.h index cae4328..b76d655 100644 --- a/src/protocol/CompressionHandler.h +++ b/src/protocol/CompressionHandler.h @@ -3,6 +3,13 @@ #include "Packet.h" namespace ts { + namespace compression { + size_t qlz_decompressed_size(const void* payload, size_t payload_length); + bool qlz_decompress_payload(const void* payload, void* buffer, size_t* buffer_size); //Attention: payload & buffer must be differen! + + size_t qlz_compressed_size(const void* payload, size_t payload_length); + bool qlz_compress_payload(const void* payload, size_t payload_length, void* buffer, size_t* buffer_length); + } namespace connection { class CompressionHandler { public: diff --git a/src/protocol/CryptHandler.cpp b/src/protocol/CryptHandler.cpp new file mode 100644 index 0000000..0fe1a41 --- /dev/null +++ b/src/protocol/CryptHandler.cpp @@ -0,0 +1,321 @@ +//#define NO_OPEN_SSL /* because we're lazy and dont want to build this lib extra for the TeaClient */ +#define FIXEDINT_H_INCLUDED /* else it will be included by ge */ + +#include "misc/endianness.h" +#include +#include +#include +#include "misc/memtracker.h" +#include "misc/digest.h" +#include "CryptHandler.h" +#include "../misc/sassert.h" + +using namespace std; +using namespace ts; +using namespace ts::connection; +using namespace ts::protocol; + + +CryptHandler::CryptHandler() { + memtrack::allocated(this); +} + +CryptHandler::~CryptHandler() { + memtrack::freed(this); +} + +void CryptHandler::reset() { + this->useDefaultChipherKeyNonce = true; + this->iv_struct_length = 0; + memset(this->iv_struct, 0, sizeof(this->iv_struct)); + memcpy(this->current_mac, CryptHandler::default_mac, sizeof(CryptHandler::default_mac)); + + for(auto& cache : this->cache_key_client) + cache.generation = 0xFFEF; + for(auto& cache : this->cache_key_server) + cache.generation = 0xFFEF; +} + +#define SHARED_KEY_BUFFER_LENGTH (256) +bool CryptHandler::setupSharedSecret(const std::string& alpha, const std::string& beta, ecc_key *publicKey, ecc_key *ownKey, std::string &error) { + size_t buffer_length = SHARED_KEY_BUFFER_LENGTH; + uint8_t buffer[SHARED_KEY_BUFFER_LENGTH]; + int err; + if((err = ecc_shared_secret(ownKey, publicKey, buffer, (unsigned long*) &buffer_length)) != CRYPT_OK){ + error = "Could not calculate shared secret. Message: " + string(error_to_string(err)); + return false; + } + + auto result = this->setupSharedSecret(alpha, beta, string((const char*) buffer, buffer_length), error); + return result; +} + +bool CryptHandler::setupSharedSecret(const std::string& alpha, const std::string& beta, const std::string& sharedKey, std::string &error) { + auto secret_hash = digest::sha1(sharedKey); + assert(secret_hash.length() == SHA_DIGEST_LENGTH); + + uint8_t iv_buffer[SHA_DIGEST_LENGTH]; + memcpy(iv_buffer, alpha.data(), 10); + memcpy(&iv_buffer[10], beta.data(), 10); + + for (int index = 0; index < SHA_DIGEST_LENGTH; index++) { + iv_buffer[index] ^= (uint8_t) secret_hash[index]; + } + + { + lock_guard lock(this->cache_key_lock); + memcpy(this->iv_struct, iv_buffer, SHA_DIGEST_LENGTH); + this->iv_struct_length = SHA_DIGEST_LENGTH; + + uint8_t mac_buffer[SHA_DIGEST_LENGTH]; + digest::sha1((const char*) iv_buffer, SHA_DIGEST_LENGTH, mac_buffer); + memcpy(this->current_mac, mac_buffer, 8); + + this->useDefaultChipherKeyNonce = false; + } + + return true; +} + +void _fe_neg(fe h, const fe f) { + int32_t f0 = f[0]; + int32_t f1 = f[1]; + int32_t f2 = f[2]; + int32_t f3 = f[3]; + int32_t f4 = f[4]; + int32_t f5 = f[5]; + int32_t f6 = f[6]; + int32_t f7 = f[7]; + int32_t f8 = f[8]; + int32_t f9 = f[9]; + int32_t h0 = -f0; + int32_t h1 = -f1; + int32_t h2 = -f2; + int32_t h3 = -f3; + int32_t h4 = -f4; + int32_t h5 = -f5; + int32_t h6 = -f6; + int32_t h7 = -f7; + int32_t h8 = -f8; + int32_t h9 = -f9; + + h[0] = h0; + h[1] = h1; + h[2] = h2; + h[3] = h3; + h[4] = h4; + h[5] = h5; + h[6] = h6; + h[7] = h7; + h[8] = h8; + h[9] = h9; +} + +inline void keyMul(uint8_t(& target_buffer)[32], const uint8_t* publicKey /* compressed */, const uint8_t* privateKey /* uncompressed */, bool negate){ + ge_p3 keyA{}; + ge_p2 result{}; + + ge_frombytes_negate_vartime(&keyA, publicKey); + if(negate) { + _fe_neg(*(fe*) &keyA.X, *(const fe*) &keyA.X); /* undo negate */ + _fe_neg(*(fe*) &keyA.T, *(const fe*) &keyA.T); /* undo negate */ + } + ge_scalarmult_vartime(&result, privateKey, &keyA); + + ge_tobytes(target_buffer, &result); +} + +bool CryptHandler::setupSharedSecretNew(const std::string &alpha, const std::string &beta, const char* privateKey /* uncompressed */, const char* publicKey /* compressed */) { + if(alpha.length() != 10 || beta.length() != 54) + return false; + + uint8_t shared[32]; + uint8_t shared_iv[64]; + + ed25519_key_exchange(shared, (uint8_t*) publicKey, (uint8_t*) privateKey); + keyMul(shared, reinterpret_cast(publicKey), reinterpret_cast(privateKey), true); //Remote key get negated + digest::sha512((char*) shared, 32, shared_iv); + + auto xor_key = alpha + beta; + for(int i = 0; i < 64; i++) + shared_iv[i] ^= (uint8_t) xor_key[i]; + + { + lock_guard lock(this->cache_key_lock); + memcpy(this->iv_struct, shared_iv, 64); + this->iv_struct_length = 64; + + uint8_t mac_buffer[SHA_DIGEST_LENGTH]; + digest::sha1((char*) this->iv_struct, 64, mac_buffer); + memcpy(this->current_mac, mac_buffer, 8); + this->useDefaultChipherKeyNonce = false; + } + + return true; +} + +#define GENERATE_BUFFER_LENGTH (128) +bool CryptHandler::generate_key_nonce( + bool to_server, /* its from the client to the server */ + uint8_t type, + uint16_t packet_id, + uint16_t generation, + CryptHandler::key_t& key, + CryptHandler::nonce_t& nonce +) { + auto& key_cache_array = to_server ? this->cache_key_client : this->cache_key_server; + if(type < 0 || type >= key_cache_array.max_size()) { + logError(0, "Tried to generate a crypt key with invalid type ({})!", type); + return false; + } + + { + std::lock_guard lock{this->cache_key_lock}; + auto& key_cache = key_cache_array[type]; + if(key_cache.generation != generation) { + const size_t buffer_length = 6 + this->iv_struct_length; + sassert(buffer_length < GENERATE_BUFFER_LENGTH); + + char buffer[GENERATE_BUFFER_LENGTH]; + memset(buffer, 0, buffer_length); + + if (to_server) { + buffer[0] = 0x31; + } else { + buffer[0] = 0x30; + } + buffer[1] = (char) (type & 0xF); + + le2be32(generation, buffer, 2); + memcpy(&buffer[6], this->iv_struct, this->iv_struct_length); + digest::sha256(buffer, buffer_length, key_cache.key_nonce); + + key_cache.generation = generation; + } + + memcpy(key.data(), key_cache.key, 16); + memcpy(nonce.data(), key_cache.nonce, 16); + } + + //Xor the key + key[0] ^= (uint8_t) ((packet_id >> 8) & 0xFFU); + key[1] ^=(packet_id & 0xFFU); + + return true; +} + +bool CryptHandler::verify_encryption(const pipes::buffer_view &packet, uint16_t packet_id, uint16_t generation) { + int err; + int success = false; + + key_t key{}; + nonce_t nonce{}; + if(!generate_key_nonce(true, (protocol::PacketType) (packet[12] & 0xF), packet_id, generation, key, nonce)) + return false; + + auto mac = packet.view(0, 8); + auto header = packet.view(8, 5); + auto data = packet.view(13); + + auto length = data.length(); + + /* static shareable void buffer */ + const static unsigned long void_target_length = 2048; + static uint8_t void_target_buffer[2048]; + if(void_target_length < length) + return false; + + //TODO: Cache find_cipher + err = eax_decrypt_verify_memory(find_cipher("rijndael"), + (uint8_t *) key.data(), /* the key */ + (size_t) key.size(), /* key is 16 bytes */ + (uint8_t *) nonce.data(), /* the nonce */ + (size_t) nonce.size(), /* nonce is 16 bytes */ + (uint8_t *) header.data_ptr(), /* example header */ + (unsigned long) header.length(), /* header length */ + (const unsigned char *) data.data_ptr(), + (unsigned long) data.length(), + (unsigned char *) void_target_buffer, + (unsigned char *) mac.data_ptr(), + (unsigned long) mac.length(), + &success + ); + + return err == CRYPT_OK && success; +} + +#define tmp_buffer_size (2048) +bool CryptHandler::decrypt(const void *header, size_t header_length, void *payload, size_t payload_length, const void *mac, const key_t &key, const nonce_t &nonce, std::string &error) { + if(tmp_buffer_size < payload_length) { + error = "buffer too large"; + return false; + } + + uint8_t tmp_buffer[tmp_buffer_size]; + int success; + + //TODO: Cache cipher + auto err = eax_decrypt_verify_memory(find_cipher("rijndael"), + (const uint8_t *) key.data(), /* the key */ + (unsigned long) key.size(), /* key is 16 bytes */ + (const uint8_t *) nonce.data(), /* the nonce */ + (unsigned long) nonce.size(), /* nonce is 16 bytes */ + (const uint8_t *) header, /* example header */ + (unsigned long) header_length, /* header length */ + (const unsigned char *) payload, + (unsigned long) payload_length, + (unsigned char *) tmp_buffer, + (unsigned char *) mac, + (unsigned long) 8, + &success + ); + if(err != CRYPT_OK) { + error = "decrypt returned " + std::string{error_to_string(err)}; + return false; + } + + if(!success) { + error = "failed to verify packet"; + return false; + } + + memcpy(payload, tmp_buffer, payload_length); + return true; +} + +bool CryptHandler::encrypt( + const void *header, size_t header_length, + void *payload, size_t payload_length, + void *mac, + const key_t &key, const nonce_t &nonce, std::string &error) { + if(tmp_buffer_size < payload_length) { + error = "buffer too large"; + return false; + } + + uint8_t tmp_buffer[tmp_buffer_size], tag_length{8}; + uint8_t tag_buffer[16]; + auto err = eax_encrypt_authenticate_memory(find_cipher("rijndael"), + (uint8_t *) key.data(), /* the key */ + (unsigned long) key.size(), /* key is 16 bytes */ + (uint8_t *) nonce.data(), /* the nonce */ + (unsigned long) nonce.size(), /* nonce is 16 bytes */ + (uint8_t *) header, /* example header */ + (unsigned long) header_length, /* header length */ + (uint8_t *) payload, /* The plain text */ + (unsigned long) payload_length, /* Plain text length */ + (uint8_t *) tmp_buffer, /* The result buffer */ + (uint8_t *) tag_buffer, + (unsigned long *) &tag_length + ); + //assert(tag_length == 8); + + if(err != CRYPT_OK) { + error = "encrypt returned " + std::string{error_to_string(err)}; + return false; + } + + memcpy(mac, tag_buffer, 8); + memcpy(payload, tmp_buffer, payload_length); + return true; +} diff --git a/src/protocol/CryptionHandler.h b/src/protocol/CryptHandler.h similarity index 58% rename from src/protocol/CryptionHandler.h rename to src/protocol/CryptHandler.h index 54638b7..b241c14 100644 --- a/src/protocol/CryptionHandler.h +++ b/src/protocol/CryptHandler.h @@ -8,7 +8,7 @@ namespace ts { namespace connection { - class CryptionHandler { + class CryptHandler { enum Methode { TEAMSPEAK_3_1, TEAMSPEAK_3 @@ -24,8 +24,10 @@ namespace ts { }; }; public: - CryptionHandler(); - ~CryptionHandler(); + typedef std::array key_t; + typedef std::array nonce_t; + CryptHandler(); + ~CryptHandler(); void reset(); @@ -36,32 +38,40 @@ namespace ts { //TeamSpeak new bool setupSharedSecretNew(const std::string& alpha, const std::string& beta, const char privateKey[32], const char publicKey[32]); - bool progressPacketOut(protocol::BasicPacket*, std::string&, bool use_default); - bool progressPacketIn(protocol::BasicPacket*, std::string&, bool use_default); + /* mac must be 8 bytes long! */ + bool encrypt( + const void* /* header */, size_t /* header length */, + void* /* payload */, size_t /* payload length */, + void* /* mac */, + const key_t& /* key */, const nonce_t& /* nonce */, + std::string& /* error */); + /* mac must be 8 bytes long! */ + bool decrypt( + const void* /* header */, size_t /* header length */, + void* /* payload */, size_t /* payload length */, + const void* /* mac */, + const key_t& /* key */, const nonce_t& /* nonce */, + std::string& /* error */); + + bool generate_key_nonce(bool /* to server */, uint8_t /* packet type */, uint16_t /* packet id */, uint16_t /* generation */, key_t& /* key */, nonce_t& /* nonce */); bool verify_encryption(const pipes::buffer_view& data, uint16_t packet_id, uint16_t generation); - bool block(){ blocked = true; return true; } - bool unblock(){ blocked = false; return true; } - bool isBlocked(){ return blocked; } + inline void write_default_mac(void* buffer) { + memcpy(buffer, this->current_mac, 8); + } - bool use_default() { return this->useDefaultChipherKeyNonce; } + static constexpr key_t default_key{'c', ':', '\\', 'w', 'i', 'n', 'd', 'o', 'w', 's', '\\', 's', 'y', 's', 't', 'e'}; //c:\windows\syste + static constexpr nonce_t default_nonce{'m', '\\', 'f', 'i', 'r', 'e', 'w', 'a', 'l', 'l', '3', '2', '.', 'c', 'p', 'l'}; //m\firewall32.cpl private: - static constexpr char default_key[16] = {'c', ':', '\\', 'w', 'i', 'n', 'd', 'o', 'w', 's', '\\', 's', 'y', 's', 't', 'e'}; //c:\windows\syste - static constexpr char default_nonce[16] = {'m', '\\', 'f', 'i', 'r', 'e', 'w', 'a', 'l', 'l', '3', '2', '.', 'c', 'p', 'l'}; //m\firewall32.cpl static constexpr char default_mac[8] = {'T', 'S', '3', 'I', 'N', 'I', 'T', '1'}; //TS3INIT1 - bool decryptPacket(protocol::BasicPacket *, std::string &, bool use_default); - bool encryptPacket(protocol::BasicPacket *, std::string &, bool use_default); - - bool generate_key_nonce(bool /* to server */, protocol::PacketType /* type */, uint16_t /* packet id */, uint16_t /* generation */, bool /* use default */, uint8_t(&)[16] /* key */, uint8_t(&)[16] /* nonce */); bool generate_key_nonce(protocol::BasicPacket* packet, bool use_default, uint8_t(&)[16] /* key */, uint8_t(&)[16] /* nonce */); //The default key and nonce bool useDefaultChipherKeyNonce = true; - bool blocked = false; /* for the old protocol SHA1 length for the new 64 bytes */ uint8_t iv_struct[64]; diff --git a/src/protocol/CryptionHandler.cpp b/src/protocol/CryptionHandler.cpp deleted file mode 100644 index 372aed6..0000000 --- a/src/protocol/CryptionHandler.cpp +++ /dev/null @@ -1,380 +0,0 @@ -//#define NO_OPEN_SSL /* because we're lazy and dont want to build this lib extra for the TeaClient */ -#define FIXEDINT_H_INCLUDED /* else it will be included by ge */ - -#include "misc/endianness.h" -#include -#include -#include -#include "misc/memtracker.h" -#include "misc/digest.h" -#include "CryptionHandler.h" -#include "../misc/sassert.h" - -using namespace std; -using namespace ts; -using namespace ts::connection; -using namespace ts::protocol; - - -CryptionHandler::CryptionHandler() { - memtrack::allocated(this); -} - -CryptionHandler::~CryptionHandler() { - memtrack::freed(this); -} - -void CryptionHandler::reset() { - this->useDefaultChipherKeyNonce = true; - this->iv_struct_length = 0; - memset(this->iv_struct, 0, sizeof(this->iv_struct)); - memcpy(this->current_mac, CryptionHandler::default_mac, sizeof(CryptionHandler::default_mac)); - - for(auto& cache : this->cache_key_client) - cache.generation = 0xFFEF; - for(auto& cache : this->cache_key_server) - cache.generation = 0xFFEF; -} - -#define SHARED_KEY_BUFFER_LENGTH (256) -bool CryptionHandler::setupSharedSecret(const std::string& alpha, const std::string& beta, ecc_key *publicKey, ecc_key *ownKey, std::string &error) { - size_t buffer_length = SHARED_KEY_BUFFER_LENGTH; - uint8_t buffer[SHARED_KEY_BUFFER_LENGTH]; - int err; - if((err = ecc_shared_secret(ownKey, publicKey, buffer, (unsigned long*) &buffer_length)) != CRYPT_OK){ - error = "Could not calculate shared secret. Message: " + string(error_to_string(err)); - return false; - } - - auto result = this->setupSharedSecret(alpha, beta, string((const char*) buffer, buffer_length), error); - return result; -} - -bool CryptionHandler::setupSharedSecret(const std::string& alpha, const std::string& beta, const std::string& sharedKey, std::string &error) { - auto secret_hash = digest::sha1(sharedKey); - assert(secret_hash.length() == SHA_DIGEST_LENGTH); - - uint8_t iv_buffer[SHA_DIGEST_LENGTH]; - memcpy(iv_buffer, alpha.data(), 10); - memcpy(&iv_buffer[10], beta.data(), 10); - - for (int index = 0; index < SHA_DIGEST_LENGTH; index++) { - iv_buffer[index] ^= (uint8_t) secret_hash[index]; - } - - { - lock_guard lock(this->cache_key_lock); - memcpy(this->iv_struct, iv_buffer, SHA_DIGEST_LENGTH); - this->iv_struct_length = SHA_DIGEST_LENGTH; - - uint8_t mac_buffer[SHA_DIGEST_LENGTH]; - digest::sha1((const char*) iv_buffer, SHA_DIGEST_LENGTH, mac_buffer); - memcpy(this->current_mac, mac_buffer, 8); - - this->useDefaultChipherKeyNonce = false; - } - - return true; -} - -void _fe_neg(fe h, const fe f) { - int32_t f0 = f[0]; - int32_t f1 = f[1]; - int32_t f2 = f[2]; - int32_t f3 = f[3]; - int32_t f4 = f[4]; - int32_t f5 = f[5]; - int32_t f6 = f[6]; - int32_t f7 = f[7]; - int32_t f8 = f[8]; - int32_t f9 = f[9]; - int32_t h0 = -f0; - int32_t h1 = -f1; - int32_t h2 = -f2; - int32_t h3 = -f3; - int32_t h4 = -f4; - int32_t h5 = -f5; - int32_t h6 = -f6; - int32_t h7 = -f7; - int32_t h8 = -f8; - int32_t h9 = -f9; - - h[0] = h0; - h[1] = h1; - h[2] = h2; - h[3] = h3; - h[4] = h4; - h[5] = h5; - h[6] = h6; - h[7] = h7; - h[8] = h8; - h[9] = h9; -} - -inline void keyMul(uint8_t(& target_buffer)[32], const uint8_t* publicKey /* compressed */, const uint8_t* privateKey /* uncompressed */, bool negate){ - ge_p3 keyA{}; - ge_p2 result{}; - - ge_frombytes_negate_vartime(&keyA, publicKey); - if(negate) { - _fe_neg(*(fe*) &keyA.X, *(const fe*) &keyA.X); /* undo negate */ - _fe_neg(*(fe*) &keyA.T, *(const fe*) &keyA.T); /* undo negate */ - } - ge_scalarmult_vartime(&result, privateKey, &keyA); - - ge_tobytes(target_buffer, &result); -} - -bool CryptionHandler::setupSharedSecretNew(const std::string &alpha, const std::string &beta, const char* privateKey /* uncompressed */, const char* publicKey /* compressed */) { - if(alpha.length() != 10 || beta.length() != 54) - return false; - - uint8_t shared[32]; - uint8_t shared_iv[64]; - - ed25519_key_exchange(shared, (uint8_t*) publicKey, (uint8_t*) privateKey); - keyMul(shared, reinterpret_cast(publicKey), reinterpret_cast(privateKey), true); //Remote key get negated - digest::sha512((char*) shared, 32, shared_iv); - - auto xor_key = alpha + beta; - for(int i = 0; i < 64; i++) - shared_iv[i] ^= (uint8_t) xor_key[i]; - - { - lock_guard lock(this->cache_key_lock); - memcpy(this->iv_struct, shared_iv, 64); - this->iv_struct_length = 64; - - uint8_t mac_buffer[SHA_DIGEST_LENGTH]; - digest::sha1((char*) this->iv_struct, 64, mac_buffer); - memcpy(this->current_mac, mac_buffer, 8); - this->useDefaultChipherKeyNonce = false; - } - - return true; -} - -bool CryptionHandler::generate_key_nonce(protocol::BasicPacket* packet, bool use_default, uint8_t(& key)[16], uint8_t(& nonce)[16]){ - return this->generate_key_nonce( - dynamic_cast(packet) != nullptr, - packet->type().type(), - packet->packetId(), - packet->generationId(), - use_default, - key, - nonce - ); -} - -#define GENERATE_BUFFER_LENGTH (128) -bool CryptionHandler::generate_key_nonce( - bool to_server, /* its from the client to the server */ - protocol::PacketType type, - uint16_t packet_id, - uint16_t generation, - bool use_default, - uint8_t (& key)[16], - uint8_t (& nonce)[16] -) { - if (this->useDefaultChipherKeyNonce || use_default) { - memcpy(key, CryptionHandler::default_key, 16); - memcpy(nonce, CryptionHandler::default_nonce, 16); - return true; - } - - auto& key_cache_array = to_server ? this->cache_key_client : this->cache_key_server; - if(type < 0 || type >= key_cache_array.max_size()) { - logError(0, "Tried to generate a crypt key with invalid type ({})!", type); - return false; - } - - auto& key_cache = key_cache_array[type]; - if(key_cache.generation != generation) { - const size_t buffer_length = 6 + this->iv_struct_length; - sassert(buffer_length < GENERATE_BUFFER_LENGTH); - - char buffer[GENERATE_BUFFER_LENGTH]; - memset(buffer, 0, buffer_length); - - if (to_server) { - buffer[0] = 0x31; - } else { - buffer[0] = 0x30; - } - buffer[1] = (char) (type & 0xF); - - le2be32(generation, buffer, 2); - memcpy(&buffer[6], this->iv_struct, this->iv_struct_length); - digest::sha256(buffer, buffer_length, key_cache.key_nonce); - - key_cache.generation = generation; - } - - memcpy(key, key_cache.key, 16); - memcpy(nonce, key_cache.nonce, 16); - - //Xor the key - key[0] ^= (uint8_t) ((packet_id >> 8) & 0xFF); - key[1] ^=(packet_id & 0xFF); - - return true; -} - -bool CryptionHandler::verify_encryption(const pipes::buffer_view &packet, uint16_t packet_id, uint16_t generation) { - int err; - int success = false; - - uint8_t key[16], nonce[16]; - if(!generate_key_nonce(true, (protocol::PacketType) (packet[12] & 0xF), packet_id, generation, false, key, nonce)) - return false; - - auto mac = packet.view(0, 8); - auto header = packet.view(8, 5); - auto data = packet.view(13); - - auto length = data.length(); - - /* static shareable void buffer */ - const static unsigned long void_target_length = 2048; - static uint8_t void_target_buffer[2048]; - if(void_target_length < length) - return false; - - err = eax_decrypt_verify_memory(find_cipher("rijndael"), - (uint8_t *) key, /* the key */ - (size_t) 16, /* key is 16 bytes */ - (uint8_t *) nonce, /* the nonce */ - (size_t) 16, /* nonce is 16 bytes */ - (uint8_t *) header.data_ptr(), /* example header */ - (unsigned long) header.length(), /* header length */ - (const unsigned char *) data.data_ptr(), - (unsigned long) data.length(), - (unsigned char *) void_target_buffer, - (unsigned char *) mac.data_ptr(), - (unsigned long) mac.length(), - &success - ); - - return err == CRYPT_OK && success; -} - -bool CryptionHandler::decryptPacket(protocol::BasicPacket *packet, std::string &error, bool use_default) { - int err; - int success = false; - - auto header = packet->header(); - auto data = packet->data(); - - uint8_t key[16], nonce[16]; - if(!generate_key_nonce(packet, use_default, key, nonce)) { - error = "Could not generate key/nonce"; - return false; - } - - size_t target_length = 2048; - uint8_t target_buffer[2048]; - auto length = data.length(); - if(target_length < length) { - error = "buffer too large"; - return false; - } - - - err = eax_decrypt_verify_memory(find_cipher("rijndael"), - (uint8_t *) key, /* the key */ - (unsigned long) 16, /* key is 16 bytes */ - (uint8_t *) nonce, /* the nonce */ - (unsigned long) 16, /* nonce is 16 bytes */ - (uint8_t *) header.data_ptr(), /* example header */ - (unsigned long) header.length(), /* header length */ - (const unsigned char *) data.data_ptr(), - (unsigned long) data.length(), - (unsigned char *) target_buffer, - (unsigned char *) packet->mac().data_ptr(), - (unsigned long) packet->mac().length(), - &success - ); - - if((err) != CRYPT_OK){ - error = "eax_decrypt_verify_memory(...) returned " + to_string(err) + "/" + error_to_string(err); - return false; - } - if(!success){ - error = "memory verify failed!"; - return false; - } - - packet->data(pipes::buffer_view{target_buffer, length}); - packet->setEncrypted(false); - return true; -} - -bool CryptionHandler::encryptPacket(protocol::BasicPacket *packet, std::string &error, bool use_default) { - uint8_t key[16], nonce[16]; - if(!generate_key_nonce(packet, use_default, key, nonce)) { - error = "Could not generate key/nonce"; - return false; - } - - size_t length = packet->data().length(); - - size_t tag_length = 8; - char tag_buffer[8]; - - size_t target_length = 2048; - uint8_t target_buffer[2048]; - if(target_length < length) { - error = "buffer too large"; - return false; - } - - int err; - if((err = eax_encrypt_authenticate_memory(find_cipher("rijndael"), - (uint8_t *) key, /* the key */ - (unsigned long) 16, /* key is 16 bytes */ - (uint8_t *) nonce, /* the nonce */ - (unsigned long) 16, /* nonce is 16 bytes */ - (uint8_t *) packet->header().data_ptr(), /* example header */ - (unsigned long) packet->header().length(), /* header length */ - (uint8_t *) packet->data().data_ptr(), /* The plain text */ - (unsigned long) packet->data().length(), /* Plain text length */ - (uint8_t *) target_buffer, /* The result buffer */ - (uint8_t *) tag_buffer, - (unsigned long *) &tag_length - )) != CRYPT_OK){ - error = "eax_encrypt_authenticate_memory(...) returned " + to_string(err) + "/" + error_to_string(err); - return false; - } - assert(tag_length == 8); - - packet->data(pipes::buffer_view{target_buffer, length}); - packet->mac().write(tag_buffer, tag_length); - packet->setEncrypted(true); - return true; -} - -bool CryptionHandler::progressPacketIn(protocol::BasicPacket* packet, std::string& error, bool use_default) { - while(blocked) - this_thread::yield(); - - if(packet->isEncrypted()){ - bool success = decryptPacket(packet, error, use_default); - if(success) packet->setEncrypted(false); - return success; - } - return true; -} - -bool CryptionHandler::progressPacketOut(protocol::BasicPacket* packet, std::string& error, bool use_default) { - while(blocked) - this_thread::yield(); - - if(packet->has_flag(PacketFlag::Unencrypted)) { - packet->mac().write(this->current_mac, 8); - } else { - bool success = encryptPacket(packet, error, use_default); - if(success) packet->setEncrypted(true); - return success; - } - return true; -} diff --git a/src/protocol/Packet.cpp b/src/protocol/Packet.cpp index bf9e8aa..f863a37 100644 --- a/src/protocol/Packet.cpp +++ b/src/protocol/Packet.cpp @@ -214,5 +214,29 @@ namespace ts { this->header()[2] = clId >> 8; this->header()[3] = clId & 0xFF; } + + + uint16_t IncomingClientPacketParser::packet_id() const { return be2le16(this->_buffer.data_ptr(), IncomingClientPacketParser::kHeaderOffset + 0); } + uint16_t IncomingClientPacketParser::client_id() const { return be2le16(this->_buffer.data_ptr(), IncomingClientPacketParser::kHeaderOffset + 2); } + uint8_t IncomingClientPacketParser::type() const { return this->_buffer[IncomingClientPacketParser::kHeaderOffset + 4] & 0xF; } + uint8_t IncomingClientPacketParser::flags() const { return this->_buffer[IncomingClientPacketParser::kHeaderOffset + 4] & 0xF0; } + + bool IncomingClientPacketParser::is_encrypted() const { + if(this->decrypted) return false; + + return (this->flags() & PacketFlag::Unencrypted) == 0; + } + + bool IncomingClientPacketParser::is_compressed() const { + if(this->uncompressed) return false; + + return (this->flags() & PacketFlag::Compressed) > 0; + } + + bool IncomingClientPacketParser::is_fragmented() const { + if(this->defragmented) return false; + + return (this->flags() & PacketFlag::Fragmented) > 0; + } } } \ No newline at end of file diff --git a/src/protocol/Packet.h b/src/protocol/Packet.h index d95f28f..8729393 100644 --- a/src/protocol/Packet.h +++ b/src/protocol/Packet.h @@ -286,6 +286,49 @@ namespace ts { void setPacketId(uint16_t, uint16_t) override; }; + class IncomingClientPacketParser { + public: + constexpr static auto kHeaderOffset = 8; + constexpr static auto kHeaderLength = CLIENT_HEADER_SIZE; + + constexpr static auto kPayloadOffset = kHeaderOffset + CLIENT_HEADER_SIZE; + explicit IncomingClientPacketParser(pipes::buffer_view buffer) : _buffer{std::move(buffer)} {} + IncomingClientPacketParser(const IncomingClientPacketParser&) = delete; + + [[nodiscard]] inline bool valid() const { + if(this->_buffer.length() < kPayloadOffset) return false; + return this->type() <= 8; + } + + [[nodiscard]] inline const void* data_ptr() const { return this->_buffer.data_ptr(); } + [[nodiscard]] inline void* mutable_data_ptr() { return (void*) this->_buffer.data_ptr(); } + + [[nodiscard]] inline pipes::buffer_view buffer() const { return this->_buffer; } + [[nodiscard]] inline const pipes::buffer_view mac() const { return this->_buffer.view(0, 8); } + [[nodiscard]] inline const pipes::buffer_view payload() const { return this->_buffer.view(kPayloadOffset); } + [[nodiscard]] inline size_t payload_length() const { return this->_buffer.length() - kPayloadOffset; } + + [[nodiscard]] uint16_t client_id() const; + [[nodiscard]] uint16_t packet_id() const; + [[nodiscard]] uint8_t type() const; + [[nodiscard]] uint8_t flags() const; + + [[nodiscard]] bool is_encrypted() const; + [[nodiscard]] bool is_compressed() const; + [[nodiscard]] bool is_fragmented() const; + + [[nodiscard]] uint16_t estimated_generation() const { return this->generation; } + void set_estimated_generation(uint16_t generation) { this->generation = generation; } + + inline void set_decrypted() { this->decrypted = true; } + inline void set_uncompressed() { this->uncompressed = true; } + inline void set_defragmented() { this->defragmented = true; } + private: + uint16_t generation{}; + bool decrypted{false}, uncompressed{false}, defragmented{false}; + pipes::buffer_view _buffer{}; + }; + /** * Packet from the server */ diff --git a/src/protocol/generation.cpp b/src/protocol/generation.cpp new file mode 100644 index 0000000..4bb8687 --- /dev/null +++ b/src/protocol/generation.cpp @@ -0,0 +1,36 @@ +#include "./generation.h" + +using namespace ts::protocol; + +generation_estimator::generation_estimator() { + this->reset(); +} + +void generation_estimator::reset() { + this->last_generation = 0; + this->last_packet_id = 0; +} + +uint16_t generation_estimator::visit_packet(uint16_t packet_id) { + if(this->last_packet_id >= generation_estimator::overflow_area_begin) { + if(packet_id > this->last_packet_id) { + this->last_packet_id = packet_id; + return this->last_generation; + } else { + this->last_packet_id = packet_id; + return ++this->last_generation; + } + } else if(this->last_packet_id <= generation_estimator::overflow_area_end) { + if(packet_id >= generation_estimator::overflow_area_begin) /* old packet */ + return this->last_generation - 1; + this->last_packet_id = packet_id; + return this->last_generation; + } else { + this->last_packet_id = packet_id; + return this->last_generation; + } +} + +uint16_t generation_estimator::generation() const { + return this->last_generation; +} \ No newline at end of file diff --git a/src/protocol/generation.h b/src/protocol/generation.h new file mode 100644 index 0000000..8c7429a --- /dev/null +++ b/src/protocol/generation.h @@ -0,0 +1,26 @@ +#pragma once + +#include + +namespace ts::protocol { + class generation_estimator { + public: + generation_estimator(); + + void reset(); + uint16_t visit_packet(uint16_t /* packet id */); + uint16_t generation() const; + + void set_last_state(uint16_t last_packet, uint16_t generation) { + this->last_packet_id = last_packet; + this->last_generation = generation; + } + private: + constexpr static uint16_t overflow_window{1024}; + constexpr static uint16_t overflow_area_begin{0xFFFF - overflow_window}; + constexpr static uint16_t overflow_area_end{overflow_window}; + + uint16_t last_generation{0}; + uint16_t last_packet_id{0}; + }; +} \ No newline at end of file diff --git a/src/query/command3.h b/src/query/command3.h index 3812cde..53e44d4 100644 --- a/src/query/command3.h +++ b/src/query/command3.h @@ -70,9 +70,9 @@ namespace ts { template [[nodiscard]] inline T value_as(const std::string_view& key) const { static_assert(converter::supported, "Target type isn't supported!"); - static_assert(!converter::supported || converter::from_string, "Target type dosn't support parsing"); + static_assert(!converter::supported || converter::from_string_view, "Target type dosn't support parsing"); - return converter::from_string(this->value(key)); + return converter::from_string_view(this->value(key)); } protected: diff --git a/src/query/command_handler.h b/src/query/command_handler.h index ca5e26c..58e25f5 100644 --- a/src/query/command_handler.h +++ b/src/query/command_handler.h @@ -229,7 +229,7 @@ namespace ts { struct field : public function_parameter, public bulk_extend { friend struct command_invoker>; static_assert(converter::supported, "Target type isn't supported!"); - static_assert(!converter::supported || converter::from_string, "Target type dosn't support parsing"); + static_assert(!converter::supported || converter::from_string_view, "Target type dosn't support parsing"); public: template diff --git a/test/generationTest.cpp b/test/generationTest.cpp new file mode 100644 index 0000000..2bd2374 --- /dev/null +++ b/test/generationTest.cpp @@ -0,0 +1,105 @@ +#include +#include +#include +#include + +using namespace ts::protocol; + + +typedef std::vector> test_vector_t; + +test_vector_t generate_test_vector(size_t size, int loss) { + test_vector_t result{}; + result.reserve(size); + + + for(size_t i = 0; i < size; i++) { + if ((rand() % 100) < loss) continue; + result.emplace_back(i & 0xFFFFU, i >> 16U); + } + + return result; +} + +test_vector_t swap_elements(test_vector_t vector, int per, int max_distance) { + for(size_t index = 0; index < vector.size() - max_distance; index++) { + if ((rand() % 100) < per) { + //lets switch + auto offset = rand() % max_distance; + if(!offset) offset = 1; + + std::swap(vector[index], vector[index + offset]); + } + } + + return vector; +} + +bool test_vector(const std::string_view& name, const test_vector_t& vector) { + generation_estimator gen{}; + + size_t last_value{0}; + for(auto [id, exp] : vector) { + if(auto val = gen.visit_packet(id); val != exp) { + std::cout << "[" << name << "] failed for " << id << " -> " << exp << " | " << val << ". Last value: " << last_value << "\n"; + return false; + } + last_value = id; + } + return true; +} + +template +bool test_vector(generation_estimator& generator, const std::array& packet_ids, const std::array& expected) { + for(size_t index = 0; index < N; index++) { + auto result = generator.visit_packet(packet_ids[index]); + if(result != expected[index]) { + std::cout << "failed to packet id " << packet_ids[index] << " (" << index << "). Result: " << result << " Expected: " << expected[index] << "\n"; + std::cout << "----- fail\n"; + return false; + } + + std::cout << "PacketID: " << packet_ids[index] << " -> " << result << "\n"; + } + std::cout << "----- pass\n"; + + return true; +} + +int main() { + generation_estimator gen{}; + + { + test_vector("00 loss", generate_test_vector(0x3000, 0)); + test_vector("10 loss", generate_test_vector(0x3000, 10)); + test_vector("25 loss", generate_test_vector(0x3000, 25)); + test_vector("50 loss", generate_test_vector(0x3000, 50)); + test_vector("80 loss", generate_test_vector(0x3000, 80)); + } + + { + auto base = generate_test_vector(0x3000, 0); + test_vector("swap 30:20", swap_elements(base, 30, 20)); + test_vector("swap 30:1000", swap_elements(base, 30, 200)); + test_vector("swap 80:1000", swap_elements(base, 80, 200)); + } + + if(false) { + test_vector("10 loss", generate_test_vector(0x3000, 10)); + test_vector("25 loss", generate_test_vector(0x3000, 25)); + test_vector("50 loss", generate_test_vector(0x3000, 50)); + test_vector("80 loss", generate_test_vector(0x3000, 80)); + } + + gen.set_last_state(0, 0); + test_vector<6>(gen, {0, 1, 2, 4, 3, 5}, {0, 0, 0, 0, 0, 0}); + + gen.set_last_state(0xFF00, 0); + test_vector<6>(gen, {0, 1, 2, 4, 3, 5}, {1, 1, 1, 1, 1, 1}); + + gen.set_last_state(0xFF00, 0); + test_vector<6>(gen, {0, 1, 2, 0xFF00, 3, 5}, {1, 1, 1, 0, 1, 1}); + + gen.set_last_state(0xFF00, 0); + test_vector<6>(gen, {0xFFFE, 0xFFFF, 0, 1, 0xFFFC, 2}, {0, 0, 1, 1, 0, 1}); +} \ No newline at end of file