/** @file * \brief Fichero de implementación de la clase Socket * \author Imanol Barba Sabariego * \date 13/06/2013 * * En este fichero se implementan los métodos de la clase Socket definidos en Socket.h */ #include "Socket.h" #include "SocketException.h" #include #include #include #include using namespace std; Socket::Socket() { sock = -1; if(RSALENGTH > 0) { memset( myKey, 0x00, AESLENGTH); memset( myIV, 0x00, AES::BLOCKSIZE ); memset( theirKey, 0x00, AESLENGTH ); memset( theirIV, 0x00, AES::BLOCKSIZE ); } } int Socket::getSock() { return sock; } void Socket::Create() { int optval = 1; if((sock = socket(AF_INET, SOCK_STREAM, 0)) <= 0) { throw SocketException ( "TCP: Could not create socket" ); } setsockopt(sock,SOL_SOCKET,SO_KEEPALIVE,&optval,sizeof optval); setsockopt(sock, SOL_TCP, TCP_NODELAY, &optval, sizeof optval); } void Socket::Bind(string address, int port) { sockAddr.sin_family = AF_INET; sockAddr.sin_port = htons(port); sockAddr.sin_addr.s_addr = inet_addr(address.c_str()); if(bind(sock, (struct sockaddr *)&sockAddr, sizeof(struct sockaddr)) != 0) { stringstream sstream; sstream << "TCP: Could not bind to address " << address << " on port " << port; throw SocketException ( sstream.str() ); } } void Socket::Listen(int backlog) { if(listen(sock, backlog) != 0) { throw SocketException ( "TCP: Could not listen to socket" ); } } void Socket::Accept(Socket &clientSock) { int size = sizeof(struct sockaddr); clientSock.sock = accept(sock,(struct sockaddr *)&clientSock.sockAddr, (socklen_t *)&size); if(clientSock.sock == -1) { throw SocketException ( "TCP: Could not accept incoming connection" ); } } void Socket::Connect(string hostname, int port) { struct in_addr *addr_ptr; struct hostent *hostPtr; string add; hostPtr = gethostbyname(hostname.c_str()); if(hostPtr == NULL) { throw SocketException (string("Could not resolve hostname ").append(hostname)); } addr_ptr = (struct in_addr *)*hostPtr->h_addr_list; add = inet_ntoa(*addr_ptr); if(add == "") { throw SocketException ( "Invalid address" ); } struct sockaddr_in newSockAddr; newSockAddr.sin_family = AF_INET; newSockAddr.sin_port = htons(port); newSockAddr.sin_addr.s_addr = inet_addr(add.c_str()); if(connect(sock, (struct sockaddr *)&newSockAddr, sizeof(struct sockaddr)) != 0) { stringstream sstream; sstream << "Could not connect to " << hostname << " on port " << port; throw SocketException ( sstream.str()); } } int Socket::Receive(char *buff, int length) { int bytes, total = 0; while(total != length) { bytes = recv(sock, buff+total, length-total,0); if ( bytes <= 0 ) { throw SocketException ( "TCP: Could not read from socket." ); } total += bytes; } return total; } int Socket::Send(const char *buff, int length) { int bytes, total = 0; while(total != length) { bytes = send(sock,buff+total,length-total,0); if(bytes == -1) { throw SocketException ( "TCP: Could not write to socket." ); } total += bytes; } return total; } void Socket::Close() { if(sock > 0) { close(sock); sock = 0; } else { throw SocketException ( "TCP: Could not close socket." ); } } void Socket::SendUnencrypted(const string& text) { stringstream sstream; int length = text.length(); if(!length) { string s = "0"; Send(s.c_str(), 2); return; } sstream << length; string len = sstream.str(); Send(len.c_str(), len.length()+1); Send(text.c_str(), text.length()); } const Socket& Socket::operator << ( const std::string& text) { if(RSALENGTH <= 0) { SendUnencrypted(text); return *this; } stringstream sstream; string length, cipher = ""; int size; cipher = encryptAES(text); size = cipher.length(); sstream << size; sstream >> length; length = encryptAES(length); Send(length.c_str(),length.length()); Send(cipher.c_str(),size); return *this; } void Socket::ReceiveUnencrypted(string& text) { text = ""; string len; int length; stringstream sstream; char c; while(true) { if( Receive(&c,1) <= 0) { throw SocketException ( "TCP: Could not read from socket." ); } if(c == '\0') { break; } len += c; } sstream << len; sstream >> length; if(!length) { return; } char *message = new char[length]; Receive(message,length); text.append(message,length); if(message != NULL) { delete[] message; message = NULL; } } const Socket& Socket::operator >> ( std::string& text ) { if(RSALENGTH <= 0) { ReceiveUnencrypted(text); return *this; } int length; stringstream sstream; string recover = "", cipher; char *c = new char[AES::BLOCKSIZE]; Receive(c, AES::BLOCKSIZE); cipher = string(c,AES::BLOCKSIZE); recover = decryptAES(cipher); sstream << recover; sstream >> length; if(c != NULL) { delete[] c; c = NULL; } recover = ""; c = new char[length]; Receive(c, length); cipher = string(c,length); recover = decryptAES(cipher); text = recover; if(c != NULL) { delete[] c; c = NULL; } return *this; } void Socket::generateRSAKeys() { if(RSALENGTH > 0) { privateKey.Initialize(rng, RSALENGTH); RSAFunction publicKey(privateKey); myPublicKey.Initialize(publicKey.GetModulus(),publicKey.GetPublicExponent()); } } void Socket::generateAESKeys() { if(RSALENGTH > 0) { rng.GenerateBlock(myKey, AESLENGTH ); rng.GenerateBlock(myIV, AES::BLOCKSIZE); } } bool Socket::sendPublicKey() { if(RSALENGTH > 0) { stringstream sstream; sstream << myPublicKey.GetModulus(); sstream << " "; sstream << myPublicKey.GetPublicExponent(); string data = sstream.str(); try { SendUnencrypted(data); } catch(SocketException& e) { return false; } } return true; } bool Socket::receivePublicKey() { if(RSALENGTH > 0) { string data = ""; struct timeval tv; tv.tv_sec = 10; tv.tv_usec = 0; setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO, (char *)&tv,sizeof(struct timeval)); try { ReceiveUnencrypted(data); } catch(SocketException& e) { return false; } tv.tv_sec = 0; setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO, (char *)&tv,sizeof(struct timeval)); stringstream sstream; sstream.str(data); Integer modulus,exponent; sstream >> modulus; sstream >> exponent; theirPublicKey.Initialize(modulus,exponent); if(!theirPublicKey.Validate(rng, 1)) { return false; } } return true; } void Socket::setKeys(RSAFunction *pubkey, InvertibleRSAFunction *privkey, byte *key, byte *iv) { if(RSALENGTH > 0) { myPublicKey = *pubkey; privateKey = *privkey; memcpy(myIV,iv,AES::BLOCKSIZE); memcpy(myKey,key,AESLENGTH); } } InvertibleRSAFunction* Socket::getPrivateKey() { return &privateKey; } RSAFunction* Socket::getPublicKey() { return &myPublicKey; } void Socket::sendAES() { if(RSALENGTH > 0) { stringstream sstream; string k((char*)myKey,AESLENGTH); string i((char*)myIV,AES::BLOCKSIZE); sstream << k << " " << i; string key = sstream.str(); string cipher = encryptRSA(key); Send(cipher.c_str(), MAXLENGTH); } } void Socket::recvAES() { if(RSALENGTH > 0) { stringstream sstream; char *c = new char[MAXLENGTH]; Receive(c, MAXLENGTH); string key = string(c,MAXLENGTH), k, i; key = decryptRSA(key); k = key.substr(0,AESLENGTH); i = key.substr(AESLENGTH+1); memcpy(theirKey,k.c_str(),AESLENGTH); memcpy(theirIV,i.c_str(),AES::BLOCKSIZE); if(c != NULL) { delete[] c; c = NULL; } } } byte* Socket::getAESKey() { return myKey; } byte* Socket::getAESIV() { return myIV; } string Socket::encryptRSA(string& text) { RSAES_OAEP_SHA_Encryptor e(theirPublicKey); string cipher; StringSource ss1(text, true, new PK_EncryptorFilter(rng, e, new StringSink(cipher))); return cipher; } string Socket::decryptRSA(string& crypt) { RSAES_OAEP_SHA_Decryptor d(privateKey); string recovered; StringSource ss2(crypt, true,new PK_DecryptorFilter(rng, d,new StringSink(recovered))); // StringSource return recovered; } string Socket::encryptAES(const string& text) { string cipher; AES::Encryption aesEncryption(myKey, AESLENGTH); CBC_Mode_ExternalCipher::Encryption cbcEncryption( aesEncryption, myIV ); StreamTransformationFilter stfEncryptor(cbcEncryption, new StringSink( cipher ) ); stfEncryptor.Put( reinterpret_cast( text.c_str() ), text.length()); stfEncryptor.MessageEnd(); return cipher; } string Socket::decryptAES(const string& crypt) { string recovered; AES::Decryption aesDecryption(theirKey, AESLENGTH); CBC_Mode_ExternalCipher::Decryption cbcDecryption( aesDecryption, theirIV ); CryptoPP::StreamTransformationFilter stfDecryptor(cbcDecryption, new CryptoPP::StringSink( recovered ) ); stfDecryptor.Put( reinterpret_cast( crypt.c_str() ), crypt.size() ); stfDecryptor.MessageEnd(); return recovered; } void Socket::LoadKey(const string& filename, PublicKey *key) { ByteQueue queue; FileSource file(filename.c_str(), true); file.TransferTo(queue); queue.MessageEnd(); key->Load(queue); } void Socket::SaveKey(const string& filename, const PublicKey *key) { ByteQueue queue; key->Save(queue); FileSink file(filename.c_str()); queue.CopyTo(file); file.MessageEnd(); } void Socket::loadKeys(string pub, string priv) { if(RSALENGTH > 0) { ifstream pubkey,privkey; pubkey.open(pub.c_str()); privkey.open(priv.c_str()); if(pubkey.is_open() || privkey.is_open()) { pubkey.close(); privkey.close(); LoadKey(pub, getPublicKey()); LoadKey(priv, getPrivateKey()); if(getPublicKey()->Validate(rng, 1) || getPrivateKey()->Validate(rng,1)) { generateAESKeys(); return; } } pubkey.close(); privkey.close(); generateRSAKeys(); SaveKey(priv, getPrivateKey()); SaveKey(pub, getPublicKey()); generateAESKeys(); } }