check_tftp.cpp 4 KB
#include "check_tftp.h"

char *servicename = (char*)"TFTP";

uint8_t recvFile(char *filename, string *error, int s, struct sockaddr_in *server)
{
	int tid = 0;
	struct sockaddr_in data;
	uint16_t numBlock;
	char buffer[BUFSIZE+4];
	ssize_t dataLength = 512;
	if(sendRRQ(filename,MODE_OCTAL,s,server) < 0)
	{
		*error = "Socket error while sending RRQ";
		return 2;
	}
	while(dataLength == 512)
	{
		memset(buffer,0x00,BUFSIZE+4);
		dataLength = recvMsg(s,buffer,BUFSIZE+4,&data) - 4;
		if(dataLength < 0)
		{
			*error = "Timeout occurred";
			return 2;
		}
		if(!tid)
		{
			tid = ntohs(data.sin_port);
			server->sin_port = htons(tid);
		}
		if(buffer[1] != DATA)
		{
			*error += (char)buffer[3]+48;
			*error += ": ";
			*error += buffer+4;
			return 1;
		}
		numBlock = 0;
		numBlock |= ((uint8_t)buffer[2] << 8);
		numBlock |= (uint8_t)buffer[3];
		if(sendACK(numBlock,s,server) < 0)
		{
			*error = "Socket error while sending ACK";
			return 2;
		}
	}
	return 0;
}

uint8_t sendACK(uint16_t numBlock, int s, struct sockaddr_in *si)
{
	char message[4];
	message[0] = 0;
	message[1] = ACK;
	message[2] = (numBlock & 0xFF00) >> 8;
	message[3] = numBlock & 0x00FF;
	return sendMsg(s,message,4,si);
}

uint8_t sendRRQ(char *filename, char *mode, int s, struct sockaddr_in *si)
{
	size_t length = 2 + strlen(filename) + 1 + strlen(mode) + 1;
	char *message = new char[length];
	message[0] = 0;
	message[1] = RRQ;
	strcpy(message+2,filename);
	strcpy(message+2+strlen(filename)+1,mode);
	return sendMsg(s,message,length,si);
}

void printVersion()
{
	cout << "check_tftp v" << VERSION << endl << endl;
}

void printHelp(bool longVersion)
{
	if(longVersion)
	{
		printVersion();
		cout << "Checks if a remote TFTP server is working by retrieving a file from it." << endl << endl;
		printHelp(false);
		cout << "Options:" << endl;
		cout << " -h" << endl;
		cout << "    Print detailed help screen" << endl;
		cout << " -V" << endl;
		cout << "    Print version information" << endl;
		cout << " -H" << endl;
		cout << "    Hostname of the TFTP server" << endl;
		cout << " -p" << endl;
		cout << "    Port where the server is listening. If not specified, uses default TFTP port" << endl;
		cout << " -t" << endl;
		cout << "    Timeout for the retrieval operation. If not specified, uses 10 seconds timeout" << endl;
		cout << " -f" << endl;
		cout << "    Filename to retrieve" << endl << endl;

		return;
	}
	cout << "Usage: " << endl << "check_tftp [-hV] | -H HOSTNAME -f FILENAME [-t TIMEOUT -p PORT]" << endl << endl;
}

int main(int argc, char **argv)
{
	int timeout = 10;
	int returnCode;
	char *filename = NULL;
	char *hostname = NULL;
	uint16_t port = DEFAULT_PORT;
	struct sockaddr_in si;
	struct hostent *host;
	int s;

	int c;
        while ((c = getopt (argc, argv, "VhH:p:f:t:")) != -1)
        {
                switch(c)
                {
                        case 'h':
                                printHelp(true);
				return 0;
                        case 'V':
                                printVersion();
				return 0;
                        case 'H':
                                hostname = optarg;
				break;
                        case 'p':
				port = (uint16_t)str2int(string(optarg));
				break;
                        case 't':
				timeout = str2int(string(optarg));
				break;
			case 'f':
				filename = optarg;
				break;
                        case '?':
                                printHelp(false);
                                return 3;
		}
	}

	if(hostname == NULL)
	{
		cout << "Hostname not specified" << endl;
		return 3;
	}
	if(filename == NULL)
	{
		cout << "Filename not specified" << endl;
		return 3;
	}

	setupSocket(port,hostname,host,timeout,&si,&s);
	string error;
	returnCode = recvFile(filename,&error,s,&si);

	cout << servicename;
	if(!returnCode)
	{
		cout << " OK - " << filename << " retrieved successfully" << endl;
	}
	else if(returnCode == 1)
	{
		cout << " WARNING - server answered with ERROR " << error << endl;
	}
	else
	{
		cout << " CRITICAL - " << error << endl;
	}
	close(s);
	return returnCode;
}
//printHelp