/*-
 * Copyright (c) 1996, Trusted Information Systems, Incorporated
 * All rights reserved.
 *
 * Redistribution and use are governed by the terms detailed in the
 * license document ("LICENSE") included with the toolkit.
 */

/*
 *      Author: Kelly Djahandari, Trusted Information Systems, Inc.
 */

#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include <termios.h>
#include <string.h>
#include <signal.h>
#include <pwd.h>
#include <fcntl.h>
#include <sys/types.h>
#include <sys/file.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include "mbone-gw.h"
#include "wrapper.h"

#define MAXADDRLEN 20
#define MAXPORTLEN 10

#define MAXMSG 2048

#define DEFAULT_TTL 15

extern pid_t wait();
extern unsigned long inet_addr();
extern int errno;


extern char *optarg;

static int	authenticate_user();
static int	auth_done();
static int	connect_fw();
static void 	get_addr_port();
static int 	get_fw_response();
static int 	get_nl();
static int 	get_ttl();
static void 	mcast_disable();
static int 	mcast_send();
static int 	process_response();
static int 	readfw();
static void 	sig_child();
static int 	writefw();

extern char 	*strtok();

static char 	addr[MAXADDRLEN];
static int	sock;
static int 	is_sdr = 0;

#ifndef _PASSWORD_LEN
#define _PASSWORD_LEN 128
#endif

static void
usage()
{
	(void)fprintf (stderr, "usage: application [arguments] address/port\n");
	exit(1);
}

int
main(argc, argv)
int argc;
char *argv[];
{
	char port[MAXPORTLEN];
	pid_t	pid;
	unsigned short fw_ucast_port;
	uid_t uid;
	unsigned short ttl;
	int	n;
	char	mreply[MAXMSG];	/* error from fw */
	int i;
	int found;

        int len, dirlen, pathlen;
        char appl[MAX_APPL_PATH_STR]; /* application name */
	char *appl_str;
        char *newarg;


	/* Get the application name without directory information */
	appl_str = strrchr(argv[0], '/');
	if (appl_str == NULL) {
		strncpy (appl_str, argv[0], strlen(argv[0]));
	}
	else {
		appl_str++;  /* get beyond '/' */
		strncpy(argv[0], appl_str, strlen(appl_str)+1);
	}
        	
	/* Check if started by "wrapper mbone_app" */
	/* Remove wrapper from argv list */
	found = 0;
	bzero(appl, sizeof(appl));
	if (!strncmp(argv[0], WRAPPER_NAME, strlen(WRAPPER_NAME))) {
		for (i=0; i<num_mbone_apps; i++) {
			if (!strncmp(argv[1], 
			    mbone_app[i],
			    strlen(mbone_app[i]))) {
				strncpy (appl, mbone_app[i], 
					strlen(mbone_app[i]));
				*++argv;
				--argc;
				found++;
				break;
			}
		}
		if (!found) {  
			(void)fprintf(stderr, "MBone application %s not in list of approved applications\n", argv[1]);
			exit(1);
		}
	}

	/* Change appl/argv[0] to the real application */
	dirlen = strlen(MBONE_APPL_DIRECTORY);

	/* Make sure we have enough for path and application name */
	if (dirlen + 4 > MAX_APPL_PATH_STR) {
		fprintf(stderr, "Path length too long\n");
		exit(1);
	}

	strncpy(appl, MBONE_APPL_DIRECTORY, dirlen);
	len = strlen(argv[0]);
	strncpy(appl+dirlen, argv[0], len);
	pathlen = len + dirlen;
	newarg = (char *)malloc(pathlen +1);
	if (newarg == (char *) 0) {
               	(void)fprintf(stderr, "Out of memory\n");
                exit(1);
	}

	argv[0] = newarg;
	strncpy(argv[0], appl, pathlen);
	argv[0][pathlen] = '\0';

		
	/* do some argument processing.  Command line options will stay */
	/* the same (except for the address replacement) when given to the */
	/* "real" application */

	ttl = get_ttl(argc, argv);

	/* Check to make sure we are not running as root */
	uid = getuid();
	if (uid == 0) {
		(void)fprintf(stderr, "Do not run as root\n");
		exit(1);
	}

	/* Is the application sdr? */
	if (!strncmp(&appl[strlen(appl)-3], "sdr", 3) != NULL) {
		is_sdr = 1;
	}

	/* Are there enough arguments? */
	if (!is_sdr && (argc < 2)) {
		usage();
	}

	/* If sdr, use the well known sdr multicast address and port */
	if (is_sdr) {
		(void)strncpy(addr, SDR_ADDR, strlen(SDR_ADDR));
		(void)strncpy(port, SDR_PORT, strlen(SDR_PORT));
	}
	else {
		/* Get the address/port from the command line (last argument) */
		get_addr_port(argv[argc-1], (char *)addr, (char *)port);
	}

	/* Connect with firewall */
	if (connect_fw(&sock) < 0) {
		(void)fprintf(stderr, "Exiting %s\n", appl);
		exit(1);
	}
        (void)fprintf(stderr, "%s connected\n", appl);

	/* Send multicast information to firewall */
	if (mcast_send(appl, addr, port, ttl, argv, argc) < 0) {
		(void)fprintf(stderr, "Exiting %s\n", appl);
		exit(1);
	}

	/* Get firewall response */
	if (process_response(&fw_ucast_port) < 0) {
		mcast_disable(addr); /* don't care about errors */
		(void)fprintf(stderr, "Exiting %s\n", appl);
		exit(1);
	}

	/* if not sdr, substitute fw address for multicast address */
	/* substitute unicast port for port */
	if (!is_sdr) {
		(void)sprintf(argv[argc-1], "%s/%d", 
				INSIDE_INTERFACE, fw_ucast_port);
	}

	/* Establish SIGCHLD handler for when child terminates */
	(void)signal(SIGCHLD, sig_child);

	/* Fork/exec child process to run "real" application */
	/* Parent waits for anything from firewall */
	bzero((char *)mreply, sizeof(mreply));
	if ((pid = fork()) < 0)
		perror("fork error");
	else if (pid == 0) { /* child */
		/* start "real" application with same options, new fw address */
		(void)execvp(argv[0], &argv[0]);
		/* not reached unless error */
		(void)fprintf(stderr, "%s\n", argv[0]);
		perror("execvp error");
	}
	else { /* parent */
		/* wait to read anything from firewall proxy */
		n = read(sock, (char *)mreply, MAXMSG-1);
		if (n > 0) {
			(void)fprintf(stderr, "%s\n", mreply);
		}
		/* kill the child process */
		/* The child process will send the MCAST_DROP message to fw */
		(void)kill(pid, SIGKILL);
	}
	return(0);
}


/* 
 * Time to live (ttl) can be an optional argument on the command line 
 * ( -ttl 15), or with the multicast address and port specification 
 * (dest/port/fmt/ttl).  If neither are specified, use the ttl default.
 */
static int
get_ttl(argc, argv)
int argc;
char *argv[];
{
	char *tok;
	char sep = '/';
	char *addr_str;
	int ttl;
	int i;

	if (argc <=1) {
		return(DEFAULT_TTL);
	}

	ttl = -1;
	for (i=0; i<argc; i++) {
		if (!strncmp(argv[i], "-t", 2)) {
			/* just in case no space between -t and ttl value */
			if (strlen(argv[i]) > 2)
				ttl = atoi(argv[i]+2);
			else 
				ttl = atoi(argv[i+1]);

			/* Check if ttl in proper range */
			if (ttl < 0 || ttl > 255 ) {
				(void)fprintf(stderr, "invalid ttl\n");
				ttl = DEFAULT_TTL;
			}
			return(ttl);
		}
	}
	if (ttl == -1) {
		addr_str = (char *)malloc(strlen(argv[argc-1]) + 1);	
		if (addr_str == (char *) 0) {
			(void)fprintf(stderr, "Out of memory\n");
			exit(1);
		}
		(void)strncpy(addr_str, argv[argc-1], strlen(argv[argc-1]));
		/* see if ttl is specified on /dest/port/fmt/ttl */
		/* dest */
		if ((tok = strtok(addr_str, &sep)) == (char *)0)
			ttl = DEFAULT_TTL;
		/* port */
		else if ((tok = strtok((char *)0, &sep)) == (char *)0)
			ttl = DEFAULT_TTL;
		/* fmt */
		else if ((tok = strtok((char *)0, &sep)) == (char *)0)
			ttl = DEFAULT_TTL;
		/* ttl */
		else if ((tok = strtok((char *)0, &sep)) == (char *)0)
			ttl = DEFAULT_TTL;
		else
			ttl = atoi(tok);

		/* Check if ttl in proper range */
		if (ttl < 0 || ttl > 255 ) {
			(void)fprintf(stderr, "invalid ttl\n");
			ttl = DEFAULT_TTL;
		}
	}
	return(ttl);
}


/* This routine separates the "address/port" string into separate */
/* address and port strings */
static void
get_addr_port(argv, addr, port)
char *argv, *addr, *port;
{
	char *sep;
	(void)strncpy(addr, argv, MAXADDRLEN-1);
	addr[MAXADDRLEN] = '\0';
	sep = (char *)strchr(addr, '/');
	if (sep == NULL)
		usage();

	*sep = '\0';
	sep++;
	(void)strncpy(port, sep, MAXPORTLEN-1);
	return;
}


/* 
 * Make a TCP connection to the firewall.
 */
static int
connect_fw(sock)
int *sock;
{
	struct sockaddr_in fw_sin;  /* sockaddr of firewall */

        /* Connect to the firewall */
        bzero((char *)&fw_sin, sizeof(fw_sin));
        fw_sin.sin_family = AF_INET;
        fw_sin.sin_addr.s_addr = inet_addr(INSIDE_INTERFACE);
        fw_sin.sin_port = htons(MBONESRV_PORT);

#ifdef BINDDEBUG        /* if debugging, use debugport */
        fw_sin.sin_port = htons(BINDDEBUGPORT);
#endif


        /* Allocate a socket for TCP commumication with firewall */
        if ((*sock = socket(AF_INET, SOCK_STREAM, 0)) < 0) {
                perror("socket error");
                return(-1);
        }

        if (connect(*sock, (struct sockaddr *)&fw_sin, sizeof(fw_sin)) < 0) {
                perror("connect error");
                return(-1);
        }
	return(0);
}


/*
 * Send the multicast address/port/ttl/application information to the mbone 
 * proxy on the firewall.
 */
static int
mcast_send(appl, addr, port, ttl, argv, argc)
char *appl, *addr, *port;
unsigned short ttl;
int argc;
char *argv[];
{
	struct mcast_info minfo;
	int i, n;

	/* Create the information to send to firewall */
	bzero((char *)&minfo, sizeof(minfo));
	minfo.type = htons(MCAST_ADD);
	minfo.mcast_group = inet_addr(addr);
	minfo.mcast_port = htons((unsigned short) atoi(port));
	minfo.ttl = htons(ttl);

	(void)strncpy(minfo.mcast_appl, appl, MAXAPPLSTR-1);

	/* vic uses rtpv2 */
	/* vat uses rtpv2 if -r specified (default with sdr); else, rtpv0 */
	/* wb, nt, sdr don't use rtp */
	minfo.rtp = htons(0);
	if (!strncmp(appl, "vic", 3)) 
		minfo.rtp = htons(2);
	else if (!strncmp(appl, "vat", 3)) {
		for (i=0; i<argc; i++) {
			if (!strncmp(argv[i], "-r", 2)) {
				minfo.rtp = htons(2);
				break;
			}
		}
		if (minfo.rtp == htons(0))
			minfo.rtp = htons(1);
	}


	(void)strncpy(minfo.mcast_user, getlogin(), MAXUSRSTR-1);

	/* Send the multicast group address/ port to enable */
	n = write(sock, (char *)&minfo, sizeof(minfo));
	if (n < sizeof(minfo)) {
		(void)fprintf(stderr, "sent only %x bytes instead of %x\n", n,sizeof(minfo));
		perror("write error");
		return(-1);
	}
	return(0);

}

/* 
 * Get a response from the firewall.  If authentication needed, the 
 * authentication steps are performed.  A success message will contain the 
 * unicast port the firewall is using to receive data from the inside client.
 */
static int
process_response(fw_ucast_port)
unsigned short *fw_ucast_port;
{
	static char *username = "Username: ";
	char	mreply[MAXMSG];		/* success or error from fw */
	struct mcast_reply fwreply;
	int n;

        bzero((char *)mreply, sizeof(mreply));
        n = read(sock, (char *)mreply, MAXMSG-1);
        if (n < 0) {
                perror("read error");
                return(-1);
        }
        else if (n == 0) {
                return(-1);
        }

        if (!strncmp(mreply, username, strlen(username))) {

		/* Lock stdout, otherwise if other 
		 * apps are trying to authenticate, things get confused */
		flock(STDOUT_FILENO, LOCK_EX);

                /* Authentication required */
                (void)fprintf(stderr, "Authentication required\n");
                if (authenticate_user() < 0) {
                        (void)fprintf(stderr, 
				"\nError authenticating user\n");
                        return(-1);
                }

		flock(STDOUT_FILENO,LOCK_UN);

                /* Get the reply for the minfo */
                bzero((char *)mreply, sizeof(mreply));
                n = read(sock, (char *)mreply, 2);
                if (n < 0) {
                        perror("read error");
                        return(-1);
                }
                if (mreply[0] == '\r' && mreply[1]== '\n') {
                        if (write(STDOUT_FILENO, mreply, 2) != 2) return(-1);
                        if (read(sock, (char *)mreply, sizeof(mreply)-1) < 0)
                                return(-1);
                }
                else {
                        if (read(sock,(char *)&mreply[2],sizeof(mreply)-3) < 0)                                return(-1);
                }

        }
        if (!strncmp(mreply, GOOD_REPLY, sizeof(GOOD_REPLY))) {
                /* SUCCESS !*/
                bcopy ((char *)mreply, (char *)&fwreply,sizeof(fwreply));
                *fw_ucast_port = ntohs(fwreply.fw_ucast_port);
        }
        else if (!strncmp(mreply, BAD_REPLY, sizeof(BAD_REPLY))) {
                /* ucast_port problem !*/
                perror("Error from firewall");
                return(-1);
        }
        else {
                /* Error from firewall printed */
                (void)fprintf(stderr, "%s\n", mreply);
                return(-1);
        }
        return(0);

}


/* 
 * Send a drop multicast membership message to the firewall.
 */
static void
mcast_disable(addr)
char *addr;
{
	struct mcast_info minfo;
	int	n;

	bzero((char *)&minfo, sizeof(minfo));
	minfo.type = htons(MCAST_DROP);
	minfo.mcast_group = htonl(inet_addr(addr));
	minfo.mcast_port = htons(0);

	/* Send the multicast group address to disable */
	n = writefw((char *)&minfo, sizeof(minfo));
	if (n < sizeof(minfo)) {
		(void)fprintf(stderr,"sent only %x bytes instead of %x\n", n,sizeof(minfo));
		perror("write error");
		return;
	}
	return;
}

/* 
   This routine will pass the authentication information to the 
   firewall proxy.  It will print all messages received from the proxy.
   The firewall proxy sends Username, Password, and Challenge without \r\n.
   All other responses end with \r\n.
*/
static int
authenticate_user()
{
	int n;
	char buf[2048];
	char passwd[_PASSWORD_LEN];
	char *passptr = passwd;
	static char	prompt[] = "Username: ";
	static char     pprompt[] = "Password: ";
	static char     cprompt[] = "Challenge: ";
	int ret;
	
	/* Write the Username prompt to the user */
	if (write(STDOUT_FILENO, prompt, strlen(prompt)) != strlen(prompt))
		return(-1);

	/* Loop until successful login or unsuccessful and no more tries */
	while (1) {
		/* Get user input */
		n = read(STDIN_FILENO, buf, sizeof(buf)-1);
		if (n <= 0) return (-1);
		if (n == 1) continue;
		buf[n-1] = '\0';   /* remove the newline */
		/* send to firewall */
		n = writefw(buf, strlen(buf));
		if (n != strlen(buf)) return(-1);

		/* read fw response, display to user */
		if (get_fw_response(buf, sizeof(buf)) < 0) {
			return(-1);
		}

		/* Password or Challenge */
		if (!strncmp(buf, pprompt, strlen(pprompt)) ||
		    strstr(buf, cprompt)) {
			if (!strncmp(buf, pprompt, strlen(pprompt))){
				/* Password */
				passptr = (char *)getpass("");
				bzero(buf, sizeof(buf));
				(void)strncpy(buf, passptr, strlen(passptr));
				bzero(passptr, (int)strlen(passptr));
			}
			else { 		/* Challenge */
				bzero(buf, sizeof(buf));
				n = read(STDIN_FILENO, buf, sizeof(buf)-1);
				/* remove nl */
				buf[n-1] = '\0';
			}

			/* send to firewall */
			n = writefw(buf, strlen(buf));
			if (n != strlen(buf)) {
				bzero(buf, sizeof(buf));
				return(-1);
			}
			if (get_fw_response(buf, sizeof(buf)) < 0) 
				return(-1);
		}
		/* Reply from proxy */
		if (buf[strlen(buf)-1] != '\n') {
			if ((ret = get_nl()) < 0) return(ret);
		}
		if ((ret = auth_done(buf)) <= 0) return(ret);
		if (get_fw_response(buf, sizeof(buf)) < 0)
			return(-1);
		if ((ret = auth_done(buf)) <= 0) return(ret);
	}
	return(0); /* not reached */
}

/* Checks response from firewall proxy */
/* 
   returns 0 if successful login
           1 if unsuccessful, but can try again
          -1 if unsuccessful, and last try
*/

static int
auth_done(buf)
char *buf;
{

	static char	noauth[] = "Cannot connect to authentication server";
	static char	toobad[] = "Too many failed login attempts";
	static char	lostit[] = "Lost connection to authentication server";
	static char	loggedin[] = "Login Accepted";

	if (!strncmp(buf, loggedin, strlen(loggedin))) {
			/* successful login */
		return(0);
	}

	else if (!strncmp(buf, noauth, strlen(noauth)) ||
		 !strncmp(buf, toobad, strlen(toobad)) ||
		 !strncmp(buf, lostit, strlen(lostit))) {
		return(-1);
	}

	else {
		/* error, but try again */
		return(1);
	}
}


/*
 * Read response from firewall.  Print to user.
 */
static int
get_fw_response(buf, bufsiz)
char *buf;
int bufsiz;
{
	int n;

	bzero(buf, bufsiz);
	/* read from firewall */
	n = readfw(buf);
	if (n < 0) return(-1);
	/* write to user */
	n = write(STDOUT_FILENO, buf, strlen(buf));
	if (n != strlen(buf)) return(-1);
	return(0);
}


/*
 * Get the newline from firewall.  Print to user.
 */
static int
get_nl()
{
	int n;
	char tmp[2];

	/* read from firewall */
	n = readfw(tmp);
	if (n < 0) return(-1);
	if (tmp[1] != '\n') return(-1);
	/* write to user */
	n = write(STDOUT_FILENO, tmp, sizeof(tmp));
	if (n != sizeof(tmp)) return(-1);
	return(0);
}

/* 
Read from the STREAM socket connected to the firewall.  
Needed because a read() may input fewer bytes than requested.  
(Not needed with DATAGRAM sockets)
*/
static int
readfw(buf)
char *buf;
{
	int nread = 0;  		/* number of bytes read */

	while (1) {
		if (read(sock, (char *)&buf[nread], 1) != 1) {
			perror("read error");
			return(-1);
		}
		/* look for \r\n */
		if (buf[nread] == '\r') {
			nread++;
			if (read(sock, (char *)&buf[nread], 1) != 1) {
				perror("read error");
				return(-1);
			}
			if (buf[nread] == '\n') {
				nread++;
				return(nread);
			}
		}
		else if (buf[nread] == '\0') {
			return(nread);
		}

		nread++;
	}
	return(0); /* not reached */
}

/* 
Write to the STREAM socket connected to the firewall.  
Needed because a write() may output fewer bytes than requested.  
(Not needed with DATAGRAM sockets)
*/
static int
writefw(buf, siz)
char *buf;
unsigned int siz;
{
	int nleft;  		/* number of bytes left to write */
	int nwrite;  		/* number of bytes written */

	nleft = siz;
	while (nleft > 0) {
		nwrite = write(sock, buf, nleft);
		if (nwrite < 0) return(nwrite);	/* error */

		nleft -= nwrite;
		buf += nwrite;
	}
	return(siz - nleft);
}


/* 
 * Routine called with the child process (which is running the MBone 
 * application) terminates.
 */
static void
sig_child()
{
	int status;

	if (wait(&status) < 0) {
		perror("error"); /* Child process terminated abnormally */
	}

	mcast_disable(addr); /* don't care about errors */

	return;
}
