WebDNS/server/src/server.cpp

321 lines
8.4 KiB
C++

#include "./server.h"
#include "./handler.h"
#include "./net.h"
#include <event.h>
#include <functional>
#include <cstring>
#include <iostream>
#include <cassert>
#include <zconf.h>
#include <fcntl.h>
using namespace ts::dns;
DNSServer::~DNSServer() {
this->stop();
}
bool DNSServer::start(const std::vector<sockaddr_storage> &bindings, std::string &error) {
size_t successful_binds{0};
std::lock_guard lock{this->bind_lock};
if(this->started) {
error = "server already bound";
return false;
}
this->event_base = event_base_new();
if(!this->event_base) {
error = "failed to spawn event base";
goto error_exit;
}
this->event_base_ticker = evtimer_new(this->event_base, [](auto, auto, void* server){
static_cast<DNSServer*>(server)->event_cb_timer();
}, this);
if(!this->event_base_ticker) {
error = "failed to spawn heartbeat event";
goto error_exit;
}
this->_bindings.resize(bindings.size());
for(size_t index = 0; index < bindings.size(); index++) {
auto binding = this->_bindings[index] = std::make_shared<DNSServerBinding>();
binding->self = binding;
binding->server = this;
memcpy(&binding->address, &bindings[index], sizeof(sockaddr_storage));
if(!this->bind(*binding, error))
binding->error = error;
else
successful_binds++;
}
if(!successful_binds) {
error = "failed to bind to any address";
goto error_exit;
}
{
timeval timeout{10, 0};
event_add(this->event_base_ticker, &timeout);
}
this->event_base_executor = std::thread(std::bind(&DNSServer::event_executor, this));
this->started = true;
return true;
error_exit:
if(this->event_base_ticker) {
event_del_block(this->event_base_ticker);
event_free(this->event_base_ticker);
this->event_base_ticker = nullptr;
}
if(this->event_base) {
event_base_free(this->event_base);
this->event_base = nullptr;
}
return false;
}
void DNSServer::stop() {
std::unique_lock lock{this->bind_lock};
if(!this->started)
return;
this->started = false;
assert(this->event_base); //Must be set else the started flag was invalid
assert(this->event_base_ticker); //Must be set else the started flag was invalid
for(auto& binding : this->bindings())
this->unbind(*binding);
event_base_loopexit(this->event_base, nullptr);
{
timeval timeout{0, 0};
event_add(this->event_base_ticker, &timeout);
}
lock.unlock();
this->event_base_executor.join();
lock.lock();
event_free(this->event_base_ticker);
this->event_base_ticker = nullptr;
}
bool DNSServer::bind(DNSServerBinding &binding, std::string &error) {
binding.socket = socket(binding.address.ss_family, SOCK_DGRAM, 0);
if(binding.socket < 2) {
error = "failed to create socket";
return false;
}
int enable{1};
if(binding.address.ss_family == AF_INET6) {
if(setsockopt(binding.socket, IPPROTO_IPV6, IPV6_RECVPKTINFO, &enable, sizeof(enable)) < 0) {
error = "failed to enable packet info (v6) (" + std::to_string(errno) + "/" + strerror(errno) + ")";
goto cleanup_exit;
}
if(setsockopt(binding.socket, IPPROTO_IPV6, IPV6_V6ONLY, &enable, sizeof(enable)) < 0) {
error = "failed to enable ip v6 only (" + std::to_string(errno) + "/" + strerror(errno) + ")";
goto cleanup_exit;
}
} else {
if(setsockopt(binding.socket, IPPROTO_IP, IP_PKTINFO, &enable, sizeof(enable)) < 0) {
error = "failed to enable packet info (" + std::to_string(errno) + "/" + strerror(errno) + ")";
goto cleanup_exit;
}
}
if(::bind(binding.socket, (const sockaddr*) &binding.address, sizeof(binding.address)) < 0) {
error = "failed to bind: " + std::to_string(errno) + "/" + strerror(errno);
goto cleanup_exit;
}
if(fcntl(binding.socket, F_SETFL, fcntl(binding.socket, F_GETFL, 0) | O_NONBLOCK) < 0) {
error = "failed to enable noblock";
goto cleanup_exit;
}
#ifdef WIN32
u_long enabled = 0;
auto non_block_rs = ioctlsocket(binding.socket, FIONBIO, &enabled);
if (non_block_rs != NO_ERROR) {
error = "failed to enable nonblock";
goto cleanup_exit;
}
#endif
binding.read_event = event_new(this->event_base, binding.socket, EV_READ | EV_PERSIST, &DNSServer::event_cb_read, &binding);
if(!binding.read_event) {
error = "failed to create read event";
goto cleanup_exit;
}
binding.write_event = event_new(this->event_base, binding.socket, EV_WRITE, &DNSServer::event_cb_write, &binding);
if(!binding.read_event) {
error = "failed to create write event";
goto cleanup_exit;
}
event_add(binding.read_event, nullptr);
return true;
cleanup_exit:
if(binding.read_event) {
event_del_noblock(binding.read_event);
event_free(binding.read_event);
binding.read_event = nullptr;
}
if(binding.write_event) {
event_del_noblock(binding.write_event);
event_free(binding.write_event);
binding.write_event = nullptr;
}
return false;
}
void DNSServer::unbind(DNSServerBinding &binding) {
std::lock_guard lock{binding.io_lock};
if(binding.read_event) {
event_del_block(binding.read_event);
event_free(binding.read_event);
binding.read_event = nullptr;
}
if(binding.write_event) {
event_del_block(binding.write_event);
event_free(binding.write_event);
binding.write_event = nullptr;
}
if(binding.socket) {
::shutdown(binding.socket, SHUT_RDWR);
::close(binding.socket);
binding.socket = 0;
}
{
auto head = binding.write_buffer_head;
while(head) {
auto tmp = head;
head = head->next;
free(tmp);
}
binding.write_buffer_head = nullptr;
binding.write_buffer_tail = nullptr;
}
binding.self = nullptr;
}
void DNSServerBinding::send(const sockaddr_storage &address, const void *payload, size_t size) {
std::lock_guard lock{this->io_lock};
if(!this->write_event)
return;
auto buffer = (char*) malloc(sizeof(BindingBuffer) + size);
auto bbuffer = (BindingBuffer*) buffer;
bbuffer->next = nullptr;
bbuffer->size = size;
memcpy(&bbuffer->target, &address, sizeof(address));
memcpy(buffer + sizeof(BindingBuffer), payload, size);
if(this->write_buffer_tail) {
assert(!this->write_buffer_tail->next);
this->write_buffer_tail->next = bbuffer;
this->write_buffer_tail = bbuffer;
} else {
assert(!this->write_buffer_head);
this->write_buffer_head = bbuffer;
this->write_buffer_tail = bbuffer;
}
event_add(this->write_event, nullptr);
}
void DNSServer::event_executor() {
do {
event_base_loop(this->event_base, EVLOOP_NO_EXIT_ON_EMPTY);
std::lock_guard lock{this->bind_lock};
if(!this->started)
return;
} while(true);
}
void DNSServer::event_cb_timer() {
std::lock_guard lock{this->bind_lock};
if(this->started) {
timeval timeout{10, 0};
event_add(this->event_base_ticker, &timeout);
}
}
void DNSServer::event_cb_read(evutil_socket_t fd, short, void *ptr_binding) {
auto binding = static_cast<DNSServerBinding*>(ptr_binding);
auto binding_ref = binding->self;
if(!binding_ref) return;
sockaddr_storage source_address{};
socklen_t source_address_length{0};
ssize_t read_length{-1};
size_t buffer_length = 1600; /* IPv6 MTU is ~1.5k */
char buffer[1600];
size_t read_count = 0;
while(true) { //TODO: Some kind of timeout
source_address_length = sizeof(sockaddr_storage);
read_length = recvfrom(fd, (char*) buffer, buffer_length, MSG_DONTWAIT, (struct sockaddr*) &source_address, &source_address_length);
if(read_length <= 0) {
if(errno == EAGAIN)
break;
std::cerr << "Failed to receive data: " << errno << "/" << strerror(errno) << "\n";
break; /* this should never happen! */
}
read_count++;
//buffer, (size_t) read_length
auto handler = binding->server->handler;
if(handler)
handler->handle_message(binding_ref, source_address, buffer, read_length);
else
std::cerr << "Dropping " << read_length << " bytes from " << net::to_string(source_address, true) << " because we've no handler\n";
}
}
void DNSServer::event_cb_write(evutil_socket_t fd, short, void *ptr_binding) {
auto binding = static_cast<DNSServerBinding*>(ptr_binding);
auto binding_ref = binding->self;
if(!binding_ref) return;
ssize_t code;
DNSServerBinding::BindingBuffer* buffer{nullptr};
while(true) {
{
std::lock_guard lock{binding->io_lock};
buffer = binding->write_buffer_head;
if(!buffer)
break;
if(!(binding->write_buffer_head = binding->write_buffer_head->next))
binding->write_buffer_tail = nullptr;
}
code = sendto(fd, (char*) buffer + sizeof(DNSServerBinding::BindingBuffer), buffer->size, 0, (sockaddr*) &buffer->target, sizeof(buffer->target));
if(code <= 0)
std::cerr << "Failed to send DNS response to " << net::to_string(buffer->target, true) << ": " << errno << "/" << strerror(errno);
free(buffer);
}
}