302 lines
9.9 KiB
C++
302 lines
9.9 KiB
C++
|
#include <netinet/tcp.h>
|
||
|
#include <log/LogUtils.h>
|
||
|
#include <misc/memtracker.h>
|
||
|
#include "crypt.h"
|
||
|
#define DEFINE_HELPER
|
||
|
#include "LicenseRequest.h"
|
||
|
#include "License.h"
|
||
|
#include <csignal>
|
||
|
|
||
|
using namespace std;
|
||
|
using namespace std::chrono;
|
||
|
using namespace ts;
|
||
|
using namespace license;
|
||
|
|
||
|
#define DEBUG_LICENSE_CLIENT
|
||
|
#define CERR(message) LICENSE_FERR(this, CouldNotConnectException, message)
|
||
|
|
||
|
|
||
|
LicenceRequest::LicenceRequest(const std::shared_ptr<LicenseRequestData> & license, const sockaddr_in& remoteAddr) : data(license) {
|
||
|
#ifdef DEBUG_LICENSE_CLIENT
|
||
|
memtrack::allocated<LicenceRequest>(this);
|
||
|
#endif
|
||
|
memcpy(&this->remote_address, &remoteAddr, sizeof(remoteAddr));
|
||
|
|
||
|
assert(license->info);
|
||
|
}
|
||
|
|
||
|
LicenceRequest::~LicenceRequest() {
|
||
|
#ifdef DEBUG_LICENSE_CLIENT
|
||
|
memtrack::freed<LicenceRequest>(this);
|
||
|
#endif
|
||
|
this->abortRequest();
|
||
|
|
||
|
if(this->closeThread) {
|
||
|
this->closeThread->join();
|
||
|
delete this->closeThread;
|
||
|
this->closeThread = nullptr;
|
||
|
}
|
||
|
|
||
|
|
||
|
delete this->currentFuture;
|
||
|
this->currentFuture = nullptr;
|
||
|
}
|
||
|
|
||
|
threads::Future<std::shared_ptr<LicenseRequestResponse>> LicenceRequest::requestInfo() {
|
||
|
{
|
||
|
lock_guard lock(this->lock);
|
||
|
if(this->currentFuture) return *this->currentFuture;
|
||
|
this->currentFuture = new threads::Future<std::shared_ptr<LicenseRequestResponse>>();
|
||
|
}
|
||
|
|
||
|
this->beginRequest();
|
||
|
return *this->currentFuture;
|
||
|
}
|
||
|
|
||
|
|
||
|
//Basic IO
|
||
|
void LicenceRequest::handleEventWrite(int fd, short event, void* ptrClient) {
|
||
|
auto* client = static_cast<LicenceRequest *>(ptrClient);
|
||
|
|
||
|
buffer::RawBuffer* buffer = nullptr;
|
||
|
{
|
||
|
lock_guard lock(client->lock);
|
||
|
if((event & EV_TIMEOUT) > 0) { //Connect timeout
|
||
|
LICENSE_FERR(client, ConnectionException, "Connect timeout");
|
||
|
return;
|
||
|
}
|
||
|
|
||
|
if(client->state == protocol::CONNECTING){
|
||
|
client->handleConnected();
|
||
|
}
|
||
|
|
||
|
if(client->state == protocol::UNCONNECTED || !client->event_write)
|
||
|
return;
|
||
|
|
||
|
buffer = TAILQ_FIRST(&client->writeQueue);
|
||
|
if(!buffer) return;
|
||
|
|
||
|
auto writtenBytes = send(fd, &buffer->buffer[buffer->index], buffer->length - buffer->index, MSG_NOSIGNAL | MSG_DONTWAIT);
|
||
|
buffer->index += writtenBytes;
|
||
|
|
||
|
if(buffer->index >= buffer->length) {
|
||
|
TAILQ_REMOVE(&client->writeQueue, buffer, tail);
|
||
|
delete buffer;
|
||
|
}
|
||
|
if(!TAILQ_EMPTY(&client->writeQueue))
|
||
|
event_add(client->event_write, nullptr);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
void LicenceRequest::sendPacket(const protocol::packet& packet) {
|
||
|
if(this->state == protocol::UNCONNECTED || this->state == protocol::DISCONNECTING) {
|
||
|
if(this->verbose)
|
||
|
logError("Tried to send a packet to an unconnected remote!");
|
||
|
return;
|
||
|
}
|
||
|
packet.prepare();
|
||
|
|
||
|
auto buffer = new buffer::RawBuffer(packet.data.length() + sizeof(packet.header));
|
||
|
memcpy(buffer->buffer, &packet.header, sizeof(packet.header));
|
||
|
memcpy(&buffer->buffer[sizeof(packet.header)], packet.data.data(), packet.data.length());
|
||
|
|
||
|
if(!this->cryptKey.empty())
|
||
|
xorBuffer(&buffer->buffer[sizeof(packet.header)], packet.data.length(), this->cryptKey.data(), this->cryptKey.length());
|
||
|
|
||
|
{
|
||
|
lock_guard lock(this->lock);
|
||
|
TAILQ_INSERT_TAIL(&this->writeQueue, buffer, tail);
|
||
|
if(this->event_write)
|
||
|
event_add(this->event_write, nullptr);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
void LicenceRequest::handleEventRead(int fd, short, void* ptrClient) {
|
||
|
auto* client = static_cast<LicenceRequest *>(ptrClient);
|
||
|
|
||
|
auto buffer = std::unique_ptr<void, decltype(free)*>{malloc(1024), free};
|
||
|
sockaddr_in remoteAddr{};
|
||
|
socklen_t remoteAddrSize = sizeof(remoteAddr);
|
||
|
|
||
|
auto read = recvfrom(fd, buffer.get(), 1024, MSG_NOSIGNAL | MSG_DONTWAIT, reinterpret_cast<sockaddr *>(&remoteAddr), &remoteAddrSize);
|
||
|
|
||
|
if(read < 0){
|
||
|
if(errno == EWOULDBLOCK) return;
|
||
|
if(client->event_read)
|
||
|
event_del_noblock(client->event_read);
|
||
|
LICENSE_FERR(client, ConnectionException, "Invalid read: " + string(strerror(errno)) + "/" + to_string(errno));
|
||
|
return;
|
||
|
} else if(read == 0) {
|
||
|
if(client->event_read)
|
||
|
event_del_noblock(client->event_read);
|
||
|
LICENSE_FERR(client, ConnectionException, "IO error (" + to_string(errno) + "): " + string(strerror(errno)));
|
||
|
return;
|
||
|
}
|
||
|
|
||
|
client->handleMessage(string((char*) buffer.get(), read));
|
||
|
}
|
||
|
|
||
|
|
||
|
static int enabled = 1;
|
||
|
static int disabled = 0;
|
||
|
void LicenceRequest::beginRequest() {
|
||
|
lock_guard lock(this->lock);
|
||
|
TAILQ_INIT(&this->writeQueue);
|
||
|
|
||
|
this->file_descriptor = socket(AF_INET, SOCK_STREAM | SOCK_NONBLOCK, 0);
|
||
|
if(this->file_descriptor < 0) CERR("Socket setup failed");
|
||
|
|
||
|
signal(SIGPIPE, SIG_IGN);
|
||
|
|
||
|
auto state = ::connect(this->file_descriptor, reinterpret_cast<const sockaddr *>(&this->remote_address), sizeof(this->remote_address));
|
||
|
if(state < 0 && errno != EINPROGRESS) CERR("connect() failed (" + string(strerror(errno)) + ")");
|
||
|
|
||
|
if(setsockopt(this->file_descriptor, SOL_SOCKET, SO_REUSEADDR, &enabled, sizeof(enabled)) < 0) CERR("could not set reuse addr");
|
||
|
if(setsockopt(this->file_descriptor, IPPROTO_TCP, TCP_CORK, &disabled, sizeof(disabled)) < 0) CERR("could not set no push");
|
||
|
|
||
|
if(fcntl(this->file_descriptor, F_SETFD, fcntl(this->file_descriptor, F_GETFL, 0) | FD_CLOEXEC | O_NONBLOCK) < 0) CERR("Failed to set FD_CLOEXEC and O_NONBLOCK");
|
||
|
|
||
|
this->event_base = event_base_new();
|
||
|
this->event_read = event_new(this->event_base, this->file_descriptor, EV_READ | EV_PERSIST, LicenceRequest::handleEventRead, this);
|
||
|
this->event_write = event_new(this->event_base, this->file_descriptor, EV_WRITE, LicenceRequest::handleEventWrite, this);
|
||
|
|
||
|
this->state = protocol::CONNECTING; //First set connected, then we could enable the event loop
|
||
|
|
||
|
event_dispatch = std::thread([&]() {
|
||
|
signal(SIGPIPE, SIG_IGN);
|
||
|
|
||
|
{ /* now we could start listening */
|
||
|
lock_guard _lock(this->lock);
|
||
|
if(!this->event_read || !this->event_write) return;
|
||
|
|
||
|
event_add(this->event_read, nullptr);
|
||
|
timeval connect_timeout{5, 0};
|
||
|
event_add(this->event_write, &connect_timeout);
|
||
|
}
|
||
|
|
||
|
event_base_dispatch(this->event_base);
|
||
|
});
|
||
|
}
|
||
|
|
||
|
void LicenceRequest::handleConnected() {
|
||
|
this->state = protocol::HANDSCAKE;
|
||
|
|
||
|
uint8_t handshakeBuffer[4];
|
||
|
handshakeBuffer[0] = 0xC0;
|
||
|
handshakeBuffer[1] = 0xFF;
|
||
|
handshakeBuffer[2] = 0xEE;
|
||
|
handshakeBuffer[3] = LICENSE_PROT_VERSION;
|
||
|
this->sendPacket(protocol::packet{protocol::PACKET_CLIENT_HANDSHAKE, string((const char*) handshakeBuffer, 4)}); //Initialise packet
|
||
|
}
|
||
|
|
||
|
void LicenceRequest::handleMessage(const std::string& message) {
|
||
|
if(message.length() < sizeof(protocol::packet::header)) LICENSE_FERR(this, ConnectionException, "Invalid packet size");
|
||
|
protocol::packet packet{protocol::PACKET_DISCONNECT, ""};
|
||
|
memcpy(&packet.header, message.data(), sizeof(protocol::packet::header));
|
||
|
packet.data = message.substr(sizeof(protocol::packet::header));
|
||
|
|
||
|
if(!this->cryptKey.empty()) {
|
||
|
xorBuffer((char*) packet.data.data(), packet.data.length(), this->cryptKey.data(), this->cryptKey.length());
|
||
|
}
|
||
|
|
||
|
if(packet.header.packetId == protocol::PACKET_SERVER_HANDSHAKE) {
|
||
|
this->handlePacketHandshake(packet.data);
|
||
|
} else if(packet.header.packetId == protocol::PACKET_DISCONNECT) {
|
||
|
this->handlePacketDisconnect(packet.data);
|
||
|
} else if(packet.header.packetId == protocol::PACKET_SERVER_VALIDATION_RESPONSE) {
|
||
|
this->handlePacketLicenseInfo(packet.data);
|
||
|
} else if(packet.header.packetId == protocol::PACKET_SERVER_PROPERTY_ADJUSTMENT) {
|
||
|
this->handlePacketInfoAdjustment(packet.data);
|
||
|
} else
|
||
|
LICENSE_FERR(this, ConnectionException, "Invalid packet id (" + to_string(packet.header.packetId) + ")");
|
||
|
}
|
||
|
|
||
|
void LicenceRequest::disconnect(const std::string& message) {
|
||
|
if(this->state != protocol::UNCONNECTED && this->state != protocol::DISCONNECTING)
|
||
|
this->sendPacket({protocol::PACKET_DISCONNECT, message});
|
||
|
this->closeConnection();
|
||
|
//TODO flush?
|
||
|
}
|
||
|
|
||
|
void LicenceRequest::closeConnection() {
|
||
|
event *event_read, *event_write;
|
||
|
{
|
||
|
lock_guard lock(this->lock);
|
||
|
if(this->state == protocol::UNCONNECTED) return;
|
||
|
|
||
|
if(this->event_dispatch.get_id() == this_thread::get_id()) { //We could not close in the same thread as we read/write (we're joining it later)
|
||
|
if(this->state == protocol::DISCONNECTING) return;
|
||
|
|
||
|
this->state = protocol::DISCONNECTING;
|
||
|
this->closeThread = new threads::Thread(THREAD_SAVE_OPERATIONS, [&]() { this->closeConnection(); });
|
||
|
|
||
|
#ifdef DEBUG_LICENSE_CLIENT
|
||
|
if(this->verbose) {
|
||
|
debugMessage("Running close in a new thread");
|
||
|
this->closeThread->name("License request close");
|
||
|
}
|
||
|
#endif
|
||
|
return;
|
||
|
}
|
||
|
this->state = protocol::UNCONNECTED;
|
||
|
|
||
|
event_read = this->event_read;
|
||
|
event_write = this->event_write;
|
||
|
|
||
|
this->event_write = nullptr;
|
||
|
this->event_read = nullptr;
|
||
|
}
|
||
|
|
||
|
if(event_read) {
|
||
|
event_del_block(event_read);
|
||
|
event_free(event_read);
|
||
|
}
|
||
|
|
||
|
if(event_write) {
|
||
|
event_del_block(event_write);
|
||
|
event_free(event_write);
|
||
|
}
|
||
|
|
||
|
/* close before base shutdown (else epoll hangup) */
|
||
|
if(this->file_descriptor > 0) {
|
||
|
shutdown(this->file_descriptor, SHUT_RDWR);
|
||
|
close(this->file_descriptor);
|
||
|
}
|
||
|
this->file_descriptor = 0;
|
||
|
|
||
|
{
|
||
|
lock_guard lock(this->lock);
|
||
|
ts::buffer::RawBuffer* buffer;
|
||
|
while ((buffer = TAILQ_FIRST(&this->writeQueue))) {
|
||
|
TAILQ_REMOVE(&this->writeQueue, buffer, tail);
|
||
|
delete buffer;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
{
|
||
|
if(this->event_base) {
|
||
|
timeval seconds{1, 0};
|
||
|
event_base_loopexit(this->event_base, &seconds);
|
||
|
event_base_loopexit(this->event_base, nullptr);
|
||
|
}
|
||
|
|
||
|
if(this->event_dispatch.joinable()) {
|
||
|
this->event_dispatch.join();
|
||
|
}
|
||
|
|
||
|
if(this->event_base) {
|
||
|
event_base_free(this->event_base);
|
||
|
this->event_base = nullptr;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
#ifdef DEBUG_LICENSE_CLIENT
|
||
|
if(this->verbose)
|
||
|
debugMessage("Executing close done");
|
||
|
#endif
|
||
|
}
|
||
|
|
||
|
void LicenceRequest::abortRequest(const std::chrono::system_clock::time_point &timeout) {
|
||
|
this->closeConnection();
|
||
|
}
|