Teaspeak-Server/license/shared/LicenseRequest.cpp

302 lines
9.9 KiB
C++

#include <netinet/tcp.h>
#include <log/LogUtils.h>
#include <misc/memtracker.h>
#include "crypt.h"
#define DEFINE_HELPER
#include "LicenseRequest.h"
#include "License.h"
#include <csignal>
using namespace std;
using namespace std::chrono;
using namespace ts;
using namespace license;
#define DEBUG_LICENSE_CLIENT
#define CERR(message) LICENSE_FERR(this, CouldNotConnectException, message)
LicenceRequest::LicenceRequest(const std::shared_ptr<LicenseRequestData> & license, const sockaddr_in& remoteAddr) : data(license) {
#ifdef DEBUG_LICENSE_CLIENT
memtrack::allocated<LicenceRequest>(this);
#endif
memcpy(&this->remote_address, &remoteAddr, sizeof(remoteAddr));
assert(license->info);
}
LicenceRequest::~LicenceRequest() {
#ifdef DEBUG_LICENSE_CLIENT
memtrack::freed<LicenceRequest>(this);
#endif
this->abortRequest();
if(this->closeThread) {
this->closeThread->join();
delete this->closeThread;
this->closeThread = nullptr;
}
delete this->currentFuture;
this->currentFuture = nullptr;
}
threads::Future<std::shared_ptr<LicenseRequestResponse>> LicenceRequest::requestInfo() {
{
lock_guard lock(this->lock);
if(this->currentFuture) return *this->currentFuture;
this->currentFuture = new threads::Future<std::shared_ptr<LicenseRequestResponse>>();
}
this->beginRequest();
return *this->currentFuture;
}
//Basic IO
void LicenceRequest::handleEventWrite(int fd, short event, void* ptrClient) {
auto* client = static_cast<LicenceRequest *>(ptrClient);
buffer::RawBuffer* buffer = nullptr;
{
lock_guard lock(client->lock);
if((event & EV_TIMEOUT) > 0) { //Connect timeout
LICENSE_FERR(client, ConnectionException, "Connect timeout");
return;
}
if(client->state == protocol::CONNECTING){
client->handleConnected();
}
if(client->state == protocol::UNCONNECTED || !client->event_write)
return;
buffer = TAILQ_FIRST(&client->writeQueue);
if(!buffer) return;
auto writtenBytes = send(fd, &buffer->buffer[buffer->index], buffer->length - buffer->index, MSG_NOSIGNAL | MSG_DONTWAIT);
buffer->index += writtenBytes;
if(buffer->index >= buffer->length) {
TAILQ_REMOVE(&client->writeQueue, buffer, tail);
delete buffer;
}
if(!TAILQ_EMPTY(&client->writeQueue))
event_add(client->event_write, nullptr);
}
}
void LicenceRequest::sendPacket(const protocol::packet& packet) {
if(this->state == protocol::UNCONNECTED || this->state == protocol::DISCONNECTING) {
if(this->verbose)
logError("Tried to send a packet to an unconnected remote!");
return;
}
packet.prepare();
auto buffer = new buffer::RawBuffer(packet.data.length() + sizeof(packet.header));
memcpy(buffer->buffer, &packet.header, sizeof(packet.header));
memcpy(&buffer->buffer[sizeof(packet.header)], packet.data.data(), packet.data.length());
if(!this->cryptKey.empty())
xorBuffer(&buffer->buffer[sizeof(packet.header)], packet.data.length(), this->cryptKey.data(), this->cryptKey.length());
{
lock_guard lock(this->lock);
TAILQ_INSERT_TAIL(&this->writeQueue, buffer, tail);
if(this->event_write)
event_add(this->event_write, nullptr);
}
}
void LicenceRequest::handleEventRead(int fd, short, void* ptrClient) {
auto* client = static_cast<LicenceRequest *>(ptrClient);
auto buffer = std::unique_ptr<void, decltype(free)*>{malloc(1024), free};
sockaddr_in remoteAddr{};
socklen_t remoteAddrSize = sizeof(remoteAddr);
auto read = recvfrom(fd, buffer.get(), 1024, MSG_NOSIGNAL | MSG_DONTWAIT, reinterpret_cast<sockaddr *>(&remoteAddr), &remoteAddrSize);
if(read < 0){
if(errno == EWOULDBLOCK) return;
if(client->event_read)
event_del_noblock(client->event_read);
LICENSE_FERR(client, ConnectionException, "Invalid read: " + string(strerror(errno)) + "/" + to_string(errno));
return;
} else if(read == 0) {
if(client->event_read)
event_del_noblock(client->event_read);
LICENSE_FERR(client, ConnectionException, "IO error (" + to_string(errno) + "): " + string(strerror(errno)));
return;
}
client->handleMessage(string((char*) buffer.get(), read));
}
static int enabled = 1;
static int disabled = 0;
void LicenceRequest::beginRequest() {
lock_guard lock(this->lock);
TAILQ_INIT(&this->writeQueue);
this->file_descriptor = socket(AF_INET, SOCK_STREAM | SOCK_NONBLOCK, 0);
if(this->file_descriptor < 0) CERR("Socket setup failed");
signal(SIGPIPE, SIG_IGN);
auto state = ::connect(this->file_descriptor, reinterpret_cast<const sockaddr *>(&this->remote_address), sizeof(this->remote_address));
if(state < 0 && errno != EINPROGRESS) CERR("connect() failed (" + string(strerror(errno)) + ")");
if(setsockopt(this->file_descriptor, SOL_SOCKET, SO_REUSEADDR, &enabled, sizeof(enabled)) < 0) CERR("could not set reuse addr");
if(setsockopt(this->file_descriptor, IPPROTO_TCP, TCP_CORK, &disabled, sizeof(disabled)) < 0) CERR("could not set no push");
if(fcntl(this->file_descriptor, F_SETFD, fcntl(this->file_descriptor, F_GETFL, 0) | FD_CLOEXEC | O_NONBLOCK) < 0) CERR("Failed to set FD_CLOEXEC and O_NONBLOCK");
this->event_base = event_base_new();
this->event_read = event_new(this->event_base, this->file_descriptor, EV_READ | EV_PERSIST, LicenceRequest::handleEventRead, this);
this->event_write = event_new(this->event_base, this->file_descriptor, EV_WRITE, LicenceRequest::handleEventWrite, this);
this->state = protocol::CONNECTING; //First set connected, then we could enable the event loop
event_dispatch = std::thread([&]() {
signal(SIGPIPE, SIG_IGN);
{ /* now we could start listening */
lock_guard _lock(this->lock);
if(!this->event_read || !this->event_write) return;
event_add(this->event_read, nullptr);
timeval connect_timeout{5, 0};
event_add(this->event_write, &connect_timeout);
}
event_base_dispatch(this->event_base);
});
}
void LicenceRequest::handleConnected() {
this->state = protocol::HANDSCAKE;
uint8_t handshakeBuffer[4];
handshakeBuffer[0] = 0xC0;
handshakeBuffer[1] = 0xFF;
handshakeBuffer[2] = 0xEE;
handshakeBuffer[3] = LICENSE_PROT_VERSION;
this->sendPacket(protocol::packet{protocol::PACKET_CLIENT_HANDSHAKE, string((const char*) handshakeBuffer, 4)}); //Initialise packet
}
void LicenceRequest::handleMessage(const std::string& message) {
if(message.length() < sizeof(protocol::packet::header)) LICENSE_FERR(this, ConnectionException, "Invalid packet size");
protocol::packet packet{protocol::PACKET_DISCONNECT, ""};
memcpy(&packet.header, message.data(), sizeof(protocol::packet::header));
packet.data = message.substr(sizeof(protocol::packet::header));
if(!this->cryptKey.empty()) {
xorBuffer((char*) packet.data.data(), packet.data.length(), this->cryptKey.data(), this->cryptKey.length());
}
if(packet.header.packetId == protocol::PACKET_SERVER_HANDSHAKE) {
this->handlePacketHandshake(packet.data);
} else if(packet.header.packetId == protocol::PACKET_DISCONNECT) {
this->handlePacketDisconnect(packet.data);
} else if(packet.header.packetId == protocol::PACKET_SERVER_VALIDATION_RESPONSE) {
this->handlePacketLicenseInfo(packet.data);
} else if(packet.header.packetId == protocol::PACKET_SERVER_PROPERTY_ADJUSTMENT) {
this->handlePacketInfoAdjustment(packet.data);
} else
LICENSE_FERR(this, ConnectionException, "Invalid packet id (" + to_string(packet.header.packetId) + ")");
}
void LicenceRequest::disconnect(const std::string& message) {
if(this->state != protocol::UNCONNECTED && this->state != protocol::DISCONNECTING)
this->sendPacket({protocol::PACKET_DISCONNECT, message});
this->closeConnection();
//TODO flush?
}
void LicenceRequest::closeConnection() {
event *event_read, *event_write;
{
lock_guard lock(this->lock);
if(this->state == protocol::UNCONNECTED) return;
if(this->event_dispatch.get_id() == this_thread::get_id()) { //We could not close in the same thread as we read/write (we're joining it later)
if(this->state == protocol::DISCONNECTING) return;
this->state = protocol::DISCONNECTING;
this->closeThread = new threads::Thread(THREAD_SAVE_OPERATIONS, [&]() { this->closeConnection(); });
#ifdef DEBUG_LICENSE_CLIENT
if(this->verbose) {
debugMessage("Running close in a new thread");
this->closeThread->name("License request close");
}
#endif
return;
}
this->state = protocol::UNCONNECTED;
event_read = this->event_read;
event_write = this->event_write;
this->event_write = nullptr;
this->event_read = nullptr;
}
if(event_read) {
event_del_block(event_read);
event_free(event_read);
}
if(event_write) {
event_del_block(event_write);
event_free(event_write);
}
/* close before base shutdown (else epoll hangup) */
if(this->file_descriptor > 0) {
shutdown(this->file_descriptor, SHUT_RDWR);
close(this->file_descriptor);
}
this->file_descriptor = 0;
{
lock_guard lock(this->lock);
ts::buffer::RawBuffer* buffer;
while ((buffer = TAILQ_FIRST(&this->writeQueue))) {
TAILQ_REMOVE(&this->writeQueue, buffer, tail);
delete buffer;
}
}
{
if(this->event_base) {
timeval seconds{1, 0};
event_base_loopexit(this->event_base, &seconds);
event_base_loopexit(this->event_base, nullptr);
}
if(this->event_dispatch.joinable()) {
this->event_dispatch.join();
}
if(this->event_base) {
event_base_free(this->event_base);
this->event_base = nullptr;
}
}
#ifdef DEBUG_LICENSE_CLIENT
if(this->verbose)
debugMessage("Executing close done");
#endif
}
void LicenceRequest::abortRequest(const std::chrono::system_clock::time_point &timeout) {
this->closeConnection();
}