From c117799cc9b391ec8072db35bbd0558aba7dff53 Mon Sep 17 00:00:00 2001 From: WolverinDEV Date: Thu, 5 Sep 2019 12:30:07 +0200 Subject: [PATCH] Fixed/improved MySQL --- CMakeLists.txt | 7 +- src/protocol/Packet.cpp | 9 +- src/protocol/Packet.h | 7 +- src/sql/mysql/MySQL.cpp | 850 +++++++++++++++++++++++++++++++--------- src/sql/mysql/MySQL.h | 61 ++- test/SQLTest.cpp | 39 +- 6 files changed, 727 insertions(+), 246 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 4e5f926..3d42ec1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -58,7 +58,7 @@ if (MSVC) CMAKE_C_FLAGS CMAKE_C_FLAGS_DEBUG CMAKE_C_FLAGS_RELEASE - ) + ) foreach(CompilerFlag ${CompilerFlags}) string(REPLACE "/MD" "/MT" ${CompilerFlag} "${${CompilerFlag}}") endforeach() @@ -200,12 +200,13 @@ set(TEST_LIBRARIES ${LIBRARY_PATH_ED255} ${LIBRARY_PATH_JSON} ${LIBRARY_TOM_CRYPT} - ${LIBRARY_TOM_MATH} - ${LIBRARY_PATH_JDBC} + mysqlclient.a ${LIBRARY_PATH_DATA_PIPES} ${LIBRARY_PATH_BORINGSSL_SSL} ${LIBRARY_PATH_BORINGSSL_CRYPTO} #Crypto must be linked after + dl + z ) include_directories(src/) option(BUILD_TESTS "Enable/disable test building" ON) diff --git a/src/protocol/Packet.cpp b/src/protocol/Packet.cpp index 15cfd89..f0b9502 100644 --- a/src/protocol/Packet.cpp +++ b/src/protocol/Packet.cpp @@ -13,12 +13,19 @@ using namespace std; namespace ts { namespace protocol { - PacketTypeInfo::PacketTypeInfo(std::string name, PacketType type, bool ack, int max_length) noexcept { + PacketTypeInfo::PacketTypeInfo(const std::string& name, PacketType type, bool ack, int max_length) noexcept { this->data = new PacketTypeProperties{name, type, max_length, ack}; + this->owns_data = true; + if(type < 0x0F) types.insert({type, *this}); } + PacketTypeInfo::~PacketTypeInfo() { + if(this->owns_data) + delete this->data; + } + PacketTypeInfo::PacketTypeInfo(PacketTypeInfo &red) : data(red.data) { } PacketTypeInfo::PacketTypeInfo(const PacketTypeInfo &red) : data(red.data) { } diff --git a/src/protocol/Packet.h b/src/protocol/Packet.h index 7d866d9..5728c64 100644 --- a/src/protocol/Packet.h +++ b/src/protocol/Packet.h @@ -67,10 +67,13 @@ namespace ts { PacketTypeInfo(PacketTypeInfo&); PacketTypeInfo(const PacketTypeInfo&); PacketTypeInfo(PacketTypeInfo&& remote) : data(remote.data) {} + + ~PacketTypeInfo(); private: static std::map types; - PacketTypeInfo(std::string, PacketType, bool, int) noexcept; - PacketTypeProperties* data; + PacketTypeInfo(const std::string&, PacketType, bool, int) noexcept; + PacketTypeProperties* data; + bool owns_data = false; }; struct PacketIdManagerData { diff --git a/src/sql/mysql/MySQL.cpp b/src/sql/mysql/MySQL.cpp index 5cc8033..3413044 100644 --- a/src/sql/mysql/MySQL.cpp +++ b/src/sql/mysql/MySQL.cpp @@ -10,10 +10,16 @@ #include +#define CR_CONNECTION_ERROR (2002) +#define CR_SERVER_GONE_ERROR (2006) +#define CR_SERVER_LOST (2013) + using namespace std; using namespace sql; using namespace sql::mysql; +//TODO: Cache statements in general any only reapply the values + MySQLManager::MySQLManager() : SqlManager(SqlType::TYPE_MYSQL) {} MySQLManager::~MySQLManager() {} @@ -21,9 +27,11 @@ MySQLManager::~MySQLManager() {} //mysql://[host][:port]/[database][?propertyName1=propertyValue1[&propertyName2=propertyValue2]...] #define MYSQL_PREFIX "mysql://" -inline result parse_url(const string& url, ConnectOptionsMap& connect_map) { +inline result parse_url(const string& url, std::map& connect_map) { string target_url; - if(url.find(MYSQL_PREFIX) != 0) return {ERROR_MYSQL_INVLID_URL, "Missing mysql:// at begin"}; + if(url.find(MYSQL_PREFIX) != 0) + return {ERROR_MYSQL_INVLID_URL, "Missing mysql:// at begin"}; + auto index_parms = url.find('?'); if(index_parms == string::npos) { target_url = "tcp://" + url.substr(strlen(MYSQL_PREFIX)); @@ -51,50 +59,186 @@ inline result parse_url(const string& url, ConnectOptionsMap& connect_map) { return result::success; } +//mysql://[host][:port]/[database][?propertyName1=propertyValue1[&propertyName2=propertyValue2]...] +inline bool parse_mysql_data(const string& url, string& error, string& host, uint16_t& port, string& database, map& properties) { + size_t parse_index = 0; + /* parse the scheme */ + { + auto index = url.find("://", parse_index); + if(index == -1 || url.substr(parse_index, index - parse_index) != "mysql") { + error = "missing/invalid URL scheme"; + return false; + } + + parse_index = index + 3; + if(parse_index >= url.length()) { + error = "unexpected EOL after scheme"; + return false; + } + } + + /* parse host[:port]*/ + { + auto index = url.find('/', parse_index); + if(index == -1) { + error = "missing host/port"; + return false; + } + + auto host_port = url.substr(parse_index, index - parse_index); + + auto port_index = host_port.find(':'); + if(port_index == -1) { + host = host_port; + } else { + host = host_port.substr(0, port_index); + auto port_str = host_port.substr(port_index + 1); + try { + port = stol(port_str); + } catch(std::exception& ex) { + error = "failed to parse port"; + return false; + } + } + if(host.empty()) { + error = "host is empty"; + return false; + } + + parse_index = index + 1; + if(parse_index >= url.length()) { + error = "unexpected EOL after host/port"; + return false; + } + } + + /* the database */ + { + auto index = url.find('?', parse_index); + if(index == -1) { + database = url.substr(parse_index); + parse_index = url.length(); + } else { + database = url.substr(parse_index, index - parse_index); + parse_index = index + 1; + } + + if(database.empty()) { + error = "database is empty"; + return false; + } + } + + /* properties */ + string full_property, property_key, property_value; + while(parse_index < url.length()){ + /* "read" the next property */ + { + auto index = url.find('&', parse_index); /* next entry */ + if(index == -1) { + full_property = url.substr(parse_index); + parse_index = url.length(); + } else { + full_property = url.substr(parse_index, index - parse_index); + parse_index = index + 1; + } + } + + /* parse it */ + { + auto index = full_property.find('='); + if(index == -1) { + error = "invalid property format (missing '=')"; + return false; + } + + property_key = full_property.substr(0, index); + property_value = full_property.substr(index + 1); + if(property_key.empty() || property_value.empty()) { + error = "invalid property key/value (empty)"; + return false; + } + + properties[property_key] = http::decode_url(property_value); + } + } + return true; +} + +mysql::Connection::~Connection() { + { + lock_guard lock(this->used_lock); + assert(!this->used); + } + + if(this->handle) { + mysql_close(this->handle); + this->handle = nullptr; + } +} + result MySQLManager::connect(const std::string &url) { this->disconnecting = false; + string error; - ConnectOptionsMap connect_map; - connect_map["connections"] = "1"; - auto res = parse_url(url, connect_map); - if(!res) return res; + map properties; + string host, database; + uint16_t port; - this->driver = get_driver_instance(); - if(!this->driver) return {ERROR_MYSQL_MISSING_DRIVER, "Missing driver!"}; + if(!parse_mysql_data(url, error, host, port, database, properties)) { + error = "URL parsing failed: " + error; + return {ERROR_MYSQL_INVLID_URL, error}; + } - try { - auto entry = connect_map["connections"]; - auto connections = std::stoll(connect_map["connections"].get()->asStdString()); - if(connections < 1) return {ERROR_MYSQL_INVLID_PROPERTIES, "Invalid connection count"}; - - for(int i = 0; i < connections; i++) { - auto connection = unique_ptr(this->driver->connect(connect_map)); - if(!connection) return {ERROR_MYSQL_INVLID_CONNECT, "Could not spawn new connection"}; - if(!connection->isValid()) return {ERROR_MYSQL_INVLID_CONNECT, "Could not validate connection"}; - if(connection->getSchema().length() == 0) return {ERROR_MYSQL_INVLID_CONNECT, "Missing schema!"}; - - this->connections.push_back(shared_ptr(new ConnectionEntry(std::move(connection), false))); + size_t connections = 4; + if(properties.count("connections") > 0) { + try { + connections = stol(properties["connections"]); + } catch(std::exception& ex) { + return {ERROR_MYSQL_INVLID_PROPERTIES, "could not parse connection count"}; } - } catch (sql::SQLException& ex) { - return {ERROR_MYSQL_INVLID_CONNECT, ex.what()}; + } + + string username, password; + if(properties.count("userName") > 0) username = properties["userName"]; + if(properties.count("username") > 0) username = properties["username"]; + if(username.empty()) return {ERROR_MYSQL_INVLID_PROPERTIES, "missing username property"}; + + if(properties.count("password") > 0) password = properties["password"]; + if(password.empty()) return {ERROR_MYSQL_INVLID_PROPERTIES, "missing password property"}; + + //debugMessage(LOG_GENERAL, R"([MYSQL] Starting {} connections to {}:{} with database "{}" as user "{}")", connections, host, port, database, username); + + for(size_t index = 0; index < connections; index++) { + auto connection = make_shared(); + connection->handle = mysql_init(nullptr); + if(!connection->handle) + return {-1, "failed to allocate connection " + to_string(index)}; + + { + my_bool reconnect = 1; + mysql_options(connection->handle, MYSQL_OPT_RECONNECT, &reconnect); + } + + auto result = mysql_real_connect(connection->handle, host.c_str(), username.c_str(), password.c_str(), database.c_str(), port, nullptr, 0); //CLIENT_MULTI_RESULTS | CLIENT_MULTI_STATEMENTS + if(!result) + return {-1, "failed to connect to server with connection " + to_string(index) + ": " + mysql_error(connection->handle)}; + + connection->used = false; + this->connections.push_back(connection); } return result::success; } bool MySQLManager::connected() { lock_guard lock(this->connections_lock); - for(const auto& conn : this->connections) - if(conn->used || conn->connection->isValid()) return true; - return false; + return !this->connections.empty(); } result MySQLManager::disconnect() { lock_guard lock(this->connections_lock); this->disconnecting = true; - for(const auto& entry : this->connections) - entry->connection->close(); - this->connections.clear(); this->connections_condition.notify_all(); @@ -102,6 +246,40 @@ result MySQLManager::disconnect() { return result::success; } +struct StatementGuard { + MYSQL_STMT* stmt; + + ~StatementGuard() { + mysql_stmt_close(this->stmt); + } +}; + +struct ResultGuard { + MYSQL_RES* result; + + ~ResultGuard() { + mysql_free_result(this->result); + } +}; + +template +struct FreeGuard { + T* ptr; + + ~FreeGuard() { + if(this->ptr) ::free(this->ptr); + } +}; + +template +struct DeleteAGuard { + T* ptr; + + ~DeleteAGuard() { + delete[] this->ptr; + } +}; + std::shared_ptr MySQLManager::allocateCommandData() { return make_shared(); } @@ -119,27 +297,7 @@ std::shared_ptr MySQLManager::copyCommandData(std::shared_ptrclose(); - delete stmt; -}; -typedef unique_ptr PreparedStatementHandle; - -void statement_release(sql::Statement* stmt) { - if(stmt) stmt->close(); - delete stmt; -}; -typedef unique_ptr StatementHandle; - -void result_release(sql::ResultSet* set) { - if(set && !set->isClosed()) set->close(); - delete set; -} -typedef unique_ptr ResultHandle; - - -namespace sql { - namespace mysql { +namespace sql::mysql { bool evaluate_sql_query(string& sql, const std::vector& vars, std::vector& result) { char quote = 0; for(int index = 0; index < sql.length(); index++) { @@ -169,66 +327,313 @@ namespace sql { break; } if(!insert) - result.push_back(variable{}); + result.emplace_back(); } return true; } - inline bool bind_parms(const PreparedStatementHandle& stmt, const std::vector& vars) { - uint32_t index = 1; - for(const auto& var : vars) { - switch (var.type()) { - case VARTYPE_NULL: - stmt->setNull(index, 0); - break; - case VARTYPE_BOOLEAN: - stmt->setBoolean(index, var.as()); - break; - case VARTYPE_INT: - stmt->setInt(index, var.as()); - break; - case VARTYPE_LONG: - stmt->setInt64(index, var.as()); - break; - case VARTYPE_DOUBLE: - stmt->setDouble(index, var.as()); - break; - case VARTYPE_FLOAT: - stmt->setDouble(index, var.as()); - break; - case VARTYPE_TEXT: - stmt->setString(index, var.value()); - break; - default: - cerr << "[MySQL] Invalid var type (" << var.type() << ")" << endl; - break; + struct BindMemory { }; + + /* memory must be freed via ::free! */ + bool create_bind(BindMemory*& memory, const std::vector& variables) { + size_t required_bytes = sizeof(MYSQL_BIND) * variables.size(); + + /* first lets calculate the required memory */ + { + for(auto& variable : variables) { + switch (variable.type()) { + case VARTYPE_NULL: + break; + case VARTYPE_BOOLEAN: + required_bytes += sizeof(bool); + break; + case VARTYPE_INT: + required_bytes += sizeof(int32_t); + break; + case VARTYPE_LONG: + required_bytes += sizeof(int64_t); + break; + case VARTYPE_DOUBLE: + required_bytes += sizeof(double); + break; + case VARTYPE_FLOAT: + required_bytes += sizeof(float); + break; + case VARTYPE_TEXT: + //TODO: Use a direct pointer to the variable's value instead of copying it + required_bytes += sizeof(unsigned long*) + variable.value().length(); + break; + default: + return false; /* unknown variable type */ + } } - index++; } + if(!required_bytes) { + memory = nullptr; + return true; + } + + //logTrace(LOG_GENERAL, "[MYSQL] Allocated {} bytes for parameters", required_bytes); + memory = (BindMemory*) malloc(required_bytes); + if(!memory) + return false; + + memset(memory, 0, required_bytes); + /* lets fill the values */ + { + size_t memory_index = variables.size() * sizeof(MYSQL_BIND); + auto bind_ptr = (MYSQL_BIND*) memory; + auto payload_ptr = (char*) memory + sizeof(MYSQL_BIND) * variables.size(); + + for(size_t index = 0; index < variables.size(); index++) { + bind_ptr->buffer = payload_ptr; + + auto& variable = variables[index]; + switch (variable.type()) { + case VARTYPE_NULL: + bind_ptr->buffer_type = enum_field_types::MYSQL_TYPE_NULL; + break; + case VARTYPE_BOOLEAN: + bind_ptr->buffer_type = enum_field_types::MYSQL_TYPE_TINY; + bind_ptr->buffer_length = sizeof(bool); + *(bool*) payload_ptr = variable.as(); + break; + case VARTYPE_INT: + bind_ptr->buffer_type = enum_field_types::MYSQL_TYPE_LONG; + bind_ptr->buffer_length = sizeof(int32_t); + *(int32_t*) payload_ptr = variable.as(); + break; + case VARTYPE_LONG: + bind_ptr->buffer_type = enum_field_types::MYSQL_TYPE_LONGLONG; + bind_ptr->buffer_length = sizeof(int64_t); + *(int64_t*) payload_ptr = variable.as(); + break; + case VARTYPE_DOUBLE: + bind_ptr->buffer_type = enum_field_types::MYSQL_TYPE_DOUBLE; + bind_ptr->buffer_length = sizeof(double); + *(double*) payload_ptr = variable.as(); + break; + case VARTYPE_FLOAT: + bind_ptr->buffer_type = enum_field_types::MYSQL_TYPE_FLOAT; + bind_ptr->buffer_length = sizeof(float); + *(float*) payload_ptr = variable.as(); + break; + case VARTYPE_TEXT: { + auto value = variable.value(); + + //TODO: Use a direct pointer to the variable's value instead of copying it + //May use a string object allocated on the memory_ptr? (Special deinit needed then!) + bind_ptr->buffer_type = enum_field_types::MYSQL_TYPE_STRING; + bind_ptr->buffer_length = value.length(); + + bind_ptr->length = (unsigned long*) payload_ptr; + *bind_ptr->length = bind_ptr->buffer_length; + + payload_ptr += sizeof(unsigned long*); + memory_index += sizeof(unsigned long*); + + memcpy(payload_ptr, value.data(), value.length()); + bind_ptr->buffer = payload_ptr; + break; + } + default: + return false; /* unknown variable type */ + } + + payload_ptr += bind_ptr->buffer_length; + bind_ptr++; + assert(memory_index <= required_bytes); + } + } + + return true; + } + + struct ResultBindDescriptor { + size_t primitive_size = 0; + + void(*destroy)(char*& /* primitive ptr */) = nullptr; + bool(*create)(const MYSQL_FIELD& /* field */, MYSQL_BIND& /* bind */, char*& /* primitive ptr */) = nullptr; + + bool(*get_as_string)(MYSQL_BIND& /* bind */, std::string& /* result */) = nullptr; + }; + + /* memory to primitive string */ + template + bool m2ps(MYSQL_BIND& bind, string& str) { + if(bind.error_value || (bind.error && *bind.error)) return false; + str = std::to_string(*(T*) bind.buffer); + return true; + } + + template + void _do_destroy_primitive(char*& primitive_ptr) { + primitive_ptr += size; + } + + template + bool _do_bind_primitive(const MYSQL_FIELD&, MYSQL_BIND& bind, char*& primitive_ptr) { + bind.buffer = (void*) primitive_ptr; + bind.buffer_length = size; + bind.buffer_type = (enum_field_types) type; + primitive_ptr += size; + return true; + } + + #define CREATE_PRIMATIVE_BIND_DESCRIPTOR(mysql_type, c_type, size) \ + case mysql_type:\ + static ResultBindDescriptor _ ##mysql_type = {\ + size,\ + _do_destroy_primitive,\ + _do_bind_primitive,\ + m2ps\ + };\ + return &_ ##mysql_type; + + const ResultBindDescriptor* get_bind_descriptor(enum_field_types type) { + switch (type) { + case MYSQL_TYPE_NULL: + static ResultBindDescriptor _null = { + /* primitive_size */ 0, + /* destroy */ _do_destroy_primitive<0>, + /* create */ _do_bind_primitive, + /* get_as_string */ [](MYSQL_BIND&, string& str) { str.clear(); return true; } + }; + return &_null; + CREATE_PRIMATIVE_BIND_DESCRIPTOR(MYSQL_TYPE_TINY, int8_t, 1); + CREATE_PRIMATIVE_BIND_DESCRIPTOR(MYSQL_TYPE_SHORT, int16_t, 2); + CREATE_PRIMATIVE_BIND_DESCRIPTOR(MYSQL_TYPE_INT24, int32_t, 4); + CREATE_PRIMATIVE_BIND_DESCRIPTOR(MYSQL_TYPE_LONG, int32_t, 4); + CREATE_PRIMATIVE_BIND_DESCRIPTOR(MYSQL_TYPE_LONGLONG, int64_t, 8); + CREATE_PRIMATIVE_BIND_DESCRIPTOR(MYSQL_TYPE_DOUBLE, double, sizeof(double)); + CREATE_PRIMATIVE_BIND_DESCRIPTOR(MYSQL_TYPE_FLOAT, float, sizeof(float)); + case MYSQL_TYPE_VAR_STRING: + case MYSQL_TYPE_STRING: + case MYSQL_TYPE_BLOB: + static ResultBindDescriptor _string = { + /* primitive_size */ sizeof(void*) + sizeof(unsigned long*), /* we store the allocated buffer in the primitive types buffer and the length */ + /* destroy */ [](char*& primitive) { ::free(*(void**) primitive); primitive += sizeof(void*); primitive += sizeof(unsigned long*); }, + /* create */ [](const MYSQL_FIELD& field, MYSQL_BIND& bind, char*& primitive) { + bind.buffer_length = field.max_length > 0 ? field.max_length : min(field.length, 5UL * 1024UL * 1024UL); + bind.buffer = malloc(bind.buffer_length); + bind.buffer_type = MYSQL_TYPE_BLOB; + + *(void**) primitive = bind.buffer; + primitive += sizeof(void*); + + bind.length = (unsigned long*) primitive; + primitive += sizeof(unsigned long*); + + return bind.buffer != nullptr; + }, + /* get_as_string */ [](MYSQL_BIND& bind, std::string& result) { + auto length = bind.length ? *bind.length : bind.length_value; + result.reserve(length); + result.assign((const char*) bind.buffer, length); + return true; + } + }; + return &_string; + default: + return nullptr; + } + } + + #undef CREATE_PRIMATIVE_BIND_DESCRIPTOR + + struct ResultBind { + size_t field_count = 0; + BindMemory* memory = nullptr; + const ResultBindDescriptor** descriptors = nullptr; + + ~ResultBind() { + if(memory) { + auto memory_ptr = (char*) this->memory + (sizeof(MYSQL_BIND) * field_count); + + for(size_t index = 0; index < this->field_count; index++) + this->descriptors[index]->destroy(memory_ptr); + + ::free(memory); + } + delete[] descriptors; + } + + ResultBind(const ResultBind&) = delete; + ResultBind(ResultBind&&) = default; + + inline bool get_as_string(size_t column, string& result) { + if(!descriptors) return false; + + auto& bind_ptr = *(MYSQL_BIND*) ((char*) this->memory + sizeof(MYSQL_BIND) * column); + return this->descriptors[column]->get_as_string(bind_ptr, result); + } + + inline bool get_as_string(string* results) { + if(!descriptors) return false; + + auto bind_ptr = (MYSQL_BIND*) this->memory; + for(int index = 0; index < this->field_count; index++) + if(!this->descriptors[index]->get_as_string(*bind_ptr, results[index])) + return false; + else + bind_ptr++; + return true; + } + }; + + bool create_result_bind(size_t field_count, MYSQL_FIELD* fields, ResultBind& result) { + size_t required_bytes = sizeof(MYSQL_BIND) * field_count; + + assert(!result.field_count); + assert(!result.descriptors); + assert(!result.memory); + result.descriptors = new const ResultBindDescriptor*[field_count]; + result.field_count = field_count; + + for(size_t index = 0; index < field_count; index++) { + result.descriptors[index] = get_bind_descriptor(fields[index].type); + if(!result.descriptors[index]) return false; + + required_bytes += result.descriptors[index]->primitive_size; + } + + if(!required_bytes) { + result.memory = nullptr; + return true; + } + + logTrace(LOG_GENERAL, "[MYSQL] Allocated {} bytes for response", required_bytes); + result.memory = (BindMemory*) malloc(required_bytes); + if(!result.memory) + return false; + + memset(result.memory, 0, required_bytes); + auto memory_ptr = (char*) result.memory + (sizeof(MYSQL_BIND) * field_count); + auto bind_ptr = (MYSQL_BIND*) result.memory; + for(size_t index = 0; index < field_count; index++) { + if(!result.descriptors[index]->create(fields[index], *bind_ptr, memory_ptr)) return false; + bind_ptr->buffer_type = fields[index].type; + bind_ptr++; + } + assert(memory_ptr == ((char*) result.memory + required_bytes)); /* Overflow check */ return true; } } -} -LocalConnection::LocalConnection(MySQLManager* mgr, const std::shared_ptr& entry) : _mgr(mgr), _connection(entry) { - _mgr->driver->threadInit(); - _connection->used = true; - //logMessage(LOG_GENERAL, "Allocate local connection {} and thread {}", (void*) _connection.get(), (void*) threads::self::id()); -} - -LocalConnection::~LocalConnection() { - //logMessage(LOG_GENERAL, "Deallocate local connection {} and thread {}", (void*) this->_connection.get(), (void*) threads::self::id()); - _mgr->driver->threadEnd(); - _connection->used = false; +AcquiredConnection::AcquiredConnection(MySQLManager* owner, const std::shared_ptr &connection) : owner(owner), connection(connection) { } +AcquiredConnection::~AcquiredConnection() { + { + lock_guard lock{this->connection->used_lock}; + this->connection->used = false; + } { - lock_guard lock(_mgr->connections_lock); - _mgr->connections_condition.notify_all(); + lock_guard lock(this->owner->connections_lock); + this->owner->connections_condition.notify_one(); } } - -std::unique_ptr MySQLManager::next_connection() { - unique_ptr result; +std::unique_ptr MySQLManager::next_connection() { + unique_ptr result; { unique_lock connections_lock(this->connections_lock); @@ -236,9 +641,14 @@ std::unique_ptr MySQLManager::next_connection() { size_t available_connections = 0; for(const auto& connection : this->connections) { available_connections++; - if(connection->used) continue; - result = std::make_unique(this, connection); + { + lock_guard use_lock(connection->used_lock); + if(connection->used) continue; + connection->used = true; + } + + result = std::make_unique(this, connection); break; } @@ -249,138 +659,190 @@ std::unique_ptr MySQLManager::next_connection() { this->disconnect(); return nullptr; } + this->connections_condition.wait(connections_lock); /* wait for the next connection */ } } } + //TODO: Test if the connection hasn't been used for a longer while if so use mysql_ping() to verify the connection - if(!result->_connection->connection->isValid()) { - try { - logError(0, "MySQL connection is invalid! Closing connection!"); - result->_connection->connection->close(); - } catch(sql::SQLException& ex) {} + return result; +} + +void MySQLManager::connection_closed(const std::shared_ptr &connection) { + bool call_disconnect = false; + { + unique_lock connections_lock(this->connections_lock); + auto index = find(this->connections.begin(), this->connections.end(), connection); + if(index == this->connections.end()) return; + + this->connections.erase(index); + call_disconnect = this->connections.empty(); } - if(result->_connection->connection->isClosed()) { - logError(0, "MySQL connection was closed! Attempt reconnect!"); - try { - if(!result->_connection->connection->reconnect()) { - logError(0, "MySQL connection reconnect attempt failed! Dropping connection!"); - { - lock_guard connections_lock(this->connections_lock); - auto index = find(this->connections.begin(), this->connections.end(), result->_connection); - if(index != this->connections.end()) - this->connections.erase(index); - } - return this->next_connection(); - } - } catch (sql::SQLException& ex) { - logError(0, "Got an exception while reconnecting! Message: " + string(ex.what())); - logError(0, "Dropping connection!"); - { - lock_guard connections_lock(this->connections_lock); - auto index = find(this->connections.begin(), this->connections.end(), result->_connection); - if(index != this->connections.end()) - this->connections.erase(index); - } - return this->next_connection(); - } - } - return result; + + auto dl = this->listener_disconnected; + if(call_disconnect && dl) + dl(this->disconnecting); } result MySQLManager::executeCommand(std::shared_ptr _ptr) { auto ptr = static_pointer_cast(_ptr); - std::lock_guard command_lock(ptr->lock); + if(!ptr) { return {-1, "invalid command handle"}; } + + std::lock_guard lock(ptr->lock); auto command = ptr->sql_command; + auto variables = ptr->variables; vector mapped_variables; if(!sql::mysql::evaluate_sql_query(command, variables, mapped_variables)) return {ptr->sql_command, -1, "Could not map sqlite vars to mysql!"}; - unique_ptr connection = this->next_connection(); + FreeGuard bind_parameter_memory{nullptr}; + if(!sql::mysql::create_bind(bind_parameter_memory.ptr, mapped_variables)) return {ptr->sql_command, -1, "Failed to allocate bind memory!"}; + + ResultBind bind_result_data{0, nullptr, nullptr}; + + auto connection = this->next_connection(); if(!connection) return {ptr->sql_command, -1, "Could not get a valid connection!"}; - try { - PreparedStatementHandle stmt(connection->_connection->connection->prepareStatement(command), prepared_statement_release); - //logMessage(LOG_GENERAL, "Deleting prepered statement {} and thread {}", (void*) stmt.get(), (void*) threads::self::id()); - if(!stmt) return {ptr->sql_command, -1, "Could not span a prepared statement"}; - if(!sql::mysql::bind_parms(stmt, mapped_variables)) return {ptr->sql_command, -1, "Could not bind variables!"}; + StatementGuard stmt_guard{mysql_stmt_init(connection->connection->handle)}; + if(!stmt_guard.stmt) + return {ptr->sql_command, -1, "failed to allocate statement"}; - auto update_count = stmt->executeUpdate(); - if(update_count < 0) - return {ptr->sql_command, -1, "Could not execute update. Code: " + to_string(update_count)}; + if(mysql_stmt_prepare(stmt_guard.stmt, command.c_str(), command.length())) { + auto errc = mysql_stmt_errno(stmt_guard.stmt); + if(errc == CR_SERVER_GONE_ERROR || errc == CR_SERVER_LOST || errc == CR_CONNECTION_ERROR) + this->connection_closed(connection->connection); - stmt.reset(); - return result::success; - } catch (sql::SQLException& ex) { - logError(0, "SQL Error: {}", ex.what()); - return {ptr->sql_command, -1, ex.what()}; + return {ptr->sql_command, -1, "failed to prepare statement: " + string(mysql_stmt_error(stmt_guard.stmt))}; } + + /* validate all parameters */ + auto parameter_count = mysql_stmt_param_count(stmt_guard.stmt); + if(parameter_count != mapped_variables.size()) + return {ptr->sql_command, -1, "invalid parameter count. Statement contains " + to_string(parameter_count) + " parameters but only " + to_string(mapped_variables.size()) + " are given."}; + + if(bind_parameter_memory.ptr) { + if(mysql_stmt_bind_param(stmt_guard.stmt, (MYSQL_BIND*) bind_parameter_memory.ptr)) + return {ptr->sql_command, -1, "failed to bind parameters to statement: " + string(mysql_stmt_error(stmt_guard.stmt))}; + } else if(parameter_count > 0) + return {ptr->sql_command, -1, "invalid parameter count. Statement contains " + to_string(parameter_count) + " parameters but only " + to_string(mapped_variables.size()) + " are given (bind nullptr)."}; + + + if(mysql_stmt_execute(stmt_guard.stmt)) { + auto errc = mysql_stmt_errno(stmt_guard.stmt); + if(errc == CR_SERVER_GONE_ERROR || errc == CR_SERVER_LOST || errc == CR_CONNECTION_ERROR) + this->connection_closed(connection->connection); + + return {ptr->sql_command, -1, "failed to execute query statement: " + string(mysql_stmt_error(stmt_guard.stmt))}; + } + + return result::success; } result MySQLManager::queryCommand(shared_ptr _ptr, const QueryCallback &fn) { auto ptr = static_pointer_cast(_ptr); + if(!ptr) { return {-1, "invalid command handle"}; } + std::lock_guard lock(ptr->lock); auto command = ptr->sql_command; + auto variables = ptr->variables; vector mapped_variables; if(!sql::mysql::evaluate_sql_query(command, variables, mapped_variables)) return {ptr->sql_command, -1, "Could not map sqlite vars to mysql!"}; - unique_ptr connection = this->next_connection(); - if(!connection) return {ptr->sql_command, -1, "Could not get a valid connection!"}; - try { - PreparedStatementHandle stmt(connection->_connection->connection->prepareStatement(command), prepared_statement_release); - //logMessage(LOG_GENERAL, "Deleting prepered statement {} and thread {}", (void*) stmt.get(), (void*) threads::self::id()); - if(!sql::mysql::bind_parms(stmt, mapped_variables)) return {ptr->sql_command, -1, "Could not bind variables!"}; - ResultHandle result(stmt->executeQuery(), result_release); + FreeGuard bind_parameter_memory{nullptr}; + if(!sql::mysql::create_bind(bind_parameter_memory.ptr, mapped_variables)) return {ptr->sql_command, -1, "Failed to allocate bind memory!"}; - auto column_count = result->getMetaData()->getColumnCount(); - std::string columnNames[column_count]; - std::string columnValues[column_count]; + ResultBind bind_result_data{0, nullptr, nullptr}; - for(int index = 0; index < column_count; index++) - columnNames[index] = result->getMetaData()->getColumnName(index + 1); - - bool userQuit = false; - while(result->next() && !userQuit) { - for(int index = 0; index < column_count; index++) - columnValues[index] = result->getString(index + 1); - if(fn(column_count, columnValues, columnNames) != 0) { - userQuit = true; - break; - } - } - - stmt.reset(); - return result::success; - } catch (sql::SQLException& ex) { - return {ptr->sql_command, -1, ex.what()}; - } -} - -result MySQLManager::execute_raw(const std::string &command) { auto connection = this->next_connection(); - if(!connection) - return {command, -1, "no connection available"}; + if(!connection) return {ptr->sql_command, -1, "Could not get a valid connection!"}; - auto& mysql_connection = connection->_connection->connection; - mysql_connection->clearWarnings(); + StatementGuard stmt_guard{mysql_stmt_init(connection->connection->handle)}; + if(!stmt_guard.stmt) + return {ptr->sql_command, -1, "failed to allocate statement"}; - StatementHandle statement(mysql_connection->createStatement(), statement_release); - statement->clearWarnings(); + if(mysql_stmt_prepare(stmt_guard.stmt, command.c_str(), command.length())) { + auto errc = mysql_stmt_errno(stmt_guard.stmt); + if(errc == CR_SERVER_GONE_ERROR || errc == CR_SERVER_LOST || errc == CR_CONNECTION_ERROR) + this->connection_closed(connection->connection); - try { - auto result = statement->execute(command); - if(statement->getWarnings()) { - cerr << "Got some warnings: " << endl; - while(statement->getWarnings()) { - cerr << " - " << statement->getWarnings()->getMessage() << endl; - statement->getWarnings()->setNextWarning(statement->getWarnings()->getNextWarning()); - } - } - if(result) - return result::success; - return {command, -1, "return false"}; - } catch (sql::SQLException& ex) { - return {command, -1, ex.what()}; + return {ptr->sql_command, -1, "failed to prepare statement: " + string(mysql_stmt_error(stmt_guard.stmt))}; } + + /* validate all parameters */ + { + auto parameter_count = mysql_stmt_param_count(stmt_guard.stmt); + if(parameter_count != mapped_variables.size()) + return {ptr->sql_command, -1, "invalid parameter count. Statement contains " + to_string(parameter_count) + " parameters but only " + to_string(mapped_variables.size()) + " are given."}; + } + + if(bind_parameter_memory.ptr) { + if(mysql_stmt_bind_param(stmt_guard.stmt, (MYSQL_BIND*) bind_parameter_memory.ptr)) + return {ptr->sql_command, -1, "failed to bind parameters to statement: " + string(mysql_stmt_error(stmt_guard.stmt))}; + } + + if(mysql_stmt_execute(stmt_guard.stmt)) { + auto errc = mysql_stmt_errno(stmt_guard.stmt); + if(errc == CR_SERVER_GONE_ERROR || errc == CR_SERVER_LOST || errc == CR_CONNECTION_ERROR) + this->connection_closed(connection->connection); + + return {ptr->sql_command, -1, "failed to execute query statement: " + string(mysql_stmt_error(stmt_guard.stmt))}; + } + + //if(mysql_stmt_store_result(stmt_guard.stmt)) + // return {ptr->sql_command, -1, "failed to store query result: " + string(mysql_stmt_error(stmt_guard.stmt))}; + + ResultGuard result_guard{mysql_stmt_result_metadata(stmt_guard.stmt)}; + if(!result_guard.result) + return {ptr->sql_command, -1, "failed to query result metadata: " + string(mysql_stmt_error(stmt_guard.stmt))}; + + + auto field_count = mysql_num_fields(result_guard.result); + DeleteAGuard field_names{new string[field_count]}; + DeleteAGuard field_values{new string[field_count]}; + + { + auto field_meta = mysql_fetch_fields(result_guard.result); + if(!field_meta && field_count > 0) + return {ptr->sql_command, -1, "failed to fetch field meta"}; + + if(!sql::mysql::create_result_bind(field_count, field_meta, bind_result_data)) + return {ptr->sql_command, -1, "failed to allocate result buffer"}; + + if(mysql_stmt_bind_result(stmt_guard.stmt, (MYSQL_BIND*) bind_result_data.memory)) + return {ptr->sql_command, -1, "failed to bind response buffer to statement: " + string(mysql_stmt_error(stmt_guard.stmt))}; + + for(size_t index = 0; index < field_count; index++) { + field_names.ptr[index] = field_meta[index].name; // field_meta cant be null because it has been checked above + //cout << field_names.ptr[index] << " - " << field_meta[index].max_length << endl; + } + } + + bool user_quit = false; + int stmt_code, row_id = 0; + while(!(stmt_code = mysql_stmt_fetch(stmt_guard.stmt))) { + bind_result_data.get_as_string(field_values.ptr); + + if(fn(field_count, field_values.ptr, field_names.ptr) != 0) { + user_quit = true; + break; + } + + row_id++; + } + + if(!user_quit) { + if(stmt_code == 1) { + auto errc = mysql_stmt_errno(stmt_guard.stmt); + if(errc == CR_SERVER_GONE_ERROR || errc == CR_SERVER_LOST || errc == CR_CONNECTION_ERROR) + this->connection_closed(connection->connection); + + return {ptr->sql_command, -1, "failed to fetch response row " + to_string(row_id) + ": " + string(mysql_stmt_error(stmt_guard.stmt))}; + } else if(stmt_code == MYSQL_NO_DATA) + ; + else if(stmt_code == MYSQL_DATA_TRUNCATED) + return {ptr->sql_command, -1, "response data has been truncated"}; + } + return result::success; } \ No newline at end of file diff --git a/src/sql/mysql/MySQL.h b/src/sql/mysql/MySQL.h index a7d4132..b7de7a3 100644 --- a/src/sql/mysql/MySQL.h +++ b/src/sql/mysql/MySQL.h @@ -1,21 +1,13 @@ #pragma once -#if __cplusplus >= 201703L - /* MySQL override. This needed to be inclided before cppconn/exception.h to define them */ - #include - #include - #include - - /* Now remove the trow */ - #define throw(...) - #include - #undef throw /* reset */ -#endif - +#include +#include +#include #include #include "sql/SqlQuery.h" -#include -#include + +#include "../../misc/spin_lock.h" +#include #define ERROR_MYSQL_MISSING_DRIVER -1 #define ERROR_MYSQL_INVLID_CONNECT -2 @@ -25,38 +17,33 @@ namespace sql { namespace mysql { class MySQLManager; - struct LocalConnection; bool evaluate_sql_query(std::string& sql, const std::vector& vars, std::vector& result); class MySQLCommand : public CommandData { }; - struct ConnectionEntry { - friend class MySQLManager; - friend class LocalConnection; - public: - typedef std::function DisconnectListener; + struct Connection { + MYSQL* handle = nullptr; - private: - ConnectionEntry(std::unique_ptr&& connection, bool used) : connection(std::move(connection)), used(used) {} + spin_lock used_lock; + bool used = false; - std::auto_ptr save_point; - std::unique_ptr connection; - bool used = false; + ~Connection(); }; - struct LocalConnection { - LocalConnection(MySQLManager* mgr, const std::shared_ptr& entry); - ~LocalConnection(); - MySQLManager* _mgr; - std::shared_ptr _connection; + struct AcquiredConnection { + MySQLManager* owner; + std::shared_ptr connection; + + AcquiredConnection(MySQLManager* owner, const std::shared_ptr&); + ~AcquiredConnection(); }; class MySQLManager : public SqlManager { - friend class LocalConnection; + friend struct AcquiredConnection; public: - typedef std::function&)> ListenerConnectionDisconnect; - typedef std::function&)> ListenerConnectionCreated; + //typedef std::function&)> ListenerConnectionDisconnect; + //typedef std::function&)> ListenerConnectionCreated; typedef std::function ListenerConnected; typedef std::function ListenerDisconnected; @@ -69,8 +56,6 @@ namespace sql { result disconnect() override; ListenerDisconnected listener_disconnected; - - result execute_raw(const std::string& /* command */); protected: std::shared_ptr copyCommandData(std::shared_ptr ptr) override; std::shared_ptr allocateCommandData() override; @@ -78,13 +63,13 @@ namespace sql { result queryCommand(std::shared_ptr ptr, const QueryCallback &fn) override; public: - inline std::unique_ptr next_connection(); + std::unique_ptr next_connection(); + void connection_closed(const std::shared_ptr& /* connection */); std::mutex connections_lock; std::condition_variable connections_condition; - std::deque> connections; + std::deque> connections; - sql::Driver* driver = nullptr; bool disconnecting = false; }; } diff --git a/test/SQLTest.cpp b/test/SQLTest.cpp index c7d0f15..d2f6b7d 100644 --- a/test/SQLTest.cpp +++ b/test/SQLTest.cpp @@ -2,6 +2,8 @@ #include #include #include +#include +#include using namespace sql; using namespace std; @@ -148,18 +150,39 @@ int main() { sql::command((SqlManager*) nullptr, std::string("SELECT *"), {":hello", "world"}, {":numeric", 2}); #endif - sql::mysql::MySQLManager manager; - if(!manager.connect("mysql://localhost:3306/test?userName=root&password=markus&connections=1")) { + sql::mysql::MySQLManager manager; + manager.listener_disconnected = [](bool x){ + cout << "Disconnect: " << x << endl; + }; + + if(!manager.connect("mysql://localhost:3306/teaspeak?userName=root&password=markus&connections=1")) { cerr << "failed to connect" << endl; return 1; } - auto result = sql::command(&manager, "SELECT * FROM `level_miner`").query([](int length, std::string* values, std::string* names) { - cout << "-- entry" << endl; - for(int index = 0; index < length; index++) { - cout << " " << names[index] << " => " << values[index] << endl; + /* + auto result = sql::command(&manager,"INSERT INTO `level_miner` (`username`, `a`) VALUES (:username, :value)", variable{":username", "Hello"}, variable{":value", "TEST!"}).execute(); + if(!result) cout << result.fmtStr() << endl; + + while(true) { + result = sql::command(&manager, "SELECT * FROM `level_miner`").query([](int length, std::string* values, std::string* names) { + cout << "-- entry" << endl; + for(int index = 0; index < length; index++) { + cout << " " << names[index] << " => " << values[index] << endl; + } + }); + cout << result.fmtStr() << endl; + this_thread::sleep_for(chrono::seconds(1)); + } + */ + + sql::command(&manager, "SELECT `cldbid`,`firstConnect`,`connections` FROM `clients` WHERE `serverId` = :sid AND `clientUid`=:uid LIMIT 1", variable{":sid", 0}, variable{":uid", "serveradmin"}).query([&](void* cl, int length, string* values, string* names){ + for (int index = 0; index < length; index++) { + logTrace(0, "Reading client property from client database table. (Key: " + names[index] + ", Value: " + values[index] + ")"); } - }); - cout << result.fmtStr() << endl; + return 0; + }, (void*) nullptr); + + mysql_library_end(); } \ No newline at end of file