/* rec_node.c -- functions to store a datbase of recognition sequences

   Copyright (C) 2000 Carl Worth

   This program is free software; you can redistribute it and/or modify
   it under the terms of the GNU General Public License as published by
   the Free Software Foundation; either version 2, or (at your option)
   any later version.

   This program is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
   GNU General Public License for more details.
*/

#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#include "stroke.h"
#include "rec_node.h"
#include "xmalloc.h"
#include "rec-interface.h"

/* How many grid cells are adjacent to each cell.
 *
 * ie. max_branches[1] = 3 means that cell 1 is adjacent to 3 other
 * cells.  The 0 cell is for a special node at the root of the tree.  */
int max_branches[] = {
  9, 3, 5, 3, 5, 8, 5, 3, 5, 3
};

/* A mapping used to traverse a tree of rec_node objects. Negative
 * numbers represent invalid paths, (ie non-adjacent cells). For
 * example, branch_mapping[1][5] = 2 means that if you arrived at a
 * node representing cell 1 and you want to follow the branch leading
 * to cell 5 you should use index 2 in this node's next array.
 * The 0 cell is a special node at the head of the tree so is never
 * accessible from any cell. */
#define INVALID_BRANCH -1
int branch_map[10][10] = {
  {INVALID_BRANCH,  0,  1,  2,  3,  4,  5,  6,  7,  8},
  {INVALID_BRANCH, INVALID_BRANCH,  0, INVALID_BRANCH,  1,  2, INVALID_BRANCH, INVALID_BRANCH, INVALID_BRANCH, INVALID_BRANCH},
  {INVALID_BRANCH,  0, INVALID_BRANCH,  1,  2,  3,  4, INVALID_BRANCH, INVALID_BRANCH, INVALID_BRANCH},
  {INVALID_BRANCH, INVALID_BRANCH,  0, INVALID_BRANCH, INVALID_BRANCH,  1,  2, INVALID_BRANCH, INVALID_BRANCH, INVALID_BRANCH},
  {INVALID_BRANCH,  0,  1, INVALID_BRANCH, INVALID_BRANCH,  2, INVALID_BRANCH,  3,  4, INVALID_BRANCH},
  {INVALID_BRANCH,  0,  1,  2,  3, INVALID_BRANCH,  4,  5,  6,  7},
  {INVALID_BRANCH, INVALID_BRANCH,  0,  1, INVALID_BRANCH,  2, INVALID_BRANCH, INVALID_BRANCH,  3,  4},
  {INVALID_BRANCH, INVALID_BRANCH, INVALID_BRANCH, INVALID_BRANCH,  0,  1, INVALID_BRANCH, INVALID_BRANCH,  2, INVALID_BRANCH},
  {INVALID_BRANCH, INVALID_BRANCH, INVALID_BRANCH, INVALID_BRANCH,  0,  1,  2,  3, INVALID_BRANCH,  4},
  {INVALID_BRANCH, INVALID_BRANCH, INVALID_BRANCH, INVALID_BRANCH, INVALID_BRANCH,  0,  1, INVALID_BRANCH,  2, INVALID_BRANCH},
};

/* This table is used to find an valid intermediate cell which
 * connects two non-adjacent cell. Ambiguities favor the center cell,
 * (5). Cells which are already adjacent are listed as INVALID_BRANCH
 * here. For example, to find the cell which connects non-adjacent
 * cells 1 and 6, simply look and see that cell_connecting[1][6] == 5.
 * */
int cell_connecting[10][10] = {
  {INVALID_BRANCH, INVALID_BRANCH, INVALID_BRANCH, INVALID_BRANCH, INVALID_BRANCH, INVALID_BRANCH, INVALID_BRANCH, INVALID_BRANCH, INVALID_BRANCH, INVALID_BRANCH},
  {INVALID_BRANCH, INVALID_BRANCH, INVALID_BRANCH, 2, INVALID_BRANCH, INVALID_BRANCH, 5, 4, 5, 5},
  {INVALID_BRANCH, INVALID_BRANCH, INVALID_BRANCH, INVALID_BRANCH, INVALID_BRANCH, INVALID_BRANCH, INVALID_BRANCH, 5, 5, 5},
  {INVALID_BRANCH, 2, INVALID_BRANCH, INVALID_BRANCH, 5, INVALID_BRANCH, INVALID_BRANCH, 5, 5, 6},
  {INVALID_BRANCH, INVALID_BRANCH, INVALID_BRANCH, 5, INVALID_BRANCH, INVALID_BRANCH, 5, INVALID_BRANCH, INVALID_BRANCH, 5},
  {INVALID_BRANCH, INVALID_BRANCH, INVALID_BRANCH, INVALID_BRANCH, INVALID_BRANCH, INVALID_BRANCH, INVALID_BRANCH, INVALID_BRANCH, INVALID_BRANCH, INVALID_BRANCH},
  {INVALID_BRANCH, 5, INVALID_BRANCH, INVALID_BRANCH, 5, INVALID_BRANCH, INVALID_BRANCH, 5, INVALID_BRANCH, INVALID_BRANCH},
  {INVALID_BRANCH, 4, 5, 5, INVALID_BRANCH, INVALID_BRANCH, 5, INVALID_BRANCH, INVALID_BRANCH, 8},
  {INVALID_BRANCH, 5, 5, 5, INVALID_BRANCH, INVALID_BRANCH, INVALID_BRANCH, INVALID_BRANCH, INVALID_BRANCH, INVALID_BRANCH},
  {INVALID_BRANCH, 5, 5, 6, 5, INVALID_BRANCH, INVALID_BRANCH, 8, INVALID_BRANCH, INVALID_BRANCH}
};

/* Private functions */
static inline int get_branch_index(struct rec_node *node, int node_digit, int branch);
static inline int find_cell_connecting(int from, int to);
static struct rec_node *new_rec_node(int digit);
void free_rec_node(struct rec_node *node, int digit);
static int add_intermediate_sequence(struct rec_node *node, int node_digit, char *seq_so_far, char *seq_to_add, action_list *data);
static void *lookup_intermediate_sequence(struct rec_node *node, int node_digit, char *seq);
static inline int get_branch_index(struct rec_node *node, int node_digit, int branch);
static int inverse_branch_map(int node_digit, int branch_index);
static void print_rec_tree_with_prefix(struct rec_node *root, int root_digit, char **prefix_p, int *length_p);

static int node_stats[10];
static int known_strokes;
static int known_strokes_length;

#define assert(x...) do { if (!(x)) { \
fprintf(stderr, "assertion (" #x ") failed at %s:%d\n", __FILE__, __LINE__); \
exit(1); \
} } while (0)

static inline int get_branch_index(struct rec_node *node, int node_digit, int branch) {
  int branch_index = branch_map[node_digit][branch];

#ifdef DEBUG_REMEMBER_NODE_DIGITS
  assert(node_digit == node->digit);
#endif
#ifdef DEBUG_VERIFY_BRANCH_BOUNDS
  assert(branch_index < max_branches[node_digit]);
#endif

  if (branch_index == INVALID_BRANCH) {
    fprintf(stderr, "%s:%d: Error: invalid branch from %d to %d\n", __FILE__, __LINE__, node_digit, branch);
  }

  return branch_index;
}

static inline int find_cell_connecting(int from, int to) {
  return cell_connecting[from][to];
}

int add_sequence(struct rec_node *root, char *seq_to_add, action_list *data) {
  char seq_so_far[MAX_SEQUENCE+1];

  seq_so_far[0] = '\0';
  return add_intermediate_sequence(root, 0, seq_so_far, seq_to_add, data);
  /* We don't free anything since add_intermediate_sequence will hang onto this pointer */
}

/* We accept two different char * sequences as follows:
 *
 * seq_so_far is a pointer to the sequence as it has been constructed
 * so far. It will consist only of digits [0-9], no '?'s. seq_so_far
 * can be used in error/warning messages to show what sequences caused
 * problems.
 *
 * seq_to_add is a pointer to the sequence that is yet to be
 * added. Note that this string may have '?'s in it.
 */
static int add_intermediate_sequence(struct rec_node *node, int node_digit, char *seq_so_far, char *seq_to_add, action_list *data) {
  int i;
  int return_val = 0;
  int branch, branch_index;
  struct rec_node *next_node = node;
  int orig_seq_len = strlen(seq_so_far);

#ifdef DEBUG_TREE_CREATION
  fprintf(stderr,"%s:%d: entering add_intermediate_sequence: \"%s\" + \"%s\" => %s\n",__FILE__,__LINE__,seq_so_far,seq_to_add,(char *) data);
#endif

  assert(node != NULL);

#ifdef DEBUG_REMEMBER_NODE_DIGITS
  assert(node_digit == node->digit);
#endif

  for (i=0; seq_to_add[i] != '\0'; i++) {
    /* Recurse on '[]' delimited classes of digits */
    if (seq_to_add[i] == '[') {
      int err;
      char *p;
      char *end = strchr(seq_to_add + i, ']');
      char rest_of_seq[strlen(end)+2];
      if (! end) {
	fprintf(stderr, "Error: '[' encountered with no matching ']'\n");
	return_val = INVALID_SEQUENCE;
	goto END_OF_SEQUENCE;
      }

      rest_of_seq[0] = '\0';
      rest_of_seq[1] = '\0';
      strcpy(rest_of_seq + 1, end + 1);
      for (p = seq_to_add + i + 1; p != end; p++) {
	rest_of_seq[0] = p[0];
	err = add_intermediate_sequence(node, node_digit, seq_so_far, rest_of_seq, data);
	if (err)
	  return_val = INVALID_SEQUENCE;
      }
      goto END_OF_SEQUENCE;
    }

    /* Lookahead once to recurse on '?' */
    if (seq_to_add[i+1] == '?') {
      int err = add_intermediate_sequence(node, node_digit, seq_so_far, seq_to_add+i+2, data);
      if (err)
	return_val = INVALID_SEQUENCE;
    }

    /* Then ignore the '?' when we iterate to it */
    /* Also, ignore any * */
    if (seq_to_add[i] == '?' || seq_to_add[i] == '*')
      continue;

    branch = seq_to_add[i] - '0';

    /* Skip this digit if it repeats the current digit */
    if (branch == node_digit)
      continue;

    if (strlen(seq_so_far) >= MAX_SEQUENCE) {
      fprintf(stderr,"%s:%d: creating sequence %s + %s would exceed MAX_SEQUENCE length (%d)\n", __FILE__, __LINE__, seq_so_far, seq_to_add + i, MAX_SEQUENCE);
      return_val = INVALID_SEQUENCE;
      goto END_OF_SEQUENCE;
    }

    /* Add new digit to seq_so_far */
    strncat(seq_so_far, seq_to_add + i, 1);

    branch_index = get_branch_index(node, node_digit, branch);
    if (branch_index == INVALID_BRANCH) {
#ifdef DEBUG_TREE_CREATION
      fprintf(stderr,"%s:%d: add_intermediate_sequence_intermediate_sequence: invalid branch occurred while mapping \"%s\" + \"%s\" => %s\n",__FILE__,__LINE__,seq_so_far,seq_to_add,(char *) data);
#endif
      return_val = INVALID_SEQUENCE;
      goto END_OF_SEQUENCE;
    }

    next_node = node->next[branch_index];
    if (next_node == NULL) {
      next_node = new_rec_node(branch);
      node->next[branch_index] = next_node;
    }

    node = next_node;
    node_digit = branch;
  }

  if (next_node) {
    if (next_node->data) {
      if (strcmp(next_node->data->label, data->label))
	fprintf(stderr, "Warning: Mapping exists %s => %s. Not changing to %s => %s\n", seq_so_far, next_node->data->label, seq_so_far, data->label);
    } else {
#ifdef DEBUG_TREE_CREATION
      fprintf(stderr, "Creating mapping %s => %s\n", seq_so_far, data->label);
#endif
      next_node->data = data;
      known_strokes++;
      known_strokes_length += strlen(seq_so_far);
    }
  }

 END_OF_SEQUENCE:
  /* Chop seq_so_far back down to its original size */
  seq_so_far[orig_seq_len] = '\0';

#ifdef DEBUG_TREE_CREATION
    fprintf(stderr,"%s:%d:  exiting add_intermediate_sequence: \"%s\" + \"%s\" => %s\n",__FILE__,__LINE__,seq_so_far,seq_to_add,(char *) data);
#endif

  return return_val;
}

void *lookup_sequence(struct rec_node *root, char *seq) {
  return lookup_intermediate_sequence(root, 0, seq);
}

static void *lookup_intermediate_sequence(struct rec_node *node, int node_digit, char *seq) {
  int i;
  int branch, branch_index;

  for (i=0; seq[i] != '\0'; i++) {
    if (node) {
      branch = seq[i] - '0';
      /* Skip this digit if it's a repeat */
      if (branch == node_digit)
	continue;
      branch_index = get_branch_index(node, node_digit, branch);
      if (branch_index == INVALID_BRANCH) {
	/* Work around the invalid branch bug in libstroke */
	branch = find_cell_connecting(node_digit, branch);
	fprintf(stderr, "%s:%d: Working around invalid branch from libstroke, replacing %d->%d with %d->%d->%d\n",__FILE__,__LINE__,node_digit,seq[i] - '0',node_digit,branch,seq[i] - '0');

	branch_index = get_branch_index(node, node_digit, branch);
	if (branch_index == INVALID_BRANCH) {
	  fprintf(stderr, "%s:%d: Error: Failed to validate invalid branch. New branch %d->%d is still invalid. Giving up.", __FILE__, __LINE__, node_digit, branch); 
	  return NULL;
	}
	/* This workaround inserted a new digit, so we want to repeat the current one. */
	i--;
      }
      node = node->next[branch_index];
      node_digit = branch;
    } else {
      return NULL;
    }
  }

  if (node) {
    return node->data;
  } else {
    return NULL;
  }
}

struct rec_node *new_rec_tree() {
  return new_rec_node(0);
}

void free_rec_tree(struct rec_node *root) {
  free_rec_node(root, 0);
}

static struct rec_node *new_rec_node(int digit) {

  struct rec_node *n = xmalloc(sizeof(struct rec_node));

#ifdef DEBUG_NODE_CREATION
  fprintf(stderr, "%s:%d: new_rec_node: Creating node of digit %d\n", __FILE__, __LINE__, digit);
#endif
#ifdef DEBUG_REMEMBER_NODE_DIGITS
  n->digit = digit;
#endif
  n->data = NULL;
  n->next = xcalloc(max_branches[digit], sizeof(struct rec_node *));

  node_stats[digit]++;
  return n;
}

void free_rec_node(struct rec_node *node, int digit) {

  int i;

  if (node) {
    for (i=0; i<max_branches[digit]; i++)
      if (node->next[i])
	free_rec_node(node->next[i], inverse_branch_map(digit, i));
    if (node->next)
      free(node->next);
    if (node->data) {
      known_strokes--;
      known_strokes_length -= strlen(node->data->label);
      /* TODO: I can't actually free data here because I probably have
      multiple pointers to it. I'm planning on putting data into its
      own object that will store type information, (keycode, modifier,
      string, action, etc.). While I'm at it and I can add a usage
      count as well -- increment when creating mappings, decrement
      here. If the decrement reduces the count to 0, *then* we can do
      the free. For now, let it leak...
      free(node->data);
      */
    }
    free(node);
    node_stats[digit]--;
  }
}

static int inverse_branch_map(int node_digit, int branch_index) {
  int i;
  for (i=0; i<10; i++) {
    if (branch_map[node_digit][i] == branch_index)
      return i;
  }
  return INVALID_BRANCH;
}

static void append_to_string(char **string, int *length, char *ending) {
  int new_length = strlen(*string) + strlen(ending) + 1;
  if (new_length > *length) {
    *string = xrealloc(*string, new_length);
    *length = new_length;
  }
  strcat(*string, ending);
}

void print_rec_tree(struct rec_node *root) {
  if (root) {
    static int prefix_string_length = 80;
    char *prefix = xmalloc(prefix_string_length);
    prefix[0]='\0';
    print_rec_tree_with_prefix(root, 0, &prefix, &prefix_string_length);
    free(prefix);
  }
}

static void print_rec_tree_with_prefix(struct rec_node *root, int root_digit, char **prefix_p, int *length_p) {
  int i;
  int branches;

  /* root is guaranteed non-null due to checks before all calls */
#ifdef REMEMBER_NODE_DIGITS
  assert(root->digit == root_digit);
#endif

  branches = 0;
  for (i=0; i<max_branches[root_digit]; i++) {
    if (root->next[i])
      branches++;
  }

  printf("%s", *prefix_p);

  if (branches < 2) {
    append_to_string(prefix_p, length_p, " ");
  } else {
    append_to_string(prefix_p, length_p, "+");
  }
  printf("%d", root_digit);
  if (root->data)
    printf(" \"%s\"", (char *) root->data);
  printf("\n");
  for (i=0; i<max_branches[root_digit]; i++) {
    if (root->next[i]) {
      print_rec_tree_with_prefix(root->next[i], inverse_branch_map(root_digit, i), prefix_p, length_p);
    }
  }
    
  (*prefix_p)[strlen(*prefix_p) - 1] = '\0';
}

void print_node_stats(void) {
  int i;
  int total_nodes = 0;
  int total_branches = 0;

  for (i=0; i<10; i++) {
    int cnt, branches;
    cnt = node_stats[i];
    branches = cnt * max_branches[i];
    printf("Nodes of type %d:%5d x %d branches/node = %d branches\n", i, cnt, max_branches[i], branches);
    total_nodes += node_stats[i];
    total_branches += branches;
  }
  printf(  "--------------------------------------------------------------\n");
  printf(  "    Total nodes: %5d     Total branches: %d\n", total_nodes, total_branches);
  printf(  "  Known strokes: %5d     Average stroke length: %f\n", known_strokes, (float) known_strokes_length / (known_strokes));
}
