/*
 * This program is an implementation of the ISAKMP Internet Standard.
 * Copyright (C) 1997 Angelos D. Keromytis.
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
 * USA
 * 
 * This code was written while the author was in Greece, in May/June
 * 1997.
 *
 * You may contact the author by:
 *   e-mail: angelos@dsl.cis.upenn.edu
 *  US-mail: Angelos D. Keromytis
 *           Distributed Systems Lab
 *           Computer and Information Science Department
 *           University of Pennsylvania
 *           Moore Building
 *           200 South 33rd Street
 *           Philadelphia, PA 19104	   
 */

#include <stdio.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <fcntl.h>
#include "constants.h"
#include "state.h"
#include "defs.h"

/*
 * This file has the functions that handle the
 * state hash table.
 */

#define STATE_TABLE_SIZE 32

static struct state *statetable[STATE_TABLE_SIZE];

/*
 * Initialize the modulo and generator.
 */
void
init_nums(void)
{
    if (mpz_init_set_str(&groupmodulo, GROUPDEFAULT, 16) != 0)
      exit_log("mpz_init_set_str() failed in init_nums()", 0, 0, 0);

    if (mpz_init_set_str(&groupgenerator, DEFAULTGENERATOR, 10) != 0)
      exit_log("mpz_init_set_str() failed in init_nums()", 0, 0, 0);
}

/*
 * Had to go somewhere, might as well be this file. Initialize
 * global variables.
 */
void
init_vars(void)
{
    int i;

    if (our_port == 0)
      our_port = PORT;

    /* 
     * Generate the secret value for responder cookies, and
     * schedule an event for refresh.
     */
    get_rnd_bytes(secret_of_the_day, SECRET_VALUE_LENGTH);
    event_schedule(EVENT_REINIT_SECRET, EVENT_REINIT_SECRET_DELAY, NULL, 0);
}

static u_int
state_hash(u_char *icookie, u_char *rcookie, struct sockaddr sa)
{
    u_int i, j;
    
#ifdef DEBUG2
    log(0, "ICOOKIE: ", 0, 0, 0);
    for (i = 0; i < COOKIE_SIZE; i++)
      fprintf(stderr, "%02x ", icookie[i]);
    fprintf(stderr, "\n");
    log(0, "RCOOKIE: ", 0, 0, 0);
    for (i = 0; i < COOKIE_SIZE; i++)
      fprintf(stderr, "%02x ", rcookie[i]);
    fprintf(stderr, "\n");
    log(0, "sockaddr: ", 0, 0, 0);
    for (i = 0; i < 4; i++)
      fprintf(stderr, "%02x ", sa.sa_data[i]);
    fprintf(stderr, "\n");
#endif

    for (i = 0, j = 0; j < COOKIE_SIZE; j++)
      i += icookie[j] + rcookie[j];

    i += sa.sa_family;

    for (j = 0; j < 4; j++)
      i += sa.sa_data[j];

#ifdef DEBUG
    log(0, "state hash entry %d", i % STATE_TABLE_SIZE);
#endif

    return i % STATE_TABLE_SIZE;
}

/* 
 * Get a state object
 */
struct state *
get_state(void)
{
    struct state *st;
    
    st = (struct state *) calloc(1, sizeof(struct state));
    if (st == (struct state *) NULL)
      exit_log("calloc() failed in get_state()", 0, 0, 0);
   
    return st;
}

/*
 * Initialize the state table
 */
void
init_state(void)
{
    int i;
    
    for (i = 0; i < STATE_TABLE_SIZE; i++)
      statetable[i] = (struct state *) NULL;
}

/*
 * Insert a state object in the hash table
 */
void
insert_state(struct state *st)
{
    int i;
    
    i = state_hash(st->st_icookie, st->st_rcookie, st->st_peer);
    
    st->st_prev = (struct state *) NULL;
    st->st_next = statetable[i];
    statetable[i] = st;
    if (st->st_next != (struct state *) NULL)
      st->st_next->st_prev = st;
}

/*
 * Delete a state object from the hash table, but don't free it
 */
void
delete_state(struct state *st)
{
    int i;

    /* If it's not the first in the list */

    if (st->st_prev != (struct state *) NULL)
    {
	st->st_prev->st_next = st->st_next;

	if (st->st_next != (struct state *) NULL)
	  st->st_next->st_prev = st->st_prev;

	return;
    }
    
    i = state_hash(st->st_icookie, st->st_rcookie, st->st_peer);
    statetable[i] = st->st_next;
    if (st->st_next != (struct state *) NULL)
      st->st_next->st_prev = (struct state *) NULL;
}

/*
 * Free a state object
 */
void
free_state(struct state *st)
{
    if ((st->st_packet != (u_char *) NULL) &&
	(st->st_packet_len != 0))
      free(st->st_packet);

    if (st->st_gi_in_use)
      mpz_clear(&(st->st_gi));
    
    if (st->st_gr_in_use)
      mpz_clear(&(st->st_gr));
    
    if (st->st_sec_in_use)
      mpz_clear(&(st->st_sec));

    if (st->st_shared_in_use)
      mpz_clear(&(st->st_shared));
    
    if (st->st_proposal != (u_char *) NULL)
      free(st->st_proposal);
    
    if (st->st_sa != (u_char *) NULL)
      free(st->st_sa);

    if (st->st_ni != (u_char *) NULL)
      free(st->st_ni);
    
    if (st->st_nr != (u_char *) NULL)
      free(st->st_nr);

    if (st->st_skeyid != (u_char *) NULL)
      free(st->st_skeyid);
    
    if (st->st_skeyid_d != (u_char *) NULL)
      free(st->st_skeyid_d);
    
    if (st->st_skeyid_a != (u_char *) NULL)
      free(st->st_skeyid_a);
    
    if (st->st_skeyid_e != (u_char *) NULL)
      free(st->st_skeyid_e);
    
    if (st->st_myidentity != (u_char *) NULL)
      free(st->st_myidentity);

    if (st->st_peeridentity != (u_char *) NULL)
      free(st->st_peeridentity);

    if (st->st_iv != (u_char *) NULL)
      free(st->st_iv);
    
    if (st->st_lastblock != (u_char *) NULL)
      free(st->st_lastblock);

    if (st->st_spi != (u_char *) NULL)
      free(st->st_spi);
    
    if (st->st_keymat != (u_char *) NULL)
      free(st->st_keymat);
    
    free(st);
}

/*
 * Duplicate a state object, for Phase 2.
 */
struct state *
duplicate_state(struct state *st)
{
    struct state *nst;
    
    nst = get_state();

    nst->st_phase1 = st;

    bcopy(st->st_icookie, nst->st_icookie, COOKIE_SIZE);
    bcopy(st->st_rcookie, nst->st_rcookie, COOKIE_SIZE);
    nst->st_peer = st->st_peer;

    nst->st_doi = st->st_doi;
    nst->st_situation = st->st_situation;

    /* Copy SKEYID_D */
    nst->st_skeyid_d = (u_char *) calloc(st->st_skeyid_d_len, sizeof(u_char));
    if (nst->st_skeyid_d == (u_char *) NULL)
      exit_log("calloc() failed in duplicate_state()", 0, 0, 0);
    
    bcopy(st->st_skeyid_d, nst->st_skeyid_d, st->st_skeyid_d_len);
    nst->st_skeyid_d_len = st->st_skeyid_d_len;
    
    /* Copy SKEYID_E */
    nst->st_skeyid_e = (u_char *) calloc(st->st_skeyid_e_len, sizeof(u_char));
    if (nst->st_skeyid_e == (u_char *) NULL)
      exit_log("calloc() failed in duplicate_state()", 0, 0, 0);
    
    bcopy(st->st_skeyid_e, nst->st_skeyid_e, st->st_skeyid_e_len);
    nst->st_skeyid_e_len = st->st_skeyid_e_len;
    
    /* Copy SKEYID_A */
    nst->st_skeyid_a = (u_char *) calloc(st->st_skeyid_a_len, sizeof(u_char));
    if (nst->st_skeyid_a == (u_char *) NULL)
      exit_log("calloc() failed in duplicate_state()", 0, 0, 0);
    
    bcopy(st->st_skeyid_a, nst->st_skeyid_a, st->st_skeyid_a_len);
    nst->st_skeyid_a_len = st->st_skeyid_a_len;
    
    /* Copy goal */
    nst->st_goal = st->st_goal;

    nst->st_hash = st->st_hash;
    nst->st_enc = st->st_enc;
    nst->st_prf = st->st_prf;
    nst->st_auth = st->st_auth;
    
    return nst;
}

/*
 * Find a state object.
 */
struct state *
find_full_state(u_char *icookie, u_char *rcookie, struct
		sockaddr sa, u_int32_t msgid)
{
#define SA_EQUAL(x, y)   ((x.sa_family == y.sa_family) &&\
                          (!bcmp(x.sa_data, y.sa_data, 6)))
    struct state *st;
    
    st = statetable[state_hash(icookie, rcookie, sa)];

#ifdef DEBUG
    log(0, "find_full_state() hash %d pointer %p", 
	state_hash(icookie, rcookie, sa), st, 0);
#endif

    while (st != (struct state *) NULL)
      if (SA_EQUAL(sa, st->st_peer) &&
	  (!bcmp(icookie, st->st_icookie, COOKIE_SIZE)) &&
	  (!bcmp(rcookie, st->st_rcookie, COOKIE_SIZE)) &&
	  (msgid == st->st_msgid))
	break;
      else
	st = st->st_next;
    
    return st;
#undef SA_EQUAL
}

/*
 * Find an ISAKMP SA state object.
 */
struct state *
find_phase1_state(struct sockaddr sa)
{
#define SA_EQUAL(x, y)   ((x.sa_family == y.sa_family) &&\
                          (!bcmp(x.sa_data, y.sa_data, 6)))
    struct state *st;
    int i;
    
    for (i = 0; i < STATE_TABLE_SIZE; i++)
      for (st = statetable[i]; st != (struct state *) NULL; st = st->st_next)
	if ((SA_EQUAL(sa, st->st_peer)) &&      /* Host we want */
	    (st->st_msgid == 0) &&	        /* ISAKMP SA */
	    (st->st_protoid == PROTO_ISAKMP) &&
	    ((st->st_state == OAKLEY_MAIN_I_4) ||
	     (st->st_state == OAKLEY_MAIN_R_3)))
	  return st;

    return (struct state *) NULL;
#undef SA_EQUAL
}

/*
 * Find a state object in the hash table, without using
 * the rcookie (to find state objects where we haven't yet
 * received a message from the responder).
 */
struct state *
find_half_state(u_char *icookie, struct sockaddr sa)
{
#define SA_EQUAL(x, y)   ((x.sa_family == y.sa_family) &&\
                          (!bcmp(x.sa_data, y.sa_data, 6)))
    struct state *st;
    u_char rcookie[COOKIE_SIZE];
    
    bzero(rcookie, COOKIE_SIZE);
    st = statetable[state_hash(icookie, rcookie, sa)];

    while (st != (struct state *) NULL)
      if (SA_EQUAL(sa, st->st_peer) &&
	  (!bcmp(icookie, st->st_icookie, COOKIE_SIZE)))
	break;
      else
	st = st->st_next;
    
    return st;
#undef SA_EQUAL
}
