Teaspeak-Server/license/shared/src/client.cpp

527 lines
19 KiB
C++

//
// Created by WolverinDEV on 23/02/2020.
//
#include <csignal>
#include <netinet/tcp.h>
#include <event.h>
#include <ThreadPool/ThreadHelper.h>
#include <misc/endianness.h>
#include "shared/include/license/client.h"
#include "crypt.h"
using namespace license::client;
LicenseServerClient::Buffer* LicenseServerClient::Buffer::allocate(size_t capacity) {
static_assert(std::is_trivially_constructible<Buffer>::value);
const auto allocated_bytes = sizeof(LicenseServerClient::Buffer) + capacity;
auto result = malloc(allocated_bytes);
if(!result) return nullptr;
auto buffer = reinterpret_cast<LicenseServerClient::Buffer*>(result);
buffer->capacity = capacity;
buffer->fill = 0;
buffer->offset = 0;
buffer->data = (char*) result + sizeof(LicenseServerClient::Buffer);
return buffer;
}
void LicenseServerClient::Buffer::free(Buffer *ptr) {
static_assert(std::is_trivially_destructible<Buffer>::value);
::free(ptr);
}
LicenseServerClient::LicenseServerClient(const sockaddr_in &address, int pversion) : protocol_version{pversion} {
memcpy(&this->network.address, &address, sizeof(address));
TAILQ_INIT(&this->buffers.write);
if(!this->buffers.read)
this->buffers.read = Buffer::allocate(1024 * 8);
}
LicenseServerClient::~LicenseServerClient() {
this->close_connection();
if(this->buffers.read)
Buffer::free(this->buffers.read);
threads::save_join(this->network.event_dispatch, false);
}
bool LicenseServerClient::start_connection(std::string &error) {
bool event_dispatch_spawned{false};
std::unique_lock slock{this->connection_lock};
if(this->connection_state != ConnectionState::UNCONNECTED) {
error = "invalid connection state";
return false;
}
this->connection_state = ConnectionState::CONNECTING;
this->communication.initialized = false;
this->network.file_descriptor = socket(this->network.address.sin_family, SOCK_STREAM | SOCK_NONBLOCK, 0);
if(this->network.file_descriptor < 0) {
error = "failed to allocate socket";
goto error_cleanup;
}
signal(SIGPIPE, SIG_IGN);
{
auto connect_state = ::connect(this->network.file_descriptor, reinterpret_cast<const sockaddr *>(&this->network.address), sizeof(this->network.address));
if(connect_state < 0 && errno != EINPROGRESS) {
error = "connect() failed (" + std::string{strerror(errno)} + ")";
goto error_cleanup;
}
}
{
int enabled{1}, disabled{0};
if(setsockopt(this->network.file_descriptor, SOL_SOCKET, SO_REUSEADDR, &enabled, sizeof(enabled)) < 0); //CERR("could not set reuse addr");
if(setsockopt(this->network.file_descriptor, IPPROTO_TCP, TCP_CORK, &disabled, sizeof(disabled)) < 0); // CERR("could not set no push");
if(fcntl(this->network.file_descriptor, F_SETFD, fcntl(this->network.file_descriptor, F_GETFL, 0) | FD_CLOEXEC | O_NONBLOCK) < 0); // CERR("Failed to set FD_CLOEXEC and O_NONBLOCK (" + std::to_string(errno) + ")");
}
this->network.event_base = event_base_new();
this->network.event_read = event_new(this->network.event_base, this->network.file_descriptor, EV_READ | EV_PERSIST, [](int, short e, void* _this) {
auto client = reinterpret_cast<LicenseServerClient*>(_this);
client->callback_read(e);
}, this);
this->network.event_write = event_new(this->network.event_base, this->network.file_descriptor, EV_WRITE, [](int, short e, void* _this) {
auto client = reinterpret_cast<LicenseServerClient*>(_this);
client->callback_write(e);
}, this);
event_dispatch_spawned = true;
this->network.event_dispatch = std::thread([&] {
signal(SIGPIPE, SIG_IGN);
event_add(this->network.event_read, nullptr);
timeval connect_timeout{5, 0};
event_add(this->network.event_write, &connect_timeout);
auto event_base{this->network.event_base};
event_base_loop(event_base, EVLOOP_NO_EXIT_ON_EMPTY);
event_base_free(event_base);
//this ptr might be dangling
});
return true;
error_cleanup:
this->cleanup_network_resources();
if(!event_dispatch_spawned) {
event_base_free(this->network.event_base);
this->network.event_base = nullptr;
}
this->connection_state = ConnectionState::UNCONNECTED;
return false;
}
void LicenseServerClient::close_connection() {
std::unique_lock slock{this->connection_lock};
if(this->connection_state == ConnectionState::UNCONNECTED) return;
this->connection_state = ConnectionState::UNCONNECTED;
this->cleanup_network_resources();
}
void LicenseServerClient::cleanup_network_resources() {
const auto is_event_loop = this->network.event_dispatch.get_id() == std::this_thread::get_id();
if(this->network.event_read) {
if(is_event_loop) event_del_noblock(this->network.event_read);
else event_del_block(this->network.event_read);
event_free(this->network.event_read);
this->network.event_read = nullptr;
}
if(this->network.event_write) {
if(is_event_loop) event_del_noblock(this->network.event_write);
else event_del_block(this->network.event_write);
event_free(this->network.event_write);
this->network.event_write = nullptr;
}
if(this->network.event_base) {
event_base_loopexit(this->network.event_base, nullptr);
if(!is_event_loop)
threads::save_join(this->network.event_dispatch, false);
this->network.event_base = nullptr; /* event base has been saved by the event dispatcher and will be freed there */
}
if(this->network.file_descriptor) {
::close(this->network.file_descriptor);
this->network.file_descriptor = 0;
}
{
std::lock_guard block{this->buffers.lock};
auto buffer = TAILQ_FIRST(&this->buffers.write);
while(buffer) {
auto next = TAILQ_NEXT(buffer, tail);
Buffer::free(next);
buffer = next;
}
TAILQ_INIT(&this->buffers.write);
this->buffers.notify_empty.notify_all();
}
}
void LicenseServerClient::callback_read(short events) {
constexpr static auto buffer_size{1024};
ssize_t read_bytes{0};
char buffer[buffer_size];
read_bytes = recv(this->network.file_descriptor, buffer, buffer_size, MSG_DONTWAIT);
if(read_bytes <= 0) {
if(errno == EAGAIN) return;
std::unique_lock slock{this->connection_lock};
std::string disconnect_reason{};
bool disconnect_expected{false};
switch (this->connection_state) {
case ConnectionState::CONNECTING:
disconnect_reason = "connect error (" + std::string{strerror(errno)} + ")";
disconnect_expected = false;
break;
case ConnectionState::INITIALIZING:
case ConnectionState::CONNECTED:
disconnect_reason = "read error (" + std::string{strerror(errno)} + ")";
disconnect_expected = false;
break;
case ConnectionState::DISCONNECTING:
disconnect_expected = true;
break;
case ConnectionState::UNCONNECTED:
return; /* we're obsolete */
}
if(auto callback{this->callback_disconnected}; callback) {
slock.unlock();
callback(disconnect_expected, disconnect_reason);
slock.lock();
}
if(this->connection_state != ConnectionState::UNCONNECTED) {
this->cleanup_network_resources();
this->connection_state = ConnectionState::UNCONNECTED;
}
return;
}
this->handle_data(buffer, (size_t) read_bytes);
}
void LicenseServerClient::callback_write(short events) {
bool add_write_event{this->connection_state == ConnectionState::DISCONNECTING};
if(events & EV_TIMEOUT) {
std::unique_lock slock{this->connection_lock};
if(this->connection_state == ConnectionState::CONNECTING || this->connection_state == ConnectionState::INITIALIZING) {
/* connect timeout */
if(auto callback{this->callback_disconnected}; callback) {
slock.unlock();
callback(false, "connect timeout");
slock.lock();
}
if(this->connection_state != ConnectionState::UNCONNECTED) {
this->cleanup_network_resources();
this->connection_state = ConnectionState::UNCONNECTED;
}
} else if(this->connection_state == ConnectionState::DISCONNECTING) {
/* disconnect timeout */
this->cleanup_network_resources();
this->connection_state = ConnectionState::UNCONNECTED;
}
return;
}
if(events & EV_WRITE) {
if(this->connection_state == ConnectionState::CONNECTING)
this->callback_socket_connected();
ssize_t written_bytes{0};
std::unique_lock block{this->buffers.lock};
auto buffer = TAILQ_FIRST(&this->buffers.write);
if(!buffer) {
this->buffers.notify_empty.notify_all();
return;
}
block.unlock();
written_bytes = send(this->network.file_descriptor, (char*) buffer->data + buffer->offset, buffer->fill - buffer->offset, MSG_DONTWAIT);
if(written_bytes <= 0) {
if(errno == EAGAIN) goto readd_event;
std::unique_lock slock{this->connection_lock};
std::string disconnect_reason{};
bool disconnect_expected{false};
switch (this->connection_state) {
case ConnectionState::CONNECTING:
case ConnectionState::INITIALIZING:
case ConnectionState::CONNECTED:
disconnect_reason = "write error (" + std::string{strerror(errno)} + ")";
disconnect_expected = false;
break;
case ConnectionState::DISCONNECTING:
disconnect_expected = true;
break;
case ConnectionState::UNCONNECTED:
return; /* we're obsolete */
}
if(auto callback{this->callback_disconnected}; callback) {
slock.unlock();
callback(disconnect_expected, disconnect_reason);
slock.lock();
}
if(this->connection_state != ConnectionState::UNCONNECTED) {
this->cleanup_network_resources();
this->connection_state = ConnectionState::UNCONNECTED;
}
return;
}
buffer->offset += (size_t) written_bytes;
if(buffer->offset >= buffer->fill) {
assert(buffer->offset == buffer->fill);
block.lock();
TAILQ_REMOVE(&this->buffers.write, buffer, tail);
if(!TAILQ_FIRST(&this->buffers.write)) {
this->buffers.notify_empty.notify_all();
} else {
add_write_event = true;
}
block.unlock();
Buffer::free(buffer);
}
}
if(this->network.event_write && add_write_event) {
readd_event:
auto timeout = this->disconnect_timeout;
if(timeout.time_since_epoch().count() == 0)
event_add(this->network.event_write, nullptr);
else {
auto now = std::chrono::system_clock::now();
struct timeval t{0, 1};
if(now > timeout) {
this->callback_write(EV_TIMEOUT);
return;
} else {
auto microseconds = std::chrono::duration_cast<std::chrono::microseconds>(timeout - now);
auto seconds = std::chrono::duration_cast<std::chrono::seconds>(microseconds);
microseconds -= seconds;
t.tv_usec = microseconds.count();
t.tv_sec = seconds.count();
}
event_add(this->network.event_write, &t);
}
}
}
void LicenseServerClient::handle_data(void *recv_buffer, size_t length) {
auto& buffer = this->buffers.read;
assert(buffer);
if(buffer->capacity - buffer->offset - buffer->fill < length) {
if(buffer->capacity - buffer->fill > length) {
memcpy(buffer->data, (char*) buffer->data + buffer->offset, buffer->fill);
buffer->offset = 0;
} else {
auto new_buffer = Buffer::allocate(buffer->fill + length);
memcpy(new_buffer->data, (char*) buffer->data + buffer->offset, buffer->fill);
new_buffer->fill = buffer->fill;
Buffer::free(buffer);
buffer = new_buffer;
}
}
auto buffer_ptr = (char*) buffer->data;
auto& buffer_offset = buffer->offset;
auto& buffer_length = buffer->fill;
memcpy((char*) buffer_ptr + buffer_offset + buffer_length, recv_buffer, length);
buffer_length += length;
while(true) {
if(buffer_length < sizeof(protocol::packet_header)) return;
auto header = reinterpret_cast<protocol::packet_header*>(buffer_ptr + buffer_offset);
if(header->length > 1024 * 8) {
if(auto callback{this->callback_disconnected}; callback)
callback(false, "received a too large message");
this->disconnect("received too large message", std::chrono::system_clock::time_point{});
return;
}
if(buffer_length < header->length + sizeof(protocol::packet_header)) return;
this->handle_raw_packet(header->packetId, buffer_ptr + buffer_offset + sizeof(protocol::packet_header), header->length);
buffer_offset += header->length + sizeof(protocol::packet_header);
buffer_length -= header->length + sizeof(protocol::packet_header);
}
}
void LicenseServerClient::send_message(protocol::PacketType type, const void *payload, size_t size) {
const auto packet_size = size + sizeof(protocol::packet_header);
auto buffer = Buffer::allocate(packet_size);
buffer->fill = packet_size;
auto header = (protocol::packet_header*) buffer->data;
header->length = packet_size;
header->packetId = type;
memcpy((char*) buffer->data + sizeof(protocol::packet_header), payload, size);
if(this->communication.initialized)
xorBuffer((char*) buffer->data + sizeof(protocol::packet_header), size, this->communication.crypt_key.data(), this->communication.crypt_key.length());
std::lock_guard clock{this->connection_lock};
if(this->connection_state == ConnectionState::UNCONNECTED || !this->network.event_write) {
Buffer::free(buffer);
return;
}
{
std::lock_guard block{this->buffers.lock};
TAILQ_INSERT_TAIL(&this->buffers.write, buffer, tail);
}
event_add(this->network.event_write, nullptr);
}
void LicenseServerClient::disconnect(const std::string &message, std::chrono::system_clock::time_point timeout) {
auto now = std::chrono::system_clock::now();
if(now > timeout)
timeout = now + std::chrono::seconds{timeout.time_since_epoch().count() ? 1 : 0};
std::unique_lock clock{this->connection_lock};
if(this->connection_state == ConnectionState::DISCONNECTING) {
this->disconnect_timeout = std::min(this->disconnect_timeout, timeout);
if(this->network.event_write)
event_add(this->network.event_write, nullptr); /* let the write update the timeout */
return;
}
this->disconnect_timeout = timeout;
if(this->connection_state != ConnectionState::INITIALIZING && this->connection_state != ConnectionState::CONNECTED) {
clock.unlock();
this->close_connection();
return;
}
this->connection_state = ConnectionState::DISCONNECTING;
if(this->network.event_read)
event_del_noblock(this->network.event_read);
clock.unlock();
this->send_message(protocol::PACKET_DISCONNECT, message.data(), message.length());
}
bool LicenseServerClient::await_disconnect() {
{
std::lock_guard clock{this->connection_lock};
if(this->connection_state != ConnectionState::DISCONNECTING)
return this->connection_state == ConnectionState::UNCONNECTED;
}
/* state might change here, but when we're disconnected the write buffer will be empty */
std::unique_lock block{this->buffers.lock};
while(TAILQ_FIRST(&this->buffers.write))
this->buffers.notify_empty.wait(block);
return std::chrono::system_clock::now() <= this->disconnect_timeout;
}
void LicenseServerClient::callback_socket_connected() {
{
std::lock_guard clock{this->connection_lock};
if(this->connection_state != ConnectionState::CONNECTING) return;
this->connection_state = ConnectionState::INITIALIZING;
}
uint8_t handshakeBuffer[4];
handshakeBuffer[0] = 0xC0;
handshakeBuffer[1] = 0xFF;
handshakeBuffer[2] = 0xEE;
handshakeBuffer[3] = this->protocol_version;
this->send_message(protocol::PACKET_CLIENT_HANDSHAKE, handshakeBuffer, 4);
}
void LicenseServerClient::handle_raw_packet(license::protocol::PacketType type, void * buffer, size_t length) {
/* decrypt packet */
if(this->communication.initialized)
xorBuffer((char*) buffer, length, this->communication.crypt_key.data(), this->communication.crypt_key.length());
if(type == protocol::PACKET_DISCONNECT) {
if(auto callback{this->callback_disconnected}; callback)
callback(false, std::string{(const char*) buffer, length});
this->close_connection();
return;
}
if(!this->communication.initialized) {
if(type != protocol::PACKET_SERVER_HANDSHAKE) {
if(auto callback{this->callback_disconnected}; callback)
callback(false, "expected handshake packet");
this->disconnect("expected handshake packet", std::chrono::system_clock::time_point{});
return;
}
this->handle_handshake_packet(buffer, length);
this->communication.initialized = true;
return;
}
if(auto callback{this->callback_message}; callback)
callback(type, buffer, length);
else
; //TODO: Print error?
}
void LicenseServerClient::handle_handshake_packet(void *buffer, size_t length) {
const auto data_ptr = (const char*) buffer;
std::string error{};
if(this->connection_state != ConnectionState::INITIALIZING) {
error = "invalid protocol state";
goto handle_error;
}
if(length < 5) {
error = "invalid packet size";
goto handle_error;
}
if((uint8_t) data_ptr[0] != 0xAF || (uint8_t) data_ptr[1] != 0xFE) {
error = "invalid handshake signature";
goto handle_error;
}
if((uint8_t) data_ptr[2] != this->protocol_version) {
error = "Invalid license protocol version. Please update TeaSpeak!";
goto handle_error;
}
{
auto key_length = be2le16(data_ptr, 3);
if(length < key_length + 5) {
error = "invalid packet size";
goto handle_error;
}
this->communication.crypt_key = std::string(data_ptr + 5, key_length);
this->communication.initialized = true;
}
if(auto callback{this->callback_connected}; callback)
callback();
return;
handle_error:
if(auto callback{this->callback_disconnected}; callback)
callback(false, error);
this->disconnect(error, std::chrono::system_clock::time_point{});
}