/*
  symbol.cc -- implementation of symbol tables
  
  This program has no warranties of any kind. Use at own risk.
  
  Author: Tommi Syrjnen (tommi.syrjanen@hut.fi)
  
  $Id: symbol.cc,v 1.1 1998/08/04 09:19:01 tssyrjan Exp $	 
*/


#include "debug.h"
#include "symbol.h"
#include <stdio.h>
#include <string.h>
#include <ctype.h>
#include "term.h"
#include <stdlib.h>

/// Array primes[] contains a table of suitable primes for hash table
/// sizes. If you add more primes, remember to update NUM_PRIMES also.
long int primes[] = { 5, 23, 53, 101, 211, 401, 809, 1601, 3203, 6421,
		      12007, 24001, 48017, 96001, 180001, 300007,
		      600011, 1200007, 600011, 1200007, 2400001,
		      4800007, 9600037, 18000041 }; 

SymbolTable::SymbolTable(long sz)
{
  debug(DBG_SYMBOL, 1, "Creating SymbolTable");
  max_size = 1;
  while (primes[max_size] < sz)
    max_size++;
  size = 0;
  max_arity = -1;

  debug(DBG_SYMBOL, 4, "\tAllocating items");
  items = new SymbolNode*[primes[max_size]];
  if (!items)
    error(SYS_ERR, "malloc error");

  memset(items, 0, sizeof(SymbolNode*)*primes[max_size]);
}

SymbolTable::~SymbolTable()
{
  debug(DBG_SYMBOL, 1, "destroying SymbolTable");
  long i;
  debug(DBG_SYMBOL, 4, "\tdestroying items");
  for (i = 0; i < primes[max_size]; i++) 
    delete items[i];

  debug(DBG_SYMBOL, 4, "\tdestroying array");
  delete [] items;
}

unsigned long SymbolTable::Hash(char *key, int sz)
{
  debug(DBG_SYMBOL, 3, "Hashing %s, size %d", key, sz);
  unsigned long val = 0, b = 0;
  SymbolNode **tbl = NULL;
  long  int i;
  char *p = key;

  if (sz)
    tbl = tmp_nodes;
  else
    tbl = items;
  
  while (*p) {
    b <<= 5;
    b += *p++;
    b %= primes[max_size+sz];
  }
  
  for (i = 0; i < primes[max_size+sz]; i++) {
    val = b + i*i;
    val %= primes[max_size+sz];

    debug(DBG_SYMBOL, 4, "Probing (%s): %ld", key, val);
    if (tbl[val] == NULL) {
      debug(DBG_SYMBOL, 4, "Returning: %ld",val);
      return val;
    } else if (!strcmp(key, tbl[val]->symbol)) {
      debug(DBG_SYMBOL, 4, "Returning: %ld",val);
      return val;
    }
  }
  debug(DBG_SYMBOL, 4, "Item not found");
  return primes[max_size+sz]+1;
}

    
long SymbolTable::Insert(char *key, int arity)
{
  long index = 0;
  static char name[BUFSIZ] = { 0 };
  int pos = 0;
  
  SymbolNode *new_node = NULL;
  debug(DBG_SYMBOL, 2, "Inserting symbol (%s, %d)", key, arity);

  if (strlen(key) > SYMBOL_MAX_LENGTH) 
    error(FATAL_ERR, "symbol name '%s' too long. max name length"
	  " '%d'", key, SYMBOL_MAX_LENGTH); 
  
  if (arity < 0) {
    index = Hash(key);
    sprintf(name, key);
  }
  else {
    pos += sprintf(name, "%s/%d", key, arity);
    if (pos > BUFSIZ) {
      fprintf(stderr, "Buffer owerflow -- core dumped\n");
      abort();
    }
    name[pos] = '\0';
    index = Hash(name);
  }
  assert(index < primes[max_size]);

  // Is the item inserted already?
  if (items[index]) {
    return items[index]->value;
  }
  
  if (arity > max_arity)
    max_arity = arity;
      
  new_node =  new SymbolNode;
  if (!new_node)
    error(SYS_ERR, "malloc error");

  new_node->symbol = strdup(name);
  new_node->value = size++;
  items[index] = new_node;

  // If the table is half full, rehash it
  if ( size > primes[max_size]/2) {
    ReHash();
    index = Hash(key);
  }
  return items[index]->value;
}

long SymbolTable::Lookup(char *key, int arity)
{
  debug(DBG_SYMBOL, 2, "Looking for symbol (%s)", key);
  long index;
  static char name[BUFSIZ] = { 0 };
  int pos = 0;
  
  if (arity < 0)
    index = Hash(key);
  else {
    pos += sprintf(name, "%s/%d", key, arity);
    if (pos > BUFSIZ) {
      fprintf(stderr, "Buffer owerflow -- core dumped\n");
      abort();
    }
    name[pos] = '\0';
    index = Hash(name);
  }
  assert( index < primes[max_size]);
  
  if (items[index] != NULL)
    return items[index]->value;
  else
    return -1;
}

inline char *SymbolTable::LookupByValue(long key)
{
  debug(DBG_SYMBOL, 2, "Looking for value (%ld)", key);

  // check that symbol array is initialized 
  assert(symbols && key >= 0 && key < size); 
  return symbols[key];
}

void SymbolTable::CreateSymbolArray()
{
  debug(DBG_SYMBOL, 2, "Creating symbol array");
  assert (!symbols); //don't try to do this more than once
  long i;
  char *st;
  
  symbols = new char*[size];

  if (!symbols)
    error(SYS_ERR, "malloc error");
  
  debug(DBG_SYMBOL, 4, "\tInitializing symbols");
  for (i = 0; i < primes[max_size]; i++) {
    if (items[i]) {
      debug(DBG_SYMBOL, 5, "\t\t%ld: symbol: %s, val %ld", i,
	    items[i]->symbol, items[i]->value);
      // clear the arity information
      symbols[items[i]->value] = strdup(items[i]->symbol);
      st = symbols[items[i]->value];
      while (*st && *st != '/')
	st++;
      *st = '\0'; // results in 2 bytes of garbage
    }
  }
}

void SymbolTable::ReHash()
{
  long new_size = max_size+1, i, index = 0;

  debug(DBG_SYMBOL, 3, "Rehashing SymbolTable:\n\tOld max_size:"
	"%ld\n\tNew max_size: %ld\n", primes[max_size],
	primes[new_size]); 

  debug(DBG_SYMBOL, 4, "\tAllocating new table");
  tmp_nodes = new SymbolNode *[primes[new_size]];
  if (!tmp_nodes)
    error(SYS_ERR, "malloc error");

  memset(tmp_nodes, 0, sizeof(SymbolNode*)*primes[new_size]);

  debug(DBG_SYMBOL, 4, "\tRehashing:");
  for (i = 0; i < primes[max_size]; i++) {
    if (items[i] != NULL) {
      debug(DBG_SYMBOL, 5, "\t\t%ld: symbol: %s, val %ld", i,
	    items[i]->symbol, items[i]->value);      
      index = Hash(items[i]->symbol, 1);
      assert( index < primes[new_size]);
      tmp_nodes[index] = items[i];
    } 
  }
  max_size = new_size;
  debug(DBG_SYMBOL, 4, "Rehash: Deleting old nodes");
  delete [] items;
  items = tmp_nodes;
}

#ifdef DEBUG
void SymbolTable::PrintItems()
{
  long i;

  fprintf(stderr, "Table items:\n");
  for (i = 0; i < primes[max_size]; i++) {
    if (items[i] != NULL)
      fprintf(stderr, "\t%s: %ld\n", items[i]->symbol,
	      items[i]->value);
  }
}
#endif
      


FunctionTable::FunctionTable(long sz)
{
  debug(DBG_SYMBOL, 1, "Creating FunctionTable");
  max_size = 1;
  while (primes[max_size] < sz)
    max_size++;
  size = 0;

  debug(DBG_SYMBOL, 4, "\tAllocating items");
  items = new FunctionNode*[primes[max_size]];
  if (!items)
    error(SYS_ERR, "malloc error");

  memset(items, 0, sizeof(FunctionNode*)*primes[max_size]);
}

FunctionTable::~FunctionTable()
{
  debug(DBG_SYMBOL, 1, "destroying FunctionTable");
  long i;
  debug(DBG_SYMBOL, 4, "\tdestroying items");
  for (i = 0; i < primes[max_size]; i++) 
    delete items[i];

  debug(DBG_SYMBOL, 4, "\tdestroying array");
  delete [] items;
}

unsigned long FunctionTable::Hash(char *key, int sz)
{
  debug(DBG_SYMBOL, 3, "Hashing %s, size %d", key, sz);
  unsigned long val = 0, b = 0;
  FunctionNode **tbl = NULL;
  long  int i;
  char *p = key;

  if (sz)
    tbl = tmp_nodes;
  else
    tbl = items;
  
  while (*p) {
    b <<= 5;
    b += *p++;
    b %= primes[max_size+sz];
  }
  
  for (i = 0; i < primes[max_size+sz]; i++) {
    val = b + i*i;
    val %= primes[max_size+sz];

    debug(DBG_SYMBOL, 4, "Probing (%s): %ld", key, val);
    if (tbl[val] == NULL) {
      debug(DBG_SYMBOL, 4, "Returning: %ld",val);
      return val;
    } else if (!strcmp(key, tbl[val]->symbol)) {
      debug(DBG_SYMBOL, 4, "Returning: %ld",val);
      return val;
    }
  }
  debug(DBG_SYMBOL, 4, "Item not found");
  return primes[max_size+sz]+1;
}

    
long FunctionTable::Insert(char *key)
{
  long index = 0;
  debug(DBG_SYMBOL, 2, "Inserting function (%s, %d)", key);
  index = Hash(key);

  assert(index < primes[max_size]);

  // Check if the function is registered
  if (!items[index]) {
    warn(WARN_DECL, "Function '%s' not registered.\nTreating '%s' as"
	 "a predicate.", key, key); 
    return -1;
  }
  items[index]->valid = 1;
  return index;
}

long FunctionTable::Register(char *key, InstFunc p, int val)
{
  debug(DBG_SYMBOL, 2, "Registering function (%s, %ld)", key,
	(long) p);
  long index = 0;
  FunctionNode *new_node;
  index = Hash(key);
  
  assert(index < primes[max_size]);

  if (items[index]) {
    warn(WARN_DECL, "Trying to register function %s twice."
	 "Ignoring new registeration", key);
    return index;
  }

  new_node = new FunctionNode;
  if (!new_node)
    error(SYS_ERR, "malloc error");

  size++;
  debug(DBG_SYMBOL, 4, "\tInitializing new_node");
  new_node->symbol = strdup(key);
  new_node->valid = val;
  new_node->func = p;

  items[index] = new_node;
  return index;
}

long FunctionTable::Lookup(char *key)
{
  debug(DBG_SYMBOL, 2, "Looking for symbol (%s)", key);
  long index = Hash(key);

  assert( index < primes[max_size]);

  // check if exists and if valid
  if (!items[index] || !items[index]->valid)
    return -1;
  else
    return 1;
}

long FunctionTable::Define(char *s, char *v,  int val)
{
  debug(DBG_SYMBOL, 2, "Defining symbol %s -> %s", s, v);
  long symbol_index, value_index, sz = max_size;
  char *symbol = NULL, *value = NULL;

  // first we must check if 'o' or 'n' are operators.
  if (!isalpha(*s)) {
    symbol = get_function_name(s);
    debug(DBG_SYMBOL, 4, "\t%s -> %s", s, symbol);
  } else
    symbol = s;

  if (!isalpha(*v)) {
    value = get_function_name(v);
    debug(DBG_SYMBOL, 4, "\t%s -> %s", v, value);
  } else
    value = v;

  if (!symbol) {
    warn(WARN_DECL, "Unknown function symbol '%s'. Ignoring it. ", s);
  }
  if (!value) {
    warn(WARN_DECL, "Unknown function symbol '%s'. Ignoring it", v);
  }
  
  symbol_index = Hash(symbol);
  value_index = Hash(value);

  assert ((symbol_index < primes[max_size]) && (value_index <
					     primes[max_size]));
  if (!items[value_index]) {
    warn(WARN_DECL, "Function %s not registered. Cannot define '%s'",
	 v, s); 
    return -1;
  }

  if (!items[symbol_index]) {
    debug(DBG_SYMBOL, 4, "\tExisting symbol not found");
    Register(symbol, items[value_index]->func, val);

    // recalculate the index after rehashing
    if (sz != max_size)
      symbol_index = Hash(symbol);
    
    assert (items[symbol_index]); // check that all worked
    items[symbol_index]->valid = 1;
  } else {
    debug(DBG_SYMBOL, 4, "\tExisting symbol found");
    items[symbol_index]->func = items[value_index]->func;
  }
  return symbol_index;
}

  
void FunctionTable::ReHash()
{
  long new_size = max_size+1, i, index = 0;

  debug(DBG_SYMBOL, 3, "Rehashing FunctionTable:\n\tOld max_size:"
	"%ld\n\tNew max_size: %ld\n", primes[max_size],
	primes[new_size]); 

  debug(DBG_SYMBOL, 4, "\tAllocating new table");
  tmp_nodes = new FunctionNode *[primes[new_size]];
  if (!tmp_nodes)
    error(SYS_ERR, "malloc error");

  memset(tmp_nodes, 0, sizeof(FunctionNode*)*primes[new_size]);

  debug(DBG_SYMBOL, 4, "\tRehashing:");
  for (i = 0; i < primes[max_size]; i++) {
    if (items[i] != NULL) {
      debug(DBG_SYMBOL, 5, "\t\t%ld: symbol: %s, %s", i, 
	    items[i]->symbol, (items[i]->valid) ?
	    "valid" : "invalid");
      
      index = Hash(items[i]->symbol, 1);
      assert( index < primes[new_size]);
      tmp_nodes[index] = items[i];
    } 
  }
  max_size = new_size;
  debug(DBG_SYMBOL, 4, "Rehash: Deleting old nodes");
  delete [] items;
  items = tmp_nodes;
}

InstFunc FunctionTable::GetFunction(char *key)
{
  long index = Hash(key);
  if (!items[index] || !items[index]->valid)
    return NULL;
  else
    return items[index]->func;
}

int compare_functions(const void *it1, const void *it2)
{
  FunctionNode *i1, *i2;

  i1 = *((FunctionNode **) it1);
  i2 = *((FunctionNode **) it2);
  if (!i1) 
    return 1;
  if (!i2)
    return -1;

  return strcmp(i1->symbol, i2->symbol);
}

void FunctionTable::PrintRegistered()
{
  int pos = 0, first = 1;
  char **tmp = NULL;
  long i;

  printf("Included files: library.cc");
  tmp = external_files;
  while (*tmp) {
    printf(", %s", *tmp);
    tmp++;
  }
  printf("\n");

  // sort using qsort
  qsort(items, primes[max_size], sizeof(FunctionNode*),
	compare_functions);
  
  printf("Registered functions:\n");
  pos = printf("   ");

  for (i = 0; i < primes[max_size]; i++) {
    if (items[i]) {
      if (!first) 
	pos += printf(", ");
      else {
	first = 0;
      }
      if (pos >= 70) {
	printf("\n");
	pos = printf("   ");
      }
      pos += printf("%s%s", items[i]->symbol,
		    (items[i]->valid ? "(*)" : "")); 
    } else {
      break;
    }

  }
  printf("\n(*) means that the function is an internal function\n");
}
  


#ifdef DEBUG
void FunctionTable::PrintItems()
{
  long i = 0;

  fprintf(stderr, "Table items:\n");

  for (i = 0; i < primes[max_size]; i++) {
    if (items[i] != NULL)
      fprintf(stderr, "\t%s: %s\n", items[i]->symbol,
	      items[i]->valid ? "valid" : "invalid");
  }
}
#endif
      

