Teaspeak-Server/client/src/protocol/Connection.cpp

454 lines
15 KiB
C++

//
// Created by wolverindev on 07.10.17.
//
#include <log/LogUtils.h>
#include "Connection.h"
#include "misc/base64.h"
#include <misc/endianness.h>
#include <arpa/inet.h>
#include <poll.h>
#include <openssl/sha.h>
#include <bitset>
#include <protocol/Packet.h>
using namespace std;
using namespace ts;
using namespace ts::connection;
using namespace ts::protocol;
ServerConnection::ServerConnection() {
cryptionHandler = new CryptionHandler();
cryptionHandler->reset();
compressionHandler = new CompressionHandler();
readQueue = (buffer::SortedBufferQueue<ServerPacket> **) malloc(16 * sizeof(void*));
for(int i = 0; i < 16; i++) {
auto type = ts::protocol::PacketTypeInfo::fromid(i);
if(type != PacketTypeInfo::Undefined){
readQueue[i] = new buffer::SortedBufferQueue<ServerPacket>(ts::protocol::PacketTypeInfo::fromid(i), PacketTypeInfo::Command != type); //Ignore command low
} else {
readQueue[i] = nullptr;
}
}
}
ServerConnection::~ServerConnection() {
for(int i = 0; i < 16; i++)
if(readQueue[i])
delete readQueue[i];
free(readQueue);
this->rwThread->join();
}
static int sourcePort = 50000;
void ServerConnection::disconnect() {
//this->rwThread->cancel();
//this->handleThread->cancel();
this->connected = false;
if(this->socket) this->socket->close();
}
bool ServerConnection::connect(std::string host, std::string port, Identity *identity) {
this->clientIdentity = identity;
memset(&remoteAddress, 0, sizeof(remoteAddress));
remoteAddress.sin_family = AF_INET;
remoteAddress.sin_port = htons((uint16_t) std::stoi(port));
remoteAddress.sin_addr.s_addr = inet_addr(host.c_str());
#ifdef NoQt
/*
this->socketfd = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP);
int allow = 1;
setsockopt(this->socketfd, SOL_SOCKET, SO_REUSEADDR, &allow, sizeof(int));
memset(&localAddress, 0, sizeof(localAddress));
localAddress.sin_family = AF_INET;
localAddress.sin_addr.s_addr = htonl (INADDR_ANY);
localAddress.sin_port = htons (sourcePort++);
::connect(this->socketfd, (const sockaddr *) &remoteHost, sizeof(this->remoteAddress));
//bind(this->socketfd, (struct sockaddr *) &localAddress, sizeof(localAddress));
*/
this->socket = new UdpSocket;
if(!this->socket->setup(&remoteAddress)){
cerr << "Invalid socket setup" << endl;
}
#else
#endif
this->rwThread = new threads::Thread(THREAD_SAVE_OPERATIONS, [&]() {
#ifndef NoQt
this->qtSocket = new QUdpSocket();
QObject::connect(qtSocket,SIGNAL(bytesWritten(qint64)),this,SLOT(bytesWritten(qint64)));
QObject::connect(this->qtSocket, SIGNAL(readyRead()), this, SLOT(attempDatagramRead()));
this->qtSocket->bind(QHostAddress::Any, 23111);
this->socketfd = qtSocket->socketDescriptor();
cout << "Sock fd: " << this->socketfd << endl;
#endif
/*
auto cthread = QThread::currentThread();
cout << "ex" << endl;
runOnThread(qtSocket->thread(), [&](){
cout << "try" << endl;
qtSocket->moveToThread(cthread);
cout << "Moved" << endl;
});
cout << "Start rw" << endl;
*/
this->rwExecutor();
});
/*
this->handleThread = new threads::Thread([&]() {
this->handleExecutor();
});
*/
return true;
}
#ifndef NoQt
void ServerConnection::bytesWritten(qint64 b) {
cout << "written " << b << endl;
}
void ServerConnection::attempDatagramRead() {
cout << "Data " << endl;
}
#endif
//this->socket->socketDescriptor()
void ServerConnection::rwExecutor() {
pollfd pollData = {this->socket->getSocketDescriptor(), POLLRDHUP | POLLIN | POLLOUT, 0};
buffer::RawBuffer readBuffer(512);
std::shared_ptr<protocol::ServerPacket> readedPacket;
while (socket->getSocketDescriptor() > 0 && this->connected) {
int rfds = poll(&pollData, 1, -1);
bool select = false;
if(rfds == 0) {
usleep(5 * 1000);;
continue;
} else if(rfds < 0) {
break;
}
if (pollData.revents & POLLRDHUP || pollData.revents & POLLHUP) {
select = 1;
cerr << "Connection hang up!" << endl;
return;
}
if (pollData.revents & POLLIN) {
select = 1;
ssize_t readedBytes = -1;
#ifdef NoQt
readedBytes = socket->read(readBuffer.buffer, readBuffer.length);
#ifdef DEBUG
cout << "Read bytes (" << readedBytes << ")" << endl;
#endif
#else
QHostAddress senderAddr;
quint16 senderPort;
readedBytes = this->qtSocket->readDatagram(readBuffer.buffer, readBuffer.length, &senderAddr, &senderPort);
#endif
if (readedBytes < 0) {
cout << "fatal read error: " << errno << "/" << strerror(errno) << endl;
return;
}
readedPacket = std::make_shared<ServerPacket>(pipes::buffer_view(readBuffer.buffer, readedBytes));
if(!preProgressPacket(readedPacket)){
cerr << "Invalid packet preprocess!" << endl;
readedPacket = nullptr;
goto exitRead;
}
if(readedPacket->type().type() < 0 || readedPacket->type().type() > 16){
cerr << "Invalid packet id!" << endl;
readedPacket = nullptr;
goto exitRead;
}
//Deserelize packet
this->bufferQueueLock.lock();
if(!this->readQueue[readedPacket->type().type()]->push_pack(readedPacket)){
//TODO error handling
cout << "pkId: " << be2le16((char*) readedPacket->data().data_ptr()) << " -> " << readedPacket->type().name() << endl;
}
this->bufferQueueLock.unlock();
while(this->handleNextPacket());
exitRead:;
}
if (pollData.revents & POLLOUT) {
this->bufferQueueLock.lock();
if (!this->writeQueue.empty()) {
select = 1;
buffer::RawBuffer buffer = this->writeQueue.front();
#ifdef NoQt
auto res = this->socket->write(buffer.buffer, buffer.length);
if (res == -1) {
cout << "having write error: " << errno << "/" << strerror(errno) << " -> " << buffer.length << endl;
}
//cout << string() + "Write: " + PacketType::fromid(buffer.type()).name() << endl;
#else
this->qtSocket->writeDatagram(buffer.buffer, buffer.length, QHostAddress("localhost"), htons(this->remoteAddress.sin_port));
#endif
/*
if(!PacketTypeInfo::fromid(buffer.type()).requireAcknowledge()){ //Than we need a ack!
//Wait for acknowlage
this->acknowlageQueueLock.lock();
this->acknowlageQueue.push_back(buffer);
this->acknowlageQueueLock.unlock();
}
*/
this->writeQueue.pop_front();
}
this->bufferQueueLock.unlock();
}
if (!select) {
usleep(5 * 10000);
continue;
}
}
cerr << "rw loop broken!" << endl;
}
bool ServerConnection::preProgressPacket(std::shared_ptr<protocol::ServerPacket> packet){
packet->setEncrypted(!packet->hasFlag(PacketFlag::Unencrypted));
packet->setCompressed(packet->hasFlag(PacketFlag::Compressed));
packet->setFragmentedEntry(packet->hasFlag(PacketFlag::Fragmented));
if(packet->type() == PacketTypeInfo::Init1){
}
if (packet->isEncrypted()) {
string error = "success";
if (!cryptionHandler->progressPacketIn(packet.get(), error, false)) {
cerr << "Cant decript packet! Message: " << error << endl;
cerr << "Dropping it!" << endl;
return false;
}
}
#ifdef DEBUG
cout << "[IN] Packet type -> " << packet->type().name() << " flags " << packet->flags() << " Length: " << packet->data().length() << endl;
#endif
if(packet->type() == PacketTypeInfo::Command || packet->type() == PacketTypeInfo::CommandLow){ //needs an acknowledge
sendAcknowledge(packet->packetId(), packet->type() == PacketTypeInfo::CommandLow);
}
return true;
}
//TODO right packet recive order!
void ServerConnection::handleExecutor() {
shared_ptr<protocol::ServerPacket> packet = nullptr;
string error = "success";
while(this->connected){
while(this->handleNextPacket());
usleep(10 * 1000);
}
}
bool ServerConnection::handleNextPacket() {
shared_ptr<protocol::ServerPacket> packet = nullptr;
string error = "success";
if(this->autoHandle){
handleQueueLock.lock();
if(!this->handleQueue.empty()) {
packet = this->handleQueue.front();
this->handleQueue.pop_front();
}
handleQueueLock.unlock();
if(packet){
if(packet->type() == PacketTypeInfo::Ack || packet->type() == PacketTypeInfo::AckLow){
handlePacketAck(packet);
} else if(packet->type() == PacketTypeInfo::Command || packet->type() == PacketTypeInfo::CommandLow){
handlePacketCommand(packet);
} else if(packet->type() == PacketTypeInfo::Ping || packet->type() == PacketTypeInfo::Pong){
handlePacketPing(packet);
} else if(packet->type() == PacketTypeInfo::Voice || packet->type() == PacketTypeInfo::VoiceWhisper){
handlePacketVoice(packet);
}
return true;
}
}
for(int index = 0; index < 16; index++){
if(this->readQueue[index]) {
if(this->readQueue[index]->available() > 0){
auto npacket = this->readQueue[index]->peekNext(0);
packet = make_shared<ServerPacket>(npacket->buffer());
packet->setEncrypted(npacket->isEncrypted());
packet->setCompressed(npacket->isCompressed());
packet->setFragmentedEntry(npacket->isFragmentEntry());
break;
}
}
}
if(!packet) return false;
if(packet->isFragmentEntry()){
packet->setFragmentedEntry(false);
int deltaPacketIndex = 0;
while(this->connected){
std::shared_ptr<protocol::ServerPacket> nextElm = this->readQueue[packet->type().type()]->peekNext(++deltaPacketIndex);
if(!nextElm)
return false;
if(!nextElm) {
cerr << "Dropped fragment?" << endl;
packet = nullptr;
break;
}
packet->append_data({nextElm->data()});
if(nextElm->hasFlag(protocol::PacketFlag::Fragmented)) break; //Tail end
nextElm = nullptr;
}
this->readQueue[packet->type().type()]->pop_packets(deltaPacketIndex);
}
this->readQueue[packet->type().type()]->pop_packets(1);
if(packet->type() != PacketTypeInfo::Init1 && !this->compressionHandler->progressPacketIn(packet.get(), error)){
cerr << "Cant decompress packet! (" << error << ")" << endl;
packet = nullptr;
return true;
}
#if defined(DEBUG_PACKET_LOG)
cout << "Parsed packet " << packet->type().name() << " with id " << packet->packetId() << ". Data:" << endl;
hexDump((void *) packet->data().data_ptr(), packet->data().length(), 16, 8, [](std::string line) {
cout << "[IN] " << line << endl;
});
#endif
handleQueueLock.lock();
this->handleQueue.push_back(packet);
handleQueueLock.unlock();
return true;
}
using namespace std::chrono;
std::shared_ptr<protocol::ServerPacket> ServerConnection::readNextPacket(bool block) {
auto start = system_clock::now();
attempGet:
if(system_clock::now() - start > seconds(5)) return nullptr;
this->handleQueueLock.lock();
if (this->handleQueue.empty()) {
this->handleQueueLock.unlock();
if (!block) return nullptr;
usleep(5 * 1000);
goto attempGet;
}
std::shared_ptr<protocol::ServerPacket> packet = std::move(this->handleQueue.front());
this->handleQueue.pop_front();
this->handleQueueLock.unlock();
return packet;
}
bool ServerConnection::setupSharedSecret(std::string alpha, std::string beta, std::string sharedKey, std::string &error) {
return this->cryptionHandler->setupSharedSecret(alpha, beta, sharedKey, error);
}
//Packet splitting not working correctly! (On clientinit dosnt wait for the second)
void ServerConnection::sendPacket(ts::protocol::ClientPacket &packet) {
int maxDataLength = 500 - packet.header().length();
if(packet.data().length() > maxDataLength){
string error;
/*
packet.enableFlag(PacketFlag::Compressed);
if(!this->compressionHandler->progressPacketOut(&packet, error)){
cerr << "Compress error!" << endl;
return;
}
packet.enableFlag(PacketFlag::Compressed);
*/
if(packet.data().length() > maxDataLength){
std::vector<shared_ptr<ClientPacket>> siblings;
siblings.reserve(8);
{ //Split packets
auto buffer = packet.data();
const auto max_length = packet.type().max_length();
while(buffer.length() > max_length * 2) {
siblings.push_back(make_shared<ClientPacket>(packet.type(), buffer.view(0, max_length)));
buffer = buffer.range(max_length);
}
if(buffer.length() > max_length) { //Divide rest by 2
siblings.push_back(make_shared<ClientPacket>(packet.type(), buffer.view(0, buffer.length() / 2)));
buffer = buffer.range(buffer.length() / 2);
}
siblings.push_back(make_shared<ClientPacket>(packet.type(), buffer));
for(const auto& frag : siblings) {
frag->setFragmentedEntry(true);
frag->enableFlag(PacketFlag::NewProtocol);
}
}
for(const auto& entry : siblings)
this->sendPacket(*entry);
return;
}
}
if(!packet.memory_state.id_branded)
packet.applyPacketId(idManager);
packet.clientId(this->clientId);
string error = "success";
if (!this->cryptionHandler->progressPacketOut(&packet, error, false)) {
cerr << "Invalid crypt -> " << error << endl;
return;
}
buffer::RawBuffer buffer(packet.buffer().length());
memcpy(&buffer.buffer[0], packet.buffer().data_ptr(), packet.buffer().length());
this->bufferQueueLock.lock();
this->writeQueue.push_back(buffer);
#if defined(DEBUG_PACKET_LOG)
cout << "Send packet " << packet.type().name() << " fragmented -> " << packet.isFragmentEntry() << " length " << packet.data().length() << " flags " << packet.flags() << " ID: " << packet.packetId() << endl;
hexDump(buffer.buffer, buffer.length, buffer.length, buffer.length);
#endif
this->bufferQueueLock.unlock();
}
void ServerConnection::sendCommand(ts::Command command, bool low) {
auto data = command.build();
protocol::ClientPacket pkt(low ? protocol::PacketTypeInfo::CommandLow : protocol::PacketTypeInfo::Command, pipes::buffer_view{(void*) data.data(), data.length()});
#ifdef DEBUG
cout << "[Client -> Server][" << pkt.type().name() << "] " << pkt.data() << endl;
#endif
if(!low) pkt.enableFlag(PacketFlag::NewProtocol);
sendPacket(pkt);
}
void ServerConnection::sendAcknowledge(uint16_t packetId, bool low) {
if(breakAck) return;
char buffer[2];
le2be16(packetId, buffer);
protocol::ClientPacket pkt(low ? protocol::PacketTypeInfo::AckLow : protocol::PacketTypeInfo::Ack, pipes::buffer_view(buffer, 2));
#ifdef DEBUG
cout << "Sending packet acknowledge for " << packetId << " (Encrypt: " << encriptAck << ")" << endl;
#endif
if(!encriptAck)
pkt.enableFlag(PacketFlag::Unencrypted);
if(!low) pkt.toggle(protocol::PacketFlag::NewProtocol, true);
sendPacket(pkt);
}