2019-07-17 13:37:18 -04:00
# include <netinet/tcp.h>
# include <log/LogUtils.h>
# include <misc/memtracker.h>
# include "crypt.h"
# define DEFINE_HELPER
# include "LicenseRequest.h"
# include "License.h"
# include <csignal>
using namespace std ;
using namespace std : : chrono ;
using namespace ts ;
using namespace license ;
# define DEBUG_LICENSE_CLIENT
# define CERR(message) LICENSE_FERR(this, CouldNotConnectException, message)
LicenceRequest : : LicenceRequest ( const std : : shared_ptr < LicenseRequestData > & license , const sockaddr_in & remoteAddr ) : data ( license ) {
# ifdef DEBUG_LICENSE_CLIENT
memtrack : : allocated < LicenceRequest > ( this ) ;
# endif
memcpy ( & this - > remote_address , & remoteAddr , sizeof ( remoteAddr ) ) ;
assert ( license - > info ) ;
}
LicenceRequest : : ~ LicenceRequest ( ) {
# ifdef DEBUG_LICENSE_CLIENT
memtrack : : freed < LicenceRequest > ( this ) ;
# endif
this - > abortRequest ( ) ;
if ( this - > closeThread ) {
this - > closeThread - > join ( ) ;
delete this - > closeThread ;
this - > closeThread = nullptr ;
}
delete this - > currentFuture ;
this - > currentFuture = nullptr ;
}
threads : : Future < std : : shared_ptr < LicenseRequestResponse > > LicenceRequest : : requestInfo ( ) {
{
lock_guard lock ( this - > lock ) ;
if ( this - > currentFuture ) return * this - > currentFuture ;
this - > currentFuture = new threads : : Future < std : : shared_ptr < LicenseRequestResponse > > ( ) ;
}
this - > beginRequest ( ) ;
return * this - > currentFuture ;
}
//Basic IO
void LicenceRequest : : handleEventWrite ( int fd , short event , void * ptrClient ) {
auto * client = static_cast < LicenceRequest * > ( ptrClient ) ;
buffer : : RawBuffer * buffer = nullptr ;
{
lock_guard lock ( client - > lock ) ;
if ( ( event & EV_TIMEOUT ) > 0 ) { //Connect timeout
LICENSE_FERR ( client , ConnectionException , " Connect timeout " ) ;
return ;
}
if ( client - > state = = protocol : : CONNECTING ) {
client - > handleConnected ( ) ;
}
if ( client - > state = = protocol : : UNCONNECTED | | ! client - > event_write )
return ;
buffer = TAILQ_FIRST ( & client - > writeQueue ) ;
if ( ! buffer ) return ;
auto writtenBytes = send ( fd , & buffer - > buffer [ buffer - > index ] , buffer - > length - buffer - > index , MSG_NOSIGNAL | MSG_DONTWAIT ) ;
buffer - > index + = writtenBytes ;
if ( buffer - > index > = buffer - > length ) {
TAILQ_REMOVE ( & client - > writeQueue , buffer , tail ) ;
delete buffer ;
}
if ( ! TAILQ_EMPTY ( & client - > writeQueue ) )
event_add ( client - > event_write , nullptr ) ;
}
}
void LicenceRequest : : sendPacket ( const protocol : : packet & packet ) {
if ( this - > state = = protocol : : UNCONNECTED | | this - > state = = protocol : : DISCONNECTING ) {
if ( this - > verbose )
logError ( " Tried to send a packet to an unconnected remote! " ) ;
return ;
}
packet . prepare ( ) ;
auto buffer = new buffer : : RawBuffer ( packet . data . length ( ) + sizeof ( packet . header ) ) ;
memcpy ( buffer - > buffer , & packet . header , sizeof ( packet . header ) ) ;
memcpy ( & buffer - > buffer [ sizeof ( packet . header ) ] , packet . data . data ( ) , packet . data . length ( ) ) ;
if ( ! this - > cryptKey . empty ( ) )
xorBuffer ( & buffer - > buffer [ sizeof ( packet . header ) ] , packet . data . length ( ) , this - > cryptKey . data ( ) , this - > cryptKey . length ( ) ) ;
{
lock_guard lock ( this - > lock ) ;
TAILQ_INSERT_TAIL ( & this - > writeQueue , buffer , tail ) ;
if ( this - > event_write )
event_add ( this - > event_write , nullptr ) ;
}
}
void LicenceRequest : : handleEventRead ( int fd , short , void * ptrClient ) {
auto * client = static_cast < LicenceRequest * > ( ptrClient ) ;
auto buffer = std : : unique_ptr < void , decltype ( free ) * > { malloc ( 1024 ) , free } ;
sockaddr_in remoteAddr { } ;
socklen_t remoteAddrSize = sizeof ( remoteAddr ) ;
auto read = recvfrom ( fd , buffer . get ( ) , 1024 , MSG_NOSIGNAL | MSG_DONTWAIT , reinterpret_cast < sockaddr * > ( & remoteAddr ) , & remoteAddrSize ) ;
if ( read < 0 ) {
if ( errno = = EWOULDBLOCK ) return ;
if ( client - > event_read )
event_del_noblock ( client - > event_read ) ;
LICENSE_FERR ( client , ConnectionException , " Invalid read: " + string ( strerror ( errno ) ) + " / " + to_string ( errno ) ) ;
return ;
} else if ( read = = 0 ) {
if ( client - > event_read )
event_del_noblock ( client - > event_read ) ;
LICENSE_FERR ( client , ConnectionException , " IO error ( " + to_string ( errno ) + " ): " + string ( strerror ( errno ) ) ) ;
return ;
}
client - > handleMessage ( string ( ( char * ) buffer . get ( ) , read ) ) ;
}
static int enabled = 1 ;
static int disabled = 0 ;
void LicenceRequest : : beginRequest ( ) {
lock_guard lock ( this - > lock ) ;
TAILQ_INIT ( & this - > writeQueue ) ;
this - > file_descriptor = socket ( AF_INET , SOCK_STREAM | SOCK_NONBLOCK , 0 ) ;
if ( this - > file_descriptor < 0 ) CERR ( " Socket setup failed " ) ;
signal ( SIGPIPE , SIG_IGN ) ;
auto state = : : connect ( this - > file_descriptor , reinterpret_cast < const sockaddr * > ( & this - > remote_address ) , sizeof ( this - > remote_address ) ) ;
if ( state < 0 & & errno ! = EINPROGRESS ) CERR ( " connect() failed ( " + string ( strerror ( errno ) ) + " ) " ) ;
if ( setsockopt ( this - > file_descriptor , SOL_SOCKET , SO_REUSEADDR , & enabled , sizeof ( enabled ) ) < 0 ) CERR ( " could not set reuse addr " ) ;
if ( setsockopt ( this - > file_descriptor , IPPROTO_TCP , TCP_CORK , & disabled , sizeof ( disabled ) ) < 0 ) CERR ( " could not set no push " ) ;
if ( fcntl ( this - > file_descriptor , F_SETFD , fcntl ( this - > file_descriptor , F_GETFL , 0 ) | FD_CLOEXEC | O_NONBLOCK ) < 0 ) CERR ( " Failed to set FD_CLOEXEC and O_NONBLOCK " ) ;
this - > event_base = event_base_new ( ) ;
this - > event_read = event_new ( this - > event_base , this - > file_descriptor , EV_READ | EV_PERSIST , LicenceRequest : : handleEventRead , this ) ;
this - > event_write = event_new ( this - > event_base , this - > file_descriptor , EV_WRITE , LicenceRequest : : handleEventWrite , this ) ;
this - > state = protocol : : CONNECTING ; //First set connected, then we could enable the event loop
event_dispatch = std : : thread ( [ & ] ( ) {
signal ( SIGPIPE , SIG_IGN ) ;
{ /* now we could start listening */
lock_guard _lock ( this - > lock ) ;
if ( ! this - > event_read | | ! this - > event_write ) return ;
event_add ( this - > event_read , nullptr ) ;
timeval connect_timeout { 5 , 0 } ;
event_add ( this - > event_write , & connect_timeout ) ;
}
event_base_dispatch ( this - > event_base ) ;
} ) ;
}
void LicenceRequest : : handleConnected ( ) {
this - > state = protocol : : HANDSCAKE ;
uint8_t handshakeBuffer [ 4 ] ;
handshakeBuffer [ 0 ] = 0xC0 ;
handshakeBuffer [ 1 ] = 0xFF ;
handshakeBuffer [ 2 ] = 0xEE ;
handshakeBuffer [ 3 ] = LICENSE_PROT_VERSION ;
this - > sendPacket ( protocol : : packet { protocol : : PACKET_CLIENT_HANDSHAKE , string ( ( const char * ) handshakeBuffer , 4 ) } ) ; //Initialise packet
}
void LicenceRequest : : handleMessage ( const std : : string & message ) {
2019-11-22 14:51:00 -05:00
this - > buffer + = message ;
if ( this - > buffer . length ( ) < sizeof ( protocol : : packet : : header ) )
return ;
2019-07-17 13:37:18 -04:00
protocol : : packet packet { protocol : : PACKET_DISCONNECT , " " } ;
2019-11-22 14:51:00 -05:00
memcpy ( & packet . header , this - > buffer . data ( ) , sizeof ( protocol : : packet : : header ) ) ;
if ( packet . header . length < = this - > buffer . length ( ) - sizeof ( protocol : : packet : : header ) ) {
packet . data = this - > buffer . substr ( sizeof ( protocol : : packet : : header ) , packet . header . length ) ;
this - > buffer = this - > buffer . substr ( sizeof ( protocol : : packet : : header ) + packet . header . length ) ;
} else {
return ;
}
2019-07-17 13:37:18 -04:00
if ( ! this - > cryptKey . empty ( ) ) {
xorBuffer ( ( char * ) packet . data . data ( ) , packet . data . length ( ) , this - > cryptKey . data ( ) , this - > cryptKey . length ( ) ) ;
}
if ( packet . header . packetId = = protocol : : PACKET_SERVER_HANDSHAKE ) {
this - > handlePacketHandshake ( packet . data ) ;
} else if ( packet . header . packetId = = protocol : : PACKET_DISCONNECT ) {
this - > handlePacketDisconnect ( packet . data ) ;
} else if ( packet . header . packetId = = protocol : : PACKET_SERVER_VALIDATION_RESPONSE ) {
this - > handlePacketLicenseInfo ( packet . data ) ;
} else if ( packet . header . packetId = = protocol : : PACKET_SERVER_PROPERTY_ADJUSTMENT ) {
this - > handlePacketInfoAdjustment ( packet . data ) ;
} else
LICENSE_FERR ( this , ConnectionException , " Invalid packet id ( " + to_string ( packet . header . packetId ) + " ) " ) ;
2019-11-22 14:51:00 -05:00
if ( ! this - > buffer . empty ( ) & & this - > state ! = protocol : : DISCONNECTING & & this - > state ! = protocol : : UNCONNECTED )
this - > handleMessage ( " " ) ;
2019-07-17 13:37:18 -04:00
}
void LicenceRequest : : disconnect ( const std : : string & message ) {
if ( this - > state ! = protocol : : UNCONNECTED & & this - > state ! = protocol : : DISCONNECTING )
this - > sendPacket ( { protocol : : PACKET_DISCONNECT , message } ) ;
this - > closeConnection ( ) ;
//TODO flush?
}
void LicenceRequest : : closeConnection ( ) {
event * event_read , * event_write ;
{
lock_guard lock ( this - > lock ) ;
if ( this - > state = = protocol : : UNCONNECTED ) return ;
if ( this - > event_dispatch . get_id ( ) = = this_thread : : get_id ( ) ) { //We could not close in the same thread as we read/write (we're joining it later)
if ( this - > state = = protocol : : DISCONNECTING ) return ;
this - > state = protocol : : DISCONNECTING ;
this - > closeThread = new threads : : Thread ( THREAD_SAVE_OPERATIONS , [ & ] ( ) { this - > closeConnection ( ) ; } ) ;
# ifdef DEBUG_LICENSE_CLIENT
if ( this - > verbose ) {
debugMessage ( " Running close in a new thread " ) ;
this - > closeThread - > name ( " License request close " ) ;
}
# endif
return ;
}
this - > state = protocol : : UNCONNECTED ;
event_read = this - > event_read ;
event_write = this - > event_write ;
this - > event_write = nullptr ;
this - > event_read = nullptr ;
}
if ( event_read ) {
event_del_block ( event_read ) ;
event_free ( event_read ) ;
}
if ( event_write ) {
event_del_block ( event_write ) ;
event_free ( event_write ) ;
}
/* close before base shutdown (else epoll hangup) */
if ( this - > file_descriptor > 0 ) {
shutdown ( this - > file_descriptor , SHUT_RDWR ) ;
close ( this - > file_descriptor ) ;
}
this - > file_descriptor = 0 ;
{
lock_guard lock ( this - > lock ) ;
ts : : buffer : : RawBuffer * buffer ;
while ( ( buffer = TAILQ_FIRST ( & this - > writeQueue ) ) ) {
TAILQ_REMOVE ( & this - > writeQueue , buffer , tail ) ;
delete buffer ;
}
}
{
if ( this - > event_base ) {
timeval seconds { 1 , 0 } ;
event_base_loopexit ( this - > event_base , & seconds ) ;
event_base_loopexit ( this - > event_base , nullptr ) ;
}
if ( this - > event_dispatch . joinable ( ) ) {
this - > event_dispatch . join ( ) ;
}
if ( this - > event_base ) {
event_base_free ( this - > event_base ) ;
this - > event_base = nullptr ;
}
}
# ifdef DEBUG_LICENSE_CLIENT
if ( this - > verbose )
debugMessage ( " Executing close done " ) ;
# endif
}
void LicenceRequest : : abortRequest ( const std : : chrono : : system_clock : : time_point & timeout ) {
this - > closeConnection ( ) ;
}