#include <stdio.h>
#include "ptrie.h"

#define HIGH_BIT   7
#define BIT_MASK   07
#define WORD_SIZE  8
#define WORD_SHIFT 3

#define is_bit_set(d, b) ((d) & (1 << (HIGH_BIT - (b))))

typedef struct _ptrie_node {
    unsigned char *key;
    void *data;
    int branch_bit;
    int visited;
    struct _ptrie_node *left, *right;
} ptrie_node, *ptrie_ptr;

ptrie
#ifdef __STDC__
ptrie_create(void)
#else /* !__STDC__ */
ptrie_create()
#endif /* !__STDC__ */
{
    ptrie_ptr self = (ptrie_ptr)malloc(sizeof(ptrie_node));

    self->key = (unsigned char *)malloc(sizeof(unsigned char));
    self->key[0] = '\0';
    self->data = NULL;
    self->branch_bit = 0;
    self->visited = 0;
    self->left = self->right = self;

    return((ptrie)self);
}

void
#ifdef __STDC__
ptrie_free(ptrie root, void (*free_func)())
#else /* !__STDC__ */
ptrie_free(root, free_func)
ptrie root;
void (*free_func)();
#endif /* !__STDC__ */
{
    ptrie_ptr next = (ptrie_ptr)root;

    if (next->left->branch_bit <= next->branch_bit &&
        next->right->branch_bit <= next->branch_bit)
      /*
       * Do nothing.
       */
      ;
    else if (next->left->branch_bit <= next->branch_bit)
      ptrie_free((ptrie)(next->right), free_func);
    else if (next->right->branch_bit <= next->branch_bit)
      ptrie_free((ptrie)(next->left), free_func);
    else {
        ptrie_free((ptrie)(next->right), free_func);
        ptrie_free((ptrie)(next->left), free_func);
    }
    if (next->data != NULL)
      (*free_func)(next->data);
    free((char *)next);
}

static ptrie_ptr
#ifdef __STDC__
new_ptrie_node(unsigned char *key, void *data, int key_len, int branch_bit)
#else /* !__STDC__ */
new_ptrie_node(key, data, key_len, branch_bit)
unsigned char *key;
void *data;
int key_len, branch_bit;
#endif /* !__STDC__ */
{
    ptrie_ptr nn = (ptrie_ptr)malloc(sizeof(ptrie_node));

    nn->key = (unsigned char *)malloc(sizeof(unsigned char) * (key_len + 1));
    (void)strncpy(nn->key, key, key_len);
    nn->key[key_len] = '\0';
    nn->data = data;

    nn->branch_bit = branch_bit;
    nn->left = nn->right = NULL;

    return(nn);
}

void
#ifdef __STDC__
ptrie_insert(ptrie root, unsigned char *key, void *data)
#else /* !__STDC__ */
ptrie_insert(root, key, data)
ptrie root;
unsigned char *key;
void *data;
#endif /* !__STDC__ */
{
    ptrie_ptr parent, new_node, curr = ((ptrie_ptr)root)->left;
    int key_len = strlen((char *)key);
    int bit, bit_diff, lower = 0;
    unsigned char *ptr1, *ptr = key;

    do {
        parent = curr;
        bit = curr->branch_bit - lower;
        if (bit >= WORD_SIZE) {
            while(lower + WORD_SIZE <= curr->branch_bit) {
                lower += WORD_SIZE;
                bit -= WORD_SIZE;
                ptr++;
            }
        }
        curr = ((is_bit_set(*ptr, bit)) ? curr->right : curr->left);
    } while (parent->branch_bit < curr->branch_bit);

    if (strcmp((char *)key, (char *)curr->key) != 0) {
        ptr = key;
        ptr1 = curr->key;

        /*
         * Find the byte where they differ.
         */
        bit_diff = 0;
        for (; *ptr1 == *ptr; bit_diff += WORD_SIZE) {
            ptr++;
            ptr1++;
        }

        /*
         * Find the bit where they differ.
         */
        for (bit = *ptr1 ^ *ptr; !(bit & (1 << HIGH_BIT)); bit <<= 1)
          bit_diff++;

        if (parent->branch_bit > bit_diff) {
            curr = (ptrie_ptr)root;
            lower = 0;
            ptr1 = key;
            do {
                parent = curr;
                bit = curr->branch_bit - lower;
                if (bit >= WORD_SIZE) {
                    while (lower + WORD_SIZE <= curr->branch_bit) {
                        lower += WORD_SIZE;
                        bit -= WORD_SIZE;
                        ptr1++;
                    }
                }
                curr = ((is_bit_set(*ptr1, bit)) ? curr->right : curr->left);
            } while (curr->branch_bit < bit_diff);
        }

        new_node = new_ptrie_node(key, data, key_len, bit_diff);

        if (is_bit_set(*ptr, (bit_diff & BIT_MASK))) {
            new_node->left = curr;
            new_node->right = new_node;
        } else {
            new_node->left = new_node;
            new_node->right = curr;
        }

        ptr = key + (parent->branch_bit >> WORD_SHIFT);

        if (is_bit_set(*ptr, (parent->branch_bit & BIT_MASK)))
          parent->right = new_node;
        else
          parent->left = new_node;
    }
}

void *
#ifdef __STDC__
ptrie_find(ptrie root, unsigned char *key)
#else /* !__STDC__ */
ptrie_find(root, key)
ptrie root;
unsigned char *key;
#endif /* !__STDC__ */
{
    ptrie_ptr parent, curr = (ptrie_ptr)(((ptrie_ptr)root)->left);
    unsigned char *ptr = key;
    int bit, lower = 0;

    do {
        parent = curr;
        bit = curr->branch_bit - lower;
        if (bit >= WORD_SIZE) {
            while(lower + WORD_SIZE <= curr->branch_bit) {
                lower += WORD_SIZE;
                bit -= WORD_SIZE;
                ptr++;
            }
        }
        curr = (is_bit_set(*ptr, bit)) ? curr->right : curr->left;
    } while (parent->branch_bit < curr->branch_bit);

    if (strcmp((char *)key, (char *)curr->key) != 0)
      return(NULL);
    return(curr->data);
}

void
#ifdef __STDC__
ptrie_walk(ptrie root, void *user_data, void (*action_func)(void *, void *))
#else /* !__STDC__ */
ptrie_walk(root, user_data, action_func)
ptrie root;
void *user_data;
void (*action_func)();
#endif /* !__STDC__ */
{
    ptrie_ptr next = (ptrie_ptr)root;

    if (next->left->branch_bit <= next->branch_bit &&
        next->right->branch_bit <= next->branch_bit)
      /*
       * Do nothing.
       */
      ;
    else if (next->left->branch_bit <= next->branch_bit)
      ptrie_walk((ptrie)(next->right), user_data, action_func);
    else if (next->right->branch_bit <= next->branch_bit)
      ptrie_walk((ptrie)(next->left), user_data, action_func);
    else {
        ptrie_walk((ptrie)(next->left), user_data, action_func);
        ptrie_walk((ptrie)(next->right), user_data, action_func);
    }
    if (next->data != NULL)
      (*action_func)(user_data, next->data);
}
