#include #include #include #include "crypt.h" #define DEFINE_HELPER #include "LicenseRequest.h" #include "License.h" #include using namespace std; using namespace std::chrono; using namespace ts; using namespace license; //#define DEBUG_LICENSE_CLIENT #define CERR(message) LICENSE_FERR(this, CouldNotConnectException, message) LicenceRequest::LicenceRequest(const std::shared_ptr & license, const sockaddr_in& remoteAddr) : data(license) { #ifdef DEBUG_LICENSE_CLIENT memtrack::allocated(this); #endif memcpy(&this->remote_address, &remoteAddr, sizeof(remoteAddr)); assert(license->info); } LicenceRequest::~LicenceRequest() { #ifdef DEBUG_LICENSE_CLIENT memtrack::freed(this); #endif this->abortRequest(); if(this->closeThread) { this->closeThread->join(); delete this->closeThread; this->closeThread = nullptr; } delete this->currentFuture; this->currentFuture = nullptr; } threads::Future> LicenceRequest::requestInfo() { { lock_guard lock(this->lock); if(this->currentFuture) return *this->currentFuture; this->currentFuture = new threads::Future>(); } this->beginRequest(); return *this->currentFuture; } //Basic IO void LicenceRequest::handleEventWrite(int fd, short event, void* ptrClient) { auto* client = static_cast(ptrClient); buffer::RawBuffer* buffer = nullptr; { lock_guard lock(client->lock); if((event & EV_TIMEOUT) > 0) { //Connect timeout LICENSE_FERR(client, ConnectionException, "Connect timeout"); return; } if(client->state == protocol::CONNECTING){ client->handleConnected(); } if(client->state == protocol::UNCONNECTED || !client->event_write) return; buffer = TAILQ_FIRST(&client->writeQueue); if(!buffer) return; auto writtenBytes = send(fd, &buffer->buffer[buffer->index], buffer->length - buffer->index, MSG_NOSIGNAL | MSG_DONTWAIT); buffer->index += writtenBytes; if(buffer->index >= buffer->length) { TAILQ_REMOVE(&client->writeQueue, buffer, tail); delete buffer; } if(!TAILQ_EMPTY(&client->writeQueue)) event_add(client->event_write, nullptr); } } void LicenceRequest::sendPacket(const protocol::packet& packet) { if(this->state == protocol::UNCONNECTED || this->state == protocol::DISCONNECTING) { if(this->verbose) logError(LOG_GENERAL, "Tried to send a packet to an unconnected remote!"); return; } packet.prepare(); auto buffer = new buffer::RawBuffer(packet.data.length() + sizeof(packet.header)); memcpy(buffer->buffer, &packet.header, sizeof(packet.header)); memcpy(&buffer->buffer[sizeof(packet.header)], packet.data.data(), packet.data.length()); if(!this->cryptKey.empty()) xorBuffer(&buffer->buffer[sizeof(packet.header)], packet.data.length(), this->cryptKey.data(), this->cryptKey.length()); { lock_guard lock(this->lock); TAILQ_INSERT_TAIL(&this->writeQueue, buffer, tail); if(this->event_write) event_add(this->event_write, nullptr); } } void LicenceRequest::handleEventRead(int fd, short, void* ptrClient) { auto* client = static_cast(ptrClient); auto buffer = std::unique_ptr{malloc(1024), free}; sockaddr_in remoteAddr{}; socklen_t remoteAddrSize = sizeof(remoteAddr); auto read = recvfrom(fd, buffer.get(), 1024, MSG_NOSIGNAL | MSG_DONTWAIT, reinterpret_cast(&remoteAddr), &remoteAddrSize); if(read < 0){ if(errno == EWOULDBLOCK) return; if(client->event_read) event_del_noblock(client->event_read); LICENSE_FERR(client, ConnectionException, "Invalid read: " + string(strerror(errno)) + "/" + to_string(errno)); return; } else if(read == 0) { if(client->event_read) event_del_noblock(client->event_read); LICENSE_FERR(client, ConnectionException, "IO error (" + to_string(errno) + "): " + string(strerror(errno))); return; } client->handleMessage(string((char*) buffer.get(), read)); } static int enabled = 1; static int disabled = 0; void LicenceRequest::beginRequest() { lock_guard lock(this->lock); TAILQ_INIT(&this->writeQueue); this->file_descriptor = socket(AF_INET, SOCK_STREAM | SOCK_NONBLOCK, 0); if(this->file_descriptor < 0) CERR("Socket setup failed"); signal(SIGPIPE, SIG_IGN); auto state = ::connect(this->file_descriptor, reinterpret_cast(&this->remote_address), sizeof(this->remote_address)); if(state < 0 && errno != EINPROGRESS) CERR("connect() failed (" + string(strerror(errno)) + ")"); if(setsockopt(this->file_descriptor, SOL_SOCKET, SO_REUSEADDR, &enabled, sizeof(enabled)) < 0) CERR("could not set reuse addr"); if(setsockopt(this->file_descriptor, IPPROTO_TCP, TCP_CORK, &disabled, sizeof(disabled)) < 0) CERR("could not set no push"); if(fcntl(this->file_descriptor, F_SETFD, fcntl(this->file_descriptor, F_GETFL, 0) | FD_CLOEXEC | O_NONBLOCK) < 0) CERR("Failed to set FD_CLOEXEC and O_NONBLOCK"); this->event_base = event_base_new(); this->event_read = event_new(this->event_base, this->file_descriptor, EV_READ | EV_PERSIST, LicenceRequest::handleEventRead, this); this->event_write = event_new(this->event_base, this->file_descriptor, EV_WRITE, LicenceRequest::handleEventWrite, this); this->state = protocol::CONNECTING; //First set connected, then we could enable the event loop event_dispatch = std::thread([&]() { signal(SIGPIPE, SIG_IGN); { /* now we could start listening */ lock_guard _lock(this->lock); if(!this->event_read || !this->event_write) return; event_add(this->event_read, nullptr); timeval connect_timeout{5, 0}; event_add(this->event_write, &connect_timeout); } event_base_dispatch(this->event_base); }); } void LicenceRequest::handleConnected() { this->state = protocol::HANDSCAKE; uint8_t handshakeBuffer[4]; handshakeBuffer[0] = 0xC0; handshakeBuffer[1] = 0xFF; handshakeBuffer[2] = 0xEE; handshakeBuffer[3] = LICENSE_PROT_VERSION; this->sendPacket(protocol::packet{protocol::PACKET_CLIENT_HANDSHAKE, string((const char*) handshakeBuffer, 4)}); //Initialise packet } void LicenceRequest::handleMessage(const std::string& message) { this->buffer += message; if(this->buffer.length() < sizeof(protocol::packet::header)) return; protocol::packet packet{protocol::PACKET_DISCONNECT, ""}; memcpy(&packet.header, this->buffer.data(), sizeof(protocol::packet::header)); if(packet.header.length <= this->buffer.length() - sizeof(protocol::packet::header)) { packet.data = this->buffer.substr(sizeof(protocol::packet::header), packet.header.length); this->buffer = this->buffer.substr(sizeof(protocol::packet::header) + packet.header.length); } else { return; } if(!this->cryptKey.empty()) { xorBuffer((char*) packet.data.data(), packet.data.length(), this->cryptKey.data(), this->cryptKey.length()); } if(packet.header.packetId == protocol::PACKET_SERVER_HANDSHAKE) { this->handlePacketHandshake(packet.data); } else if(packet.header.packetId == protocol::PACKET_DISCONNECT) { this->handlePacketDisconnect(packet.data); } else if(packet.header.packetId == protocol::PACKET_SERVER_VALIDATION_RESPONSE) { this->handlePacketLicenseInfo(packet.data); } else if(packet.header.packetId == protocol::PACKET_SERVER_PROPERTY_ADJUSTMENT) { this->handlePacketInfoAdjustment(packet.data); } else LICENSE_FERR(this, ConnectionException, "Invalid packet id (" + to_string(packet.header.packetId) + ")"); if(!this->buffer.empty() && this->state != protocol::DISCONNECTING && this->state != protocol::UNCONNECTED) this->handleMessage(""); } void LicenceRequest::disconnect(const std::string& message) { if(this->state != protocol::UNCONNECTED && this->state != protocol::DISCONNECTING) this->sendPacket({protocol::PACKET_DISCONNECT, message}); this->closeConnection(); //TODO flush? } void LicenceRequest::closeConnection() { event *event_read, *event_write; { lock_guard lock(this->lock); if(this->state == protocol::UNCONNECTED) return; if(this->event_dispatch.get_id() == this_thread::get_id()) { //We could not close in the same thread as we read/write (we're joining it later) if(this->state == protocol::DISCONNECTING) return; this->state = protocol::DISCONNECTING; this->closeThread = new threads::Thread(THREAD_SAVE_OPERATIONS, [&]() { this->closeConnection(); }); #ifdef DEBUG_LICENSE_CLIENT if(this->verbose) { debugMessage(LOG_GENERAL,"Running close in a new thread"); this->closeThread->name("License request close"); } #endif return; } this->state = protocol::UNCONNECTED; event_read = this->event_read; event_write = this->event_write; this->event_write = nullptr; this->event_read = nullptr; } if(event_read) { event_del_block(event_read); event_free(event_read); } if(event_write) { event_del_block(event_write); event_free(event_write); } /* close before base shutdown (else epoll hangup) */ if(this->file_descriptor > 0) { shutdown(this->file_descriptor, SHUT_RDWR); close(this->file_descriptor); } this->file_descriptor = 0; { lock_guard lock(this->lock); ts::buffer::RawBuffer* buffer; while ((buffer = TAILQ_FIRST(&this->writeQueue))) { TAILQ_REMOVE(&this->writeQueue, buffer, tail); delete buffer; } } { if(this->event_base) { timeval seconds{1, 0}; event_base_loopexit(this->event_base, &seconds); event_base_loopexit(this->event_base, nullptr); } if(this->event_dispatch.joinable()) { this->event_dispatch.join(); } if(this->event_base) { event_base_free(this->event_base); this->event_base = nullptr; } } #ifdef DEBUG_LICENSE_CLIENT if(this->verbose) debugMessage("Executing close done"); #endif } void LicenceRequest::abortRequest(const std::chrono::system_clock::time_point &timeout) { this->closeConnection(); }