/*
 * This program is an implementation of the ISAKMP Internet Standard.
 * Copyright (C) 1997 Angelos D. Keromytis.
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
 * USA
 * 
 * This code was written while the author was in Greece, in May/June
 * 1997.
 *
 * You may contact the author by:
 *   e-mail: angelos@dsl.cis.upenn.edu
 *  US-mail: Angelos D. Keromytis
 *           Distributed Systems Lab
 *           Computer and Information Science Department
 *           University of Pennsylvania
 *           Moore Building
 *           200 South 33rd Street
 *           Philadelphia, PA 19104	   
 */

#include <stdio.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include "constants.h"
#include "state.h"
#include "packet.h"
#include "defs.h"
#include "md5.h"
#include "sha1.h"

/*
 * This file does basic header checking and demux of
 * incoming packets.
 */

extern char *get_address();

#define PRINTADDRESS(x) (get_address(x) != (char *) NULL ? get_address(x) : "(unknown address family)")

/*
 * Receive a packet. If we pass buffer to a routine that does not return
 * failure indication (e.g. the packet handling routines), it's up to them
 * to free it; otherwise this routine does.
 */
void
comm_handle(int kernelfd, int sock)
{
    struct isakmp_generic *isag;
    u_char *buffer;
    int length, i, k;
    struct sockaddr_in sin;
    struct isakmp_hdr *isa, tisa;
    struct isakmp_sa *isasa;
    struct state *st;
    u_int32_t ipsecdoisit;
    int encrypted = 0, lastblocksize;
    u_int32_t des_cbc_keys[16][2];
    u_char lastblock[256];		/* Increase if block size larger */
    struct state *nst;
    MD5_CTX md5ctx;
    SHA1_CTX sha1ctx;
    
    length = sizeof(sin);
    bzero(&sin, sizeof(sin));
    sin.sin_family = AF_INET;
    
#ifndef linux
    /* Find how large is this message by peeking at the header */
    i = recvfrom(sock, &tisa, sizeof(tisa), MSG_PEEK, (struct sockaddr *)&sin,
		 &length);
    if (i == -1)
      log(1, "recvfrom() failed in comm_handle()", 0, 0, 0);
    else
      if (i != sizeof(tisa))
      {
	  log(0, "recvfrom() failed in comm_handle(), message too small",
	      0, 0, 0);
	  return;
      }

    tisa.isa_length = ntohl(tisa.isa_length);
    i = tisa.isa_length;

    /* Allocate exactly that much space */
    buffer = (u_char *) calloc(i, sizeof(u_char));
    if (buffer == (u_char *) NULL)
      exit_log("calloc() failed in comm_handle()", 0, 0, 0);

    /* Now really read the message */
    if ((i = recvfrom(sock, buffer, i, 0, (struct sockaddr *)&sin, &length))
	== -1)
    {
	log(1, "recvfrom() failed in comm_handle", 0, 0, 0);
	free(buffer);
	return;
    }

    if (i != tisa.isa_length)
    {
	log(0, "comm_handle(): mismatch between real and reported size (%d/%d)",
	    i, tisa.isa_length, 0);
	free(buffer);
	return;
    }
#else
    i = UDP_SIZE;

    /* Now really read the message */
    if ((i = recvfrom(sock, nullbuffer, i, 0, 
		      (struct sockaddr *)&sin, &length)) == -1)
    {
	log(1, "recvfrom() failed in comm_handle", 0, 0, 0);
	return;
    }

    if (i < sizeof(struct isakmp_hdr))
    {
	log(0, "too short packet received", 0, 0, 0);
	return;
    }

    bcopy(nullbuffer, &tisa, sizeof(struct isakmp_hdr));
    
    i = ntohl(tisa.isa_length);
    
    /* Allocate exactly that much space */
    buffer = (u_char *) calloc(i, sizeof(u_char));
    if (buffer == (u_char *) NULL)
      exit_log("calloc() failed in comm_handle()", 0, 0, 0);

    bcopy(nullbuffer, buffer, i);
#endif

#ifdef DEBUG
    log(0, "read %d bytes from %s, port %d", i,
	PRINTADDRESS(sin), get_port(sin));
#endif

    isa = (struct isakmp_hdr *) buffer;
    if (isa->isa_maj != ISAKMP_MAJOR_VERSION)
    {
	free(buffer);
	log(0, "invalid major version number %d from %s, port %d", 
	    isa->isa_maj, PRINTADDRESS(sin), get_port(sin));
	/* XXX Could send a notification back */
	return;
    }
    
    if (isa->isa_min != ISAKMP_MINOR_VERSION)
    {
	free(buffer);
	log(0, "invalid minor version number %d from %s, port %d", 
	    isa->isa_min, PRINTADDRESS(sin), get_port(sin));
	/* XXX Could send a notification back */
	return;
    }

    switch (isa->isa_xchg)
    {
	case ISAKMP_XCHG_IDPROT:
	case ISAKMP_XCHG_QUICK:
	case ISAKMP_XCHG_INFO:
/* XXX	case ISAKMP_XCHG_AGGR: */
	    break;
	default:
	    free(buffer);
	    log(0, "unsupport exchange type %d from %s, port %d",
		isa->isa_xchg, PRINTADDRESS(sin), get_port(sin));
	    /* XXX Could send a notification back */
	    return;
    }
    
    if ((isa->isa_xchg == ISAKMP_XCHG_IDPROT) &&
	(isa->isa_msgid != 0))
    {
	free(buffer);
	log(0, "message id should be zero (was %08x) in base mode",
	    isa->isa_msgid, 0, 0);
	/* XXX Could send notification back */
	return;
    }
    
    st = (struct state *) find_full_state(isa->isa_icookie, 
					  isa->isa_rcookie,
					  sin, isa->isa_msgid);
#ifdef DEBUG
    if (st)
      log(0, "full state object found", 0, 0, 0);
#endif

    if (isa->isa_flags & ISAKMP_FLAG_ENCRYPTION) 
    {
	if (st == (struct state *) NULL)
	{
#ifdef DEBUG
	    log(0, "received encrypted packet from %s, port %d, for which no state can be found", PRINTADDRESS(sin), get_port(sin), 0);
#endif

	    /* Find the Phase 1 object */
	    st = (struct state *) find_full_state(isa->isa_icookie,
						  isa->isa_rcookie, sin, 0);
	    if (st == (struct state *) NULL)
	    {
		log(0, "attempt to use non-existant (expired?) ISAKMP SA with %s, port %d", PRINTADDRESS(sin), get_port(sin));
		free(buffer);
		/* XXX Could send notification back */
		return;
	    }

	    if ((st->st_state != OAKLEY_MAIN_I_4) &&
		(st->st_state != OAKLEY_MAIN_R_3))
	    {
		log(0, "attempt to use non-fully established ISAKMP SA with %s, port %d", PRINTADDRESS(sin), get_port(sin));
		free(buffer);
		/* XXX Could send notification back */
		return;
	    }

	    nst = (struct state *) duplicate_state(st);
	    nst->st_msgid = isa->isa_msgid;
	    nst->st_state = OAKLEY_QUICK_R_1;
	    insert_state(nst);
	    
	    /* Quick Mode Initial IV */
	    if (nst->st_prf == 0)
	    {
		switch (nst->st_hash)
		{
		    case OAKLEY_MD5:
			nst->st_iv_len = 16;
			nst->st_iv = (u_char *) calloc(nst->st_iv_len,
						       sizeof(u_char));
			if (nst->st_iv == (u_char *) NULL)
			  exit_log("calloc() failed in comm_handle()", 
				   0, 0, 0);
			
			MD5Init(&md5ctx);
			MD5Update(&md5ctx, st->st_lastblock,
				  st->st_lastblock_len);
			MD5Update(&md5ctx, (u_char *) &(isa->isa_msgid),
				  sizeof(isa->isa_msgid));
			MD5Final(nst->st_iv, &md5ctx);
			break;
			
		    case OAKLEY_SHA:
			nst->st_iv_len = 20;
			nst->st_iv = (u_char *) calloc(nst->st_iv_len,
						       sizeof(u_char));
			if (nst->st_iv == (u_char *) NULL)
			  exit_log("calloc() failed in comm_handle()", 
				   0, 0, 0);
			
			SHA1Init(&sha1ctx);
			SHA1Update(&sha1ctx, st->st_lastblock,
				   st->st_lastblock_len);
			SHA1Update(&sha1ctx, (u_char *) &(isa->isa_msgid),
				   sizeof(isa->isa_msgid));
			SHA1Final(nst->st_iv, &sha1ctx);
			break;

		    default:
			exit_log("unknown/unsupport hash algorithm %d specified in comm_handle()", nst->st_hash, 0, 0);
		}
	    }
	    else
	    {
		/* XXX Handle 3DES MAC */
	    }

#ifdef DEBUG
	    log(0, "computed phase 2 IV: ");
	    for (k = 0; k < nst->st_iv_len; k++)
	      fprintf(stderr, "%02x ", nst->st_iv[k]);
	    fprintf(stderr, "\n");
#endif

	    st = nst;
	}
	
#ifdef DEBUG
	log(0, "received encrypted packet from %s, port %d",
	    PRINTADDRESS(sin), get_port(sin), 0);
#endif
	if ((st->st_skeyid_e == (u_char *) NULL) ||
	    (st->st_iv == (u_char *) NULL))
	{
	    log(0, "unexpected encrypted packet received from %s, port %d",
		PRINTADDRESS(sin), get_port(sin), 0);
	    free(buffer);
	    /* XXX Could send notification back */
	    return;
	}
	  
	/* Mark as encrypted */
	encrypted = 1;

#ifdef DEBUG
	log(0, "decrypting %d bytes using algorithm %d",
	    i - sizeof(struct isakmp_hdr), st->st_enc, 0);
#endif
	
	switch (st->st_enc)
	{
	    case OAKLEY_DES_CBC:
		/* XXX Detect weak keys */

		/* Copy last block in case we need it */
		lastblocksize = DES_CBC_BLOCK_SIZE;
		bcopy(buffer + i - lastblocksize, lastblock, lastblocksize);

#ifdef DEBUG
		log(0, "keeping last %d bytes, just in case", lastblocksize,
		    0, 0);
#endif

#ifdef DEBUG
		log(0, "new IV: ", 0, 0, 0);
		for (k = 0; k < lastblocksize; k++)
		  fprintf(stderr, "%02x ", lastblock[k]);
		fprintf(stderr, "\n");
#endif

		/* Decrypt */
		des_set_key(st->st_skeyid_e, des_cbc_keys);
		des_cbc_encrypt(buffer + sizeof(struct isakmp_hdr),
				buffer + sizeof(struct isakmp_hdr),
				i - sizeof(struct isakmp_hdr),
				des_cbc_keys, st->st_iv, 0);

		/* If not enough space, free it */
		if ((st->st_iv != (u_char *) NULL) &&
		    (st->st_iv_len < lastblocksize))
		{
		    free(st->st_iv);
		    st->st_iv = (u_char *) NULL;
		}

		/* Allocate enough space */
		if (st->st_iv == (u_char *) NULL)
		{
		    st->st_iv = (u_char *) calloc(lastblocksize,
						  sizeof(u_char));
		    if (st->st_iv == (u_char *) NULL)
		      exit_log("calloc() failed in comm_handle()", 0, 0, 0);
		}
		  
		/* Update the IV */
		bcopy(lastblock, st->st_iv, lastblocksize);
		st->st_iv_len = lastblocksize;
		break;
		  
		/* XXX Handle more */
	    default:
		exit_log("unknown encryption algorithm %d specified",
			 st->st_enc, 0, 0);
	}
	  
	/* Adjust for padding */
	k = sizeof(struct isakmp_hdr);
	isag = (struct isakmp_generic *) (buffer + k);
	while (isag->isag_np != ISAKMP_NEXT_NONE)
	{
	    k += ntohs(isag->isag_length);
	    if (k >= i)
	    {
		log(0, "malformed packet from %s, port %d", 
		    PRINTADDRESS(sin), get_port(sin), 0);
		free(st->st_iv);
		free(buffer);
		/* XXX Could send notification back */
		return;
	    }
	    
	    isag = (struct isakmp_generic *) (buffer + k);
	}

	k += ntohs(isag->isag_length);
	if (k > i)
	{
	    log(0, "malformed packet from %s, port %d", 
		PRINTADDRESS(sin), get_port(sin), 0);
	    free(st->st_iv);
	    free(buffer);
	    /* XXX Could send notification back */
	    return;
	}
	
#ifdef DEBUG
	log(0, "removed %d bytes of padding", i - k, 0, 0);
#endif

	i = k;
	isa->isa_length = htonl(i);
    }

    if (st == (struct state *) NULL)
#ifdef DEBUG
    {
#endif
      st = (struct state *) find_half_state(isa->isa_icookie, sin);
#ifdef DEBUG
      if (st)
	log(0, "half state object found", 0, 0, 0);
      else
	log(0, "state object not found", 0, 0, 0);
    }
#endif
    if (st == (struct state *) NULL)	/* Begining of exchange */
    {
	switch (isa->isa_xchg)
	{
	    case ISAKMP_XCHG_IDPROT:
		if (isa->isa_np != ISAKMP_NEXT_SA)
		{
		    log(0, "invalid payload %d from %s, port %d",
			isa->isa_np, PRINTADDRESS(sin), get_port(sin));
		    /* XXX Could send notification back */
		    free(buffer);
		    return;
		}

		isasa = (struct isakmp_sa *) (buffer + 
					      sizeof(struct isakmp_hdr));

		switch (ntohl(isasa->isasa_doi))
		{
		    case ISAKMP_DOI_IPSEC:
			/* Check the situation */
			bcopy(buffer + sizeof(struct isakmp_hdr) +
			      sizeof(struct isakmp_sa), &ipsecdoisit,
			      IPSEC_DOI_SITUATION_LENGTH);
			ipsecdoisit = ntohl(ipsecdoisit);
			if (ipsecdoisit != SIT_IDENTITY_ONLY)
			{
			    free(buffer);
			    log (0, "unsupported IPsec DOI situation (%d) received from %s, port %d",
				 ipsecdoisit, PRINTADDRESS(sin), 
				 get_port(sin));
			    /* XXX Could send notification back */
			    return;
			}
			
			ipsecdoi_handle_rfirst(sock, buffer, i, sin);
			break;

		    default:
			log(0, "unknown/unsupported DOI %d from %s, port %d",
			    ntohl(isasa->isasa_doi), PRINTADDRESS(sin),
			    get_port(sin));
			/* XXX Could send notification back */
			free(buffer);
			return;
		}
		
		return;
		
	    default:
		log(0, "out of order packet from %s, port %d",
		    PRINTADDRESS(sin), get_port(sin), 0);
		/* XXX Could send notification back */
		free(buffer);
		return;
	}
    }

    /* Handle Informational Exchanges -- Delete/Notifications */
    if (isa->isa_xchg == ISAKMP_XCHG_INFO)
    {
	/* XXX Handle deletions */

	/* XXX Handle error messages */

#ifdef DEBUG
	log(0, "informational message from %s, port %d",
	    PRINTADDRESS(sin), get_port(sin), 0);
#endif
    }

#ifdef DEBUG
    log(0, "exchange state %d", st->st_state, 0, 0);
#endif

    /* XXX Handle Commit Bit set and unset it */

    switch (st->st_state)
    {
	case OAKLEY_MAIN_I_1:
	    if (isa->isa_np != ISAKMP_NEXT_SA)
	    {
#ifdef DEBUG
		log(0, "dropping out of sequence packet from %s, port %d",
		    PRINTADDRESS(sin), get_port(sin), 0);
#endif
		free(buffer);
		return;
	    }
	    
	    ipsecdoi_handle_i1(sock, buffer, i, sin, st);
	    return;
	    
	case OAKLEY_MAIN_I_2:
	    if (isa->isa_np != ISAKMP_NEXT_KE)
	    {
#ifdef DEBUG
		log(0, "dropping out of sequence packet from %s, port %d",
		    PRINTADDRESS(sin), get_port(sin), 0);
#endif
		free(buffer);
		return;
	    }

	    ipsecdoi_handle_i2(sock, buffer, i, sin, st);
	    return;
	    
	case OAKLEY_MAIN_I_3:
	    if (encrypted == 0)
	    {
		log(0, "packet from %s, port %d should have been encrypted",
		    PRINTADDRESS(sin), get_port(sin), 0);
		free(buffer);
		/* XXX Could send notification back */
		return;
	    }

	    if (isa->isa_np != ISAKMP_NEXT_ID)
	    {
#ifdef DEBUG
		log(0, "dropping out of sequence packet from %s, port %d",
		    PRINTADDRESS(sin), get_port(sin), 0);
#endif
		free(buffer);
		return;
	    }
	    
	    /* Keep last block */
	    st->st_lastblock = (u_char *) calloc(lastblocksize, 
						 sizeof(u_char));
	    if (st->st_lastblock == (u_char *) NULL)
	      exit_log("calloc() failed in comm_handle()", 0, 0, 0);
	    
	    st->st_lastblock_len = lastblocksize;
	    bcopy(lastblock, st->st_lastblock, lastblocksize);

#ifdef DEBUG
	    log(0, "last encrypted Phase 1 block: ", 0, 0, 0);
	    for (k = 0; k < st->st_lastblock_len; k++)
	      fprintf(stderr, "%02x ", st->st_lastblock[k]);
	    fprintf(stderr, "\n");
#endif

	    ipsecdoi_handle_i3(sock, buffer, i, sin, st);
	    return;
	  
	case OAKLEY_MAIN_R_1:
	    if (isa->isa_np != ISAKMP_NEXT_KE)
	    {
		/* Retransmit */
		if (isa->isa_np == ISAKMP_NEXT_SA)
		{
		    if (sendto(sock, st->st_packet, st->st_packet_len, 0,
			       &(st->st_peer), sizeof(st->st_peer)) !=
			st->st_packet_len)
		      log(1, "sendto() failed in comm_handle() for %s, port %d"
			  , PRINTADDRESS(st->st_peer), get_port(st->st_peer), 
			  0);
#ifdef DEBUG
		    else
		      log(0, "retransmitted %d bytes", 
			  st->st_packet_len, 0, 0);
#endif
		}
#ifdef DEBUG
		else
		  log(0, "dropping out of sequence packet from %s, port %d",
		      PRINTADDRESS(sin), get_port(sin), 0);
#endif
		free(buffer);
		return;
	    }
	    
	    ipsecdoi_handle_r1(sock, buffer, i, sin, st);
	    return;
	    
	case OAKLEY_MAIN_R_2:
	    if (encrypted == 0)
	    {
		log(0, "packet from %s, port %d should have been encrypted",
		    PRINTADDRESS(sin), get_port(sin), 0);
		free(buffer);
		/* XXX Could send notification back */
		return;
	    }

	    if (isa->isa_np != ISAKMP_NEXT_ID)
	    {
		/* Retransmit */
		if (isa->isa_np == ISAKMP_NEXT_KE)
		{
		    if (sendto(sock, st->st_packet, st->st_packet_len, 0,
			       &(st->st_peer), sizeof(st->st_peer)) !=
			st->st_packet_len)
		      log(1, "sendto() failed in comm_handle() for %s, port %d"
			  , PRINTADDRESS(st->st_peer), 
			  get_port(st->st_peer), 0);
#ifdef DEBUG
		    else
		      log(0, "retransmitted %d bytes", 
			  st->st_packet_len, 0, 0);
#endif
		}
#ifdef DEBUG
		else
		  log(0, "dropping out of sequence packet from %s, port %d",
		      PRINTADDRESS(sin), get_port(sin), 0);
#endif		
		free(buffer);
		return;
	    }
	    
	    ipsecdoi_handle_r2(sock, buffer, i, sin, st);
	    return;

	case OAKLEY_QUICK_I_1:
	    if (encrypted == 0)
	    {
		log(0, "packet from %s, port %d should have been encrypted",
		    PRINTADDRESS(sin), get_port(sin), 0);
		free(buffer);
		/* XXX Could send notification back */
		return;
	    }

	    if (isa->isa_np != ISAKMP_NEXT_HASH)
	    {
#ifdef DEBUG
		log(0, "dropping out of sequence packet from %s, port %d",
		    PRINTADDRESS(sin), get_port(sin), 0);
#endif
		free(buffer);
		return;
	    }

	    ipsecdoi_handle_quick_i1(sock, buffer, i, sin, st, kernelfd);
	    return;
	    
	case OAKLEY_QUICK_R_1:
	    if (encrypted == 0)
	    {
		log(0, "packet from %s, port %d should have been encrypted",
		    PRINTADDRESS(sin), get_port(sin), 0);
		free(buffer);
		/* XXX Could send notification back */
		return;
	    }

	    if (isa->isa_np != ISAKMP_NEXT_HASH)
	    {
#ifdef DEBUG
		log(0, "dropping out of sequence packet from %s, port %d",
		    PRINTADDRESS(sin), get_port(sin), 0);
#endif		
		free(buffer);
		return;
	    }

#ifdef DEBUG
	    log(0, "Packet dump: ");
	    for (k = 0; k < i; k++)
	    {
		fprintf(stderr, "%02x ", buffer[k]);
		if ((k + 1) % 16 == 0)
		  fprintf(stderr, "\n");
	    }

	    if (k % 16)
	      fprintf(stderr, "\n");
#endif

	    ipsecdoi_handle_quick_r1(sock, buffer, i, sin, st);
	    return;
	    
	case OAKLEY_QUICK_R_2:
	    if (encrypted == 0)
	    {
		log(0, "packet from %s, port %d should have been encrypted",
		    PRINTADDRESS(sin), get_port(sin), 0);
		free(buffer);
		/* XXX Could send notification back */
		return;
	    }

	    if (isa->isa_np != ISAKMP_NEXT_HASH)
	    {
		isag = (struct isakmp_generic *) (buffer +
						  sizeof(struct isakmp_hdr));
		if (isag->isag_np == ISAKMP_NEXT_SA)
		{
		    if (sendto(sock, st->st_packet, st->st_packet_len, 0,
			       &(st->st_peer), sizeof(st->st_peer)) !=
			st->st_packet_len)
		      log(1, "sendto() failed in comm_handle() for %s, port %d"
			  , PRINTADDRESS(st->st_peer), 
			  get_port(st->st_peer), 0);
#ifdef DEBUG
		    else
		      log(0, "retransmitted %d bytes", 
			  st->st_packet_len, 0, 0);
#endif
		}
#ifdef DEBUG
		else
		  log(0, "dropping out of sequence packet from %s, port %d",
		      PRINTADDRESS(sin), get_port(sin), 0);
#endif		
		free(buffer);
		return;
	    }
	    
	    ipsecdoi_handle_quick_r2(sock, buffer, i, sin, st, kernelfd);
	    return;
    }

    log(0, "unexpected packet received from %s, port %d", PRINTADDRESS(sin),
	get_port(sin), 0);
    free(buffer);
    return;
}
