Totally fucked up...

This commit is contained in:
root 2020-03-17 12:08:33 +01:00
parent d6f483a019
commit d2ba1b4eee
3 changed files with 488 additions and 488 deletions

View File

@ -1,9 +1,9 @@
The general namespace prefix is ts:: The general namespace prefix is ts::
TeaSpeak - Server: ts::server TeaSpeak - Server: ts::server
Basic: ts::server Basic: ts::server
Sub-Server: Sub-Server:
Query: ts::server::server::query Query: ts::server::server::query
Voice: ts::server::server::udp Voice: ts::server::server::udp
File: ts::server::server::file File: ts::server::server::file
Web: ts::server::server::web Web: ts::server::server::web

View File

@ -1,384 +1,384 @@
// //
// Created by WolverinDEV on 11/03/2020. // Created by WolverinDEV on 11/03/2020.
// //
#include "./QueryClientConnection.h" #include "./QueryClientConnection.h"
#include <netinet/tcp.h> #include <netinet/tcp.h>
#include <log/LogUtils.h> #include <log/LogUtils.h>
#include <pipes/errors.h> #include <pipes/errors.h>
#include <src/InstanceHandler.h> #include <src/InstanceHandler.h>
#include "./QueryClient.h" #include "./QueryClient.h"
#include "../ConnectedClient.h" #include "../ConnectedClient.h"
#include "../../server/QueryServer.h" #include "../../server/QueryServer.h"
#include "QueryClientConnection.h" #include "QueryClientConnection.h"
using namespace ts::server::server::query; using namespace ts::server::server::query;
#if defined(TCP_CORK) && !defined(TCP_NOPUSH) #if defined(TCP_CORK) && !defined(TCP_NOPUSH)
#define TCP_NOPUSH TCP_CORK #define TCP_NOPUSH TCP_CORK
#endif #endif
namespace ts::server::server::query { namespace ts::server::server::query {
/* will be set by the event loop */ /* will be set by the event loop */
thread_local bool thread_is_event_loop{false}; thread_local bool thread_is_event_loop{false};
} }
QueryClientConnection::QueryClientConnection(ts::server::QueryClient *client, int fd) : client_handle{client}, file_descriptor_{fd} { QueryClientConnection::QueryClientConnection(ts::server::QueryClient *client, int fd) : client_handle{client}, file_descriptor_{fd} {
TAILQ_INIT(&this->write_queue); TAILQ_INIT(&this->write_queue);
} }
QueryClientConnection::~QueryClientConnection() { QueryClientConnection::~QueryClientConnection() {
this->finalize(true); this->finalize(true);
} }
bool QueryClientConnection::initialize(std::string &error) { bool QueryClientConnection::initialize(std::string &error) {
assert(this->client_handle); assert(this->client_handle);
int enabled{1}; int enabled{1};
int disabled{0}; int disabled{0};
setsockopt(this->file_descriptor_, SOL_SOCKET, SO_KEEPALIVE, &enabled, sizeof(enabled)); setsockopt(this->file_descriptor_, SOL_SOCKET, SO_KEEPALIVE, &enabled, sizeof(enabled));
if(setsockopt(this->file_descriptor_, IPPROTO_TCP, TCP_NOPUSH, &disabled, sizeof disabled) < 0) if(setsockopt(this->file_descriptor_, IPPROTO_TCP, TCP_NOPUSH, &disabled, sizeof disabled) < 0)
logError(LOG_QUERY, "Could not disable nopush for {} ({}/{})", CLIENT_STR_LOG_PREFIX_(this->client_handle), errno, strerror(errno)); logError(LOG_QUERY, "Could not disable nopush for {} ({}/{})", CLIENT_STR_LOG_PREFIX_(this->client_handle), errno, strerror(errno));
if(setsockopt(this->file_descriptor_, IPPROTO_TCP, TCP_NODELAY, &enabled, sizeof enabled) < 0) if(setsockopt(this->file_descriptor_, IPPROTO_TCP, TCP_NODELAY, &enabled, sizeof enabled) < 0)
logError(LOG_QUERY, "[Query] Could not disable no delay for {} ({}/{})", CLIENT_STR_LOG_PREFIX_(this->client_handle), errno, strerror(errno)); logError(LOG_QUERY, "[Query] Could not disable no delay for {} ({}/{})", CLIENT_STR_LOG_PREFIX_(this->client_handle), errno, strerror(errno));
auto query_server = this->client_handle->getQueryServer(); auto query_server = this->client_handle->getQueryServer();
this->readEvent = event_new(query_server->io_event_loop(), this->file_descriptor_, EV_READ | EV_PERSIST, [](int a1, short a2, void* _this) { this->readEvent = event_new(query_server->io_event_loop(), this->file_descriptor_, EV_READ | EV_PERSIST, [](int a1, short a2, void* _this) {
reinterpret_cast<QueryClientConnection*>(_this)->handle_event_read(a1, a2); reinterpret_cast<QueryClientConnection*>(_this)->handle_event_read(a1, a2);
}, this); }, this);
this->writeEvent = event_new(query_server->io_event_loop(), this->file_descriptor_, EV_WRITE, [](int a1, short a2, void* _this){ this->writeEvent = event_new(query_server->io_event_loop(), this->file_descriptor_, EV_WRITE, [](int a1, short a2, void* _this){
reinterpret_cast<QueryClientConnection*>(_this)->handle_event_write(a1, a2); reinterpret_cast<QueryClientConnection*>(_this)->handle_event_write(a1, a2);
}, this); }, this);
this->connection_state = ConnectionState::INITIALIZING; this->connection_state = ConnectionState::INITIALIZING;
if(ts::config::query::sslMode == 0) { if(ts::config::query::sslMode == 0) {
this->connection_state = ConnectionState::CONNECTED; this->connection_state = ConnectionState::CONNECTED;
this->connection_type_ = ConnectionType::PLAIN_TEXT; this->connection_type_ = ConnectionType::PLAIN_TEXT;
this->client_handle->handle_connection_initialized(); this->client_handle->handle_connection_initialized();
} }
return true; return true;
} }
void QueryClientConnection::add_read_event() { void QueryClientConnection::add_read_event() {
std::lock_guard elock{this->event_mutex}; std::lock_guard elock{this->event_mutex};
if(this->readEvent) event_add(this->readEvent, nullptr); if(this->readEvent) event_add(this->readEvent, nullptr);
} }
void QueryClientConnection::finalize(bool is_destructor_call) { void QueryClientConnection::finalize(bool is_destructor_call) {
auto old_state = this->connection_state; auto old_state = this->connection_state;
this->connection_state = ConnectionState::DISCONNECTED; this->connection_state = ConnectionState::DISCONNECTED;
/* unregister event handling */ /* unregister event handling */
{ {
std::unique_lock elock{this->event_mutex}; std::unique_lock elock{this->event_mutex};
auto wevent = std::exchange(this->writeEvent, nullptr); auto wevent = std::exchange(this->writeEvent, nullptr);
auto revent = std::exchange(this->readEvent, nullptr); auto revent = std::exchange(this->readEvent, nullptr);
elock.unlock(); elock.unlock();
if(revent) { if(revent) {
if(thread_is_event_loop) if(thread_is_event_loop)
event_del_noblock(revent); event_del_noblock(revent);
else else
event_del_block(revent); /* may calls finalize() while we're waiting. But thats okey. */ event_del_block(revent); /* may calls finalize() while we're waiting. But thats okey. */
event_free(revent); event_free(revent);
} }
if(wevent) { if(wevent) {
if(thread_is_event_loop) if(thread_is_event_loop)
event_del_noblock(wevent); event_del_noblock(wevent);
else else
event_del_block(wevent); /* may calls finalize() while we're waiting. But thats okey. */ event_del_block(wevent); /* may calls finalize() while we're waiting. But thats okey. */
event_free(wevent); event_free(wevent);
} }
} }
{ {
std::lock_guard block{this->buffer_lock}; std::lock_guard block{this->buffer_lock};
/* Free the entire tail queue. */ /* Free the entire tail queue. */
while (auto buffer = TAILQ_FIRST(&this->write_queue)) { while (auto buffer = TAILQ_FIRST(&this->write_queue)) {
TAILQ_REMOVE(&this->write_queue, buffer, tq); TAILQ_REMOVE(&this->write_queue, buffer, tq);
free(buffer->original_ptr); free(buffer->original_ptr);
delete buffer; delete buffer;
} }
TAILQ_INIT(&this->write_queue); /* just ensures a valid tailq */ TAILQ_INIT(&this->write_queue); /* just ensures a valid tailq */
::free(this->read_buffer.buffer); ::free(this->read_buffer.buffer);
this->read_buffer.buffer = nullptr; this->read_buffer.buffer = nullptr;
this->read_buffer.length = 0; this->read_buffer.length = 0;
this->read_buffer.fill_count = 0; this->read_buffer.fill_count = 0;
} }
if(!is_destructor_call && old_state != ConnectionState::DISCONNECTED) if(!is_destructor_call && old_state != ConnectionState::DISCONNECTED)
this->client_handle->handle_connection_finalized(); this->client_handle->handle_connection_finalized();
} }
void QueryClientConnection::handle_event_read(int fd, short events) { void QueryClientConnection::handle_event_read(int fd, short events) {
constexpr auto buffer_length{1024 * 4}; constexpr auto buffer_length{1024 * 4};
uint8_t buffer[buffer_length]; uint8_t buffer[buffer_length];
auto length = read(fd, (void *) buffer, buffer_length); auto length = read(fd, (void *) buffer, buffer_length);
if (length <= 0) { if (length <= 0) {
if (errno == EINTR || errno == EAGAIN) if (errno == EINTR || errno == EAGAIN)
return; return;
else if (length == 0) { else if (length == 0) {
logMessage(LOG_QUERY, "{} Connection closed (r). Client disconnected.", logMessage(LOG_QUERY, "{} Connection closed (r). Client disconnected.",
CLIENT_STR_LOG_PREFIX_(this->client_handle)); CLIENT_STR_LOG_PREFIX_(this->client_handle));
} else { } else {
logError(LOG_QUERY, "{} Failed to read! Code: {} errno: {} message: {}", logError(LOG_QUERY, "{} Failed to read! Code: {} errno: {} message: {}",
CLIENT_STR_LOG_PREFIX_(this->client_handle), length, errno, strerror(errno)); CLIENT_STR_LOG_PREFIX_(this->client_handle), length, errno, strerror(errno));
} }
event_del_noblock(this->readEvent); event_del_noblock(this->readEvent);
this->close_connection(std::chrono::system_clock::time_point{}); this->close_connection(std::chrono::system_clock::time_point{});
return; return;
} }
if (this->connection_type_ == ConnectionType::PLAIN_TEXT) { if (this->connection_type_ == ConnectionType::PLAIN_TEXT) {
plain_text_buffer_insert: plain_text_buffer_insert:
this->handle_decoded_message(buffer, length); this->handle_decoded_message(buffer, length);
} else if (this->connection_type_ == ConnectionType::SSL_ENCRYPTED) { } else if (this->connection_type_ == ConnectionType::SSL_ENCRYPTED) {
ssl_buffer_insert:; ssl_buffer_insert:;
this->ssl_handler.process_incoming_data(pipes::buffer_view{(const char*) buffer, (size_t) length});; this->ssl_handler.process_incoming_data(pipes::buffer_view{(const char*) buffer, (size_t) length});;
} else { } else {
if (config::query::sslMode != 0 && pipes::SSL::isSSLHeader(std::string{(const char *) buffer, (size_t) length})) { if (config::query::sslMode != 0 && pipes::SSL::isSSLHeader(std::string{(const char *) buffer, (size_t) length})) {
if(!this->initialize_ssl()) return; if(!this->initialize_ssl()) return;
/* /*
* - Content * - Content
* \x16 * \x16
* -Version (1) * -Version (1)
* \x03 \x00 * \x03 \x00
* - length (2) * - length (2)
* \x00 \x04 * \x00 \x04
* *
* - Header * - Header
* \x00 -> hello request (3) * \x00 -> hello request (3)
* \x05 -> length (4) * \x05 -> length (4)
*/ */
//this->writeRawMessage(string("\x16\x03\x01\x00\x05\x00\x00\x00\x00\x00", 10)); //this->writeRawMessage(string("\x16\x03\x01\x00\x05\x00\x00\x00\x00\x00", 10));
goto ssl_buffer_insert; goto ssl_buffer_insert;
} else { } else {
this->connection_type_ = ConnectionType::PLAIN_TEXT; this->connection_type_ = ConnectionType::PLAIN_TEXT;
this->client_handle->handle_connection_initialized(); this->client_handle->handle_connection_initialized();
goto plain_text_buffer_insert; goto plain_text_buffer_insert;
} }
} }
} }
void QueryClientConnection::handle_event_write(int fd, short events) { void QueryClientConnection::handle_event_write(int fd, short events) {
bool readd_write{false}; bool readd_write{false};
if(events & EV_WRITE) { if(events & EV_WRITE) {
/* Safe to access, because we're only reading the queue and the head could never change. Only within the IO loop itself. */ /* Safe to access, because we're only reading the queue and the head could never change. Only within the IO loop itself. */
WriteBuffer* wbuffer; WriteBuffer* wbuffer;
while((wbuffer = TAILQ_FIRST(&this->write_queue))) { while((wbuffer = TAILQ_FIRST(&this->write_queue))) {
auto written = send(fd, wbuffer->ptr, wbuffer->length, 0); auto written = send(fd, wbuffer->ptr, wbuffer->length, 0);
if(written <= 0) { if(written <= 0) {
if(errno == EAGAIN) { if(errno == EAGAIN) {
readd_write = true; readd_write = true;
break; break;
} }
if(written == 0) { if(written == 0) {
logMessage(LOG_QUERY, "{} Connection closed (w). Client disconnected.", CLIENT_STR_LOG_PREFIX_(this->client_handle)); logMessage(LOG_QUERY, "{} Connection closed (w). Client disconnected.", CLIENT_STR_LOG_PREFIX_(this->client_handle));
} else { } else {
logError(LOG_QUERY, "{} Failed to write! Code: {} errno: {} message: {}", CLIENT_STR_LOG_PREFIX_(this->client_handle), written, errno, strerror(errno)); logError(LOG_QUERY, "{} Failed to write! Code: {} errno: {} message: {}", CLIENT_STR_LOG_PREFIX_(this->client_handle), written, errno, strerror(errno));
} }
event_del_noblock(this->readEvent); event_del_noblock(this->readEvent);
this->close_connection(std::chrono::system_clock::time_point{}); this->close_connection(std::chrono::system_clock::time_point{});
return; return;
} }
wbuffer->length -= written; wbuffer->length -= written;
if(wbuffer->length == 0) { if(wbuffer->length == 0) {
std::lock_guard block{this->buffer_lock}; std::lock_guard block{this->buffer_lock};
TAILQ_REMOVE(&this->write_queue, wbuffer, tq); TAILQ_REMOVE(&this->write_queue, wbuffer, tq);
::free(wbuffer->original_ptr); ::free(wbuffer->original_ptr);
delete wbuffer; delete wbuffer;
} else { } else {
wbuffer->ptr += written; wbuffer->ptr += written;
} }
} }
} }
if(this->connection_state == ConnectionState::DISCONNECTING) { if(this->connection_state == ConnectionState::DISCONNECTING) {
if(!readd_write || (events & EV_TIMEOUT)) { if(!readd_write || (events & EV_TIMEOUT)) {
/* disconnect timeouted or nothing more to write */ /* disconnect timeouted or nothing more to write */
this->finalize(false); this->finalize(false);
return; return;
} else /* if(readd_write) */ { /* check not needed because tested before already */ } else /* if(readd_write) */ { /* check not needed because tested before already */
auto time_left = this->disconnect_timeout - std::chrono::system_clock::now(); auto time_left = this->disconnect_timeout - std::chrono::system_clock::now();
timeval timeout{0, 1}; timeval timeout{0, 1};
if(time_left.count() > 0) { if(time_left.count() > 0) {
timeout.tv_sec = std::chrono::floor<std::chrono::seconds>(time_left).count(); timeout.tv_sec = std::chrono::floor<std::chrono::seconds>(time_left).count();
timeout.tv_usec = std::chrono::floor<std::chrono::microseconds>(time_left).count() % 1000000ULL; timeout.tv_usec = std::chrono::floor<std::chrono::microseconds>(time_left).count() % 1000000ULL;
} }
event_add(this->writeEvent, &timeout); event_add(this->writeEvent, &timeout);
} }
} else if(readd_write) { } else if(readd_write) {
event_add(this->writeEvent, nullptr); event_add(this->writeEvent, nullptr);
} }
} }
bool QueryClientConnection::initialize_ssl() { bool QueryClientConnection::initialize_ssl() {
this->connection_type_ = ConnectionType::SSL_ENCRYPTED; this->connection_type_ = ConnectionType::SSL_ENCRYPTED;
this->ssl_handler.direct_process(pipes::PROCESS_DIRECTION_OUT, true); this->ssl_handler.direct_process(pipes::PROCESS_DIRECTION_OUT, true);
this->ssl_handler.direct_process(pipes::PROCESS_DIRECTION_IN, true); this->ssl_handler.direct_process(pipes::PROCESS_DIRECTION_IN, true);
this->ssl_handler.callback_data([&](const pipes::buffer_view &buffer) { this->ssl_handler.callback_data([&](const pipes::buffer_view &buffer) {
this->handle_decoded_message(buffer.data_ptr<void>(), buffer.length()); this->handle_decoded_message(buffer.data_ptr<void>(), buffer.length());
}); });
this->ssl_handler.callback_write([&](const pipes::buffer_view &buffer) { this->ssl_handler.callback_write([&](const pipes::buffer_view &buffer) {
this->send_data_raw({buffer.data_ptr<char>(), buffer.length()}); this->send_data_raw({buffer.data_ptr<char>(), buffer.length()});
}); });
this->ssl_handler.callback_initialized = [&] { this->ssl_handler.callback_initialized = [&] {
this->client_handle->handle_connection_initialized(); this->client_handle->handle_connection_initialized();
}; };
this->ssl_handler.callback_error([&](int code, const std::string& message) { this->ssl_handler.callback_error([&](int code, const std::string& message) {
if(code == PERROR_SSL_ACCEPT) { if(code == PERROR_SSL_ACCEPT) {
logError(LOG_QUERY, "{} Failed to initialize query ssl session ({})", CLIENT_STR_LOG_PREFIX_(this->client_handle), message); logError(LOG_QUERY, "{} Failed to initialize query ssl session ({})", CLIENT_STR_LOG_PREFIX_(this->client_handle), message);
this->close_connection(std::chrono::system_clock::time_point{}); this->close_connection(std::chrono::system_clock::time_point{});
} else if(code == PERROR_SSL_TIMEOUT) { } else if(code == PERROR_SSL_TIMEOUT) {
logError(LOG_QUERY, "{} Failed to initialize query ssl session (timeout: {})", CLIENT_STR_LOG_PREFIX_(this->client_handle), message); logError(LOG_QUERY, "{} Failed to initialize query ssl session (timeout: {})", CLIENT_STR_LOG_PREFIX_(this->client_handle), message);
this->close_connection(std::chrono::system_clock::time_point{}); this->close_connection(std::chrono::system_clock::time_point{});
} else } else
logError(LOG_QUERY, "{} Received SSL error ({} | {})", CLIENT_STR_LOG_PREFIX_(this->client_handle), code, message); logError(LOG_QUERY, "{} Received SSL error ({} | {})", CLIENT_STR_LOG_PREFIX_(this->client_handle), code, message);
}); });
{ {
auto context = serverInstance->sslManager()->getQueryContext(); auto context = serverInstance->sslManager()->getQueryContext();
auto options = std::make_shared<pipes::SSL::Options>(); auto options = std::make_shared<pipes::SSL::Options>();
options->type = pipes::SSL::SERVER; options->type = pipes::SSL::SERVER;
options->context_method = TLS_method(); options->context_method = TLS_method();
options->default_keypair({context->privateKey, context->certificate}); options->default_keypair({context->privateKey, context->certificate});
if(!this->ssl_handler.initialize(options)) { if(!this->ssl_handler.initialize(options)) {
logError(LOG_QUERY, "[{}] Failed to setup ssl!", CLIENT_STR_LOG_PREFIX_(this->client_handle)); logError(LOG_QUERY, "[{}] Failed to setup ssl!", CLIENT_STR_LOG_PREFIX_(this->client_handle));
this->close_connection(std::chrono::system_clock::time_point{}); this->close_connection(std::chrono::system_clock::time_point{});
return false; return false;
} }
} }
return true; return true;
} }
void QueryClientConnection::handle_decoded_message(const void *buffer, size_t size) { void QueryClientConnection::handle_decoded_message(const void *buffer, size_t size) {
{ {
std::lock_guard block{this->buffer_lock}; std::lock_guard block{this->buffer_lock};
if((this->read_buffer.length - this->read_buffer.fill_count) < size) { /* !this->read_buffer.buffer is already implicitly implemented because by default read_buffer.length will be zero */ if((this->read_buffer.length - this->read_buffer.fill_count) < size) { /* !this->read_buffer.buffer is already implicitly implemented because by default read_buffer.length will be zero */
const auto new_size{this->read_buffer.length + size + 128}; const auto new_size{this->read_buffer.length + size + 128};
auto new_buffer = ::malloc(new_size); auto new_buffer = ::malloc(new_size);
assert(new_buffer); assert(new_buffer);
if(this->read_buffer.fill_count) memcpy(new_buffer, this->read_buffer.buffer, this->read_buffer.fill_count); if(this->read_buffer.fill_count) memcpy(new_buffer, this->read_buffer.buffer, this->read_buffer.fill_count);
::free(this->read_buffer.buffer); ::free(this->read_buffer.buffer);
this->read_buffer.buffer = new_buffer; this->read_buffer.buffer = new_buffer;
this->read_buffer.length = new_size; this->read_buffer.length = new_size;
} }
assert(this->read_buffer.buffer); assert(this->read_buffer.buffer);
assert(this->read_buffer.length - this->read_buffer.fill_count >= size); assert(this->read_buffer.length - this->read_buffer.fill_count >= size);
memcpy((char*) this->read_buffer.buffer + this->read_buffer.fill_count, buffer, size); memcpy((char*) this->read_buffer.buffer + this->read_buffer.fill_count, buffer, size);
this->read_buffer.fill_count += size; this->read_buffer.fill_count += size;
} }
{ {
//TODO: Improve this command progress //TODO: Improve this command progress
auto qserver{this->client_handle->handle}; auto qserver{this->client_handle->handle};
if(qserver) { if(qserver) {
auto wlock{this->client_handle->_this}; auto wlock{this->client_handle->_this};
qserver->executePool()->execute([wlock]() { qserver->executePool()->execute([wlock]() {
auto client{std::dynamic_pointer_cast<QueryClient>(wlock.lock())}; auto client{std::dynamic_pointer_cast<QueryClient>(wlock.lock())};
if(!client) return; if(!client) return;
int counter = 0; int counter = 0;
while(client->process_next_command() && counter++ < 15); while(client->process_next_command() && counter++ < 15);
}); });
} }
} }
} }
void QueryClientConnection::send_data(const std::string_view &buffer) { void QueryClientConnection::send_data(const std::string_view &buffer) {
if(this->connection_type_ == ConnectionType::PLAIN_TEXT) if(this->connection_type_ == ConnectionType::PLAIN_TEXT)
this->send_data_raw(buffer); this->send_data_raw(buffer);
else if(this->connection_type_ == ConnectionType::SSL_ENCRYPTED) else if(this->connection_type_ == ConnectionType::SSL_ENCRYPTED)
this->ssl_handler.send(pipes::buffer_view{buffer.data(), buffer.length()}); this->ssl_handler.send(pipes::buffer_view{buffer.data(), buffer.length()});
} }
void QueryClientConnection::send_data_raw(const std::string_view &buffer) { void QueryClientConnection::send_data_raw(const std::string_view &buffer) {
auto wbuf = new WriteBuffer{}; auto wbuf = new WriteBuffer{};
wbuf->original_ptr = (char*) malloc(buffer.length()); wbuf->original_ptr = (char*) malloc(buffer.length());
wbuf->ptr = wbuf->original_ptr; wbuf->ptr = wbuf->original_ptr;
memcpy(wbuf->ptr, buffer.data(), buffer.length()); memcpy(wbuf->ptr, buffer.data(), buffer.length());
wbuf->length = buffer.length(); wbuf->length = buffer.length();
{ {
std::lock_guard wlock{this->buffer_lock}; std::lock_guard wlock{this->buffer_lock};
TAILQ_INSERT_TAIL(&this->write_queue, wbuf, tq); TAILQ_INSERT_TAIL(&this->write_queue, wbuf, tq);
} }
{ {
std::lock_guard elock{this->event_mutex}; std::lock_guard elock{this->event_mutex};
if(this->writeEvent) if(this->writeEvent)
event_add(this->writeEvent, nullptr); event_add(this->writeEvent, nullptr);
} }
} }
void QueryClientConnection::close_connection(const std::chrono::system_clock::time_point &timeout) { void QueryClientConnection::close_connection(const std::chrono::system_clock::time_point &timeout) {
if(timeout.time_since_epoch().count() > 0) { if(timeout.time_since_epoch().count() > 0) {
this->connection_state = ConnectionState::DISCONNECTING; this->connection_state = ConnectionState::DISCONNECTING;
this->disconnect_timeout = timeout; this->disconnect_timeout = timeout;
std::lock_guard elock{this->event_mutex}; std::lock_guard elock{this->event_mutex};
if(this->writeEvent) { if(this->writeEvent) {
event_add(this->writeEvent, nullptr); event_add(this->writeEvent, nullptr);
return; return;
} }
/* failed to add the write event, so call disconnect */ /* failed to add the write event, so call disconnect */
} }
if(this->connection_state == ConnectionState::DISCONNECTED) return; if(this->connection_state == ConnectionState::DISCONNECTED) return;
this->finalize(false); this->finalize(false);
} }
void QueryClientConnection::enforce_text_connection() { void QueryClientConnection::enforce_text_connection() {
if(this->connection_state != ConnectionState::INITIALIZING) return; if(this->connection_state != ConnectionState::INITIALIZING) return;
this->connection_state = ConnectionState::CONNECTED; this->connection_state = ConnectionState::CONNECTED;
this->connection_type_ = ConnectionType::PLAIN_TEXT; this->connection_type_ = ConnectionType::PLAIN_TEXT;
this->client_handle->handle_connection_initialized(); this->client_handle->handle_connection_initialized();
} }
CommandAssembleState QueryClientConnection::next_command(std::string &result) { CommandAssembleState QueryClientConnection::next_command(std::string &result) {
std::lock_guard block{this->buffer_lock}; std::lock_guard block{this->buffer_lock};
auto new_line_idx = (char*) memchr(this->read_buffer.buffer, '\n', this->read_buffer.fill_count); auto new_line_idx = (char*) memchr(this->read_buffer.buffer, '\n', this->read_buffer.fill_count);
if(!new_line_idx) return CommandAssembleState::NO_COMMAND_PENDING; if(!new_line_idx) return CommandAssembleState::NO_COMMAND_PENDING;
const auto length = ((char*) this->read_buffer.buffer - new_line_idx) * sizeof(*new_line_idx); const auto length = ((char*) this->read_buffer.buffer - new_line_idx) * sizeof(*new_line_idx);
auto line_length{length}; auto line_length{length};
if(length > 0 && *(new_line_idx - 1) == '\r') if(length > 0 && *(new_line_idx - 1) == '\r')
line_length--; line_length--;
result.assign((char*) this->read_buffer.buffer, line_length); result.assign((char*) this->read_buffer.buffer, line_length);
//Do not copy the \r character //Do not copy the \r character
auto copy_bytes{this->read_buffer.fill_count - length}; auto copy_bytes{this->read_buffer.fill_count - length};
if(copy_bytes > 0 && *(new_line_idx + 1) == '\r') { if(copy_bytes > 0 && *(new_line_idx + 1) == '\r') {
copy_bytes--; copy_bytes--;
new_line_idx++; new_line_idx++;
} }
memcpy(this->read_buffer.buffer, new_line_idx + 1, copy_bytes); memcpy(this->read_buffer.buffer, new_line_idx + 1, copy_bytes);
this->read_buffer.fill_count = copy_bytes; this->read_buffer.fill_count = copy_bytes;
return copy_bytes == 0 ? CommandAssembleState::SUCCESS : CommandAssembleState::MORE_COMMANDS_PENDING; return copy_bytes == 0 ? CommandAssembleState::SUCCESS : CommandAssembleState::MORE_COMMANDS_PENDING;
} }

View File

@ -1,98 +1,98 @@
#pragma once #pragma once
#include <chrono> #include <chrono>
#include <event.h> #include <event.h>
#include <deque> #include <deque>
#include <string> #include <string>
#include <pipes/ssl.h> #include <pipes/ssl.h>
#include <sys/queue.h> #include <sys/queue.h>
namespace ts::server { namespace ts::server {
class QueryClient; class QueryClient;
} }
namespace ts::server::server::query { namespace ts::server::server::query {
enum struct ConnectionType { enum struct ConnectionType {
UNKNOWN, UNKNOWN,
PLAIN_TEXT, PLAIN_TEXT,
SSL_ENCRYPTED, SSL_ENCRYPTED,
/* SSH */ /* SSH */
}; };
enum struct ConnectionState { enum struct ConnectionState {
INITIALIZING, INITIALIZING,
CONNECTED, CONNECTED,
DISCONNECTING, DISCONNECTING,
DISCONNECTED DISCONNECTED
}; };
enum struct CommandAssembleState { enum struct CommandAssembleState {
SUCCESS, SUCCESS,
MORE_COMMANDS_PENDING, MORE_COMMANDS_PENDING,
NO_COMMAND_PENDING NO_COMMAND_PENDING
}; };
class QueryClientConnection { class QueryClientConnection {
public: public:
explicit QueryClientConnection(QueryClient* /* client */, int /* file descriptor */); explicit QueryClientConnection(QueryClient* /* client */, int /* file descriptor */);
~QueryClientConnection(); ~QueryClientConnection();
[[nodiscard]] inline ConnectionType connection_type() const { return this->connection_type_; } [[nodiscard]] inline ConnectionType connection_type() const { return this->connection_type_; }
bool initialize(std::string& /* error */); bool initialize(std::string& /* error */);
void add_read_event(); void add_read_event();
void finalize(bool /* is destructor call */); void finalize(bool /* is destructor call */);
void send_data(const std::string_view& /* payload */); void send_data(const std::string_view& /* payload */);
void send_data_raw(const std::string_view& /* payload */); void send_data_raw(const std::string_view& /* payload */);
void enforce_text_connection(); void enforce_text_connection();
[[nodiscard]] CommandAssembleState next_command(std::string& /* command */); [[nodiscard]] CommandAssembleState next_command(std::string& /* command */);
/* could be called from every thread (event IO thread) */ /* could be called from every thread (event IO thread) */
void close_connection(const std::chrono::system_clock::time_point& /* disconnect timeout */); void close_connection(const std::chrono::system_clock::time_point& /* disconnect timeout */);
private: private:
struct WriteBuffer { struct WriteBuffer {
char* original_ptr; char* original_ptr;
char* ptr; char* ptr;
size_t length; size_t length;
TAILQ_ENTRY(WriteBuffer) tq; TAILQ_ENTRY(WriteBuffer) tq;
}; };
QueryClient* client_handle{nullptr}; QueryClient* client_handle{nullptr};
ConnectionState connection_state{ConnectionState::INITIALIZING}; ConnectionState connection_state{ConnectionState::INITIALIZING};
std::chrono::system_clock::time_point disconnect_timeout{}; std::chrono::system_clock::time_point disconnect_timeout{};
ConnectionType connection_type_{ConnectionType::UNKNOWN}; ConnectionType connection_type_{ConnectionType::UNKNOWN};
int file_descriptor_{-1}; int file_descriptor_{-1};
/* only delete the events within the event loop! */ /* only delete the events within the event loop! */
std::mutex event_mutex{}; std::mutex event_mutex{};
::event* readEvent{nullptr}; ::event* readEvent{nullptr};
::event* writeEvent{nullptr}; ::event* writeEvent{nullptr};
pipes::SSL ssl_handler{}; pipes::SSL ssl_handler{};
std::mutex buffer_lock{}; std::mutex buffer_lock{};
struct { struct {
void* buffer{nullptr}; void* buffer{nullptr};
size_t length{0}; size_t length{0};
size_t fill_count{0}; size_t fill_count{0};
std::chrono::system_clock::time_point last_shrink{}; std::chrono::system_clock::time_point last_shrink{};
} read_buffer; } read_buffer;
TAILQ_HEAD(, WriteBuffer) write_queue{}; TAILQ_HEAD(, WriteBuffer) write_queue{};
void handle_event_write(int, short); void handle_event_write(int, short);
void handle_event_read(int, short); void handle_event_read(int, short);
bool initialize_ssl(); bool initialize_ssl();
void handle_decoded_message(const void* /* message */, size_t /* length */); void handle_decoded_message(const void* /* message */, size_t /* length */);
}; };
} }