/* term.cc -- implementation for terms
      
   This program has no warranties of any kind. Use at own risk.

   Author: Tommi Syrjnen (tommi.syrjanen@hut.fi)

   $Id: term.cc,v 1.1 1998/08/04 09:18:59 tssyrjan Exp $	 

   */

#include "term.h"
#include "symbol.h"
#include "parse.h"
#include "debug.h"
#include <string.h>


char *function_strings[] = { "abs", "eq", "le", "ge", "lt", "gt",
			     "neq", "plus", "minus", "times", "div",
			     "mod", "assign", "***error***" };

char *operators[] = { "||", "==", "<=", ">=", "<", ">", "!=", "+",
		      "-", "*", "/", "%", "=" };

// returns the internal function name corresponding to an operator */
char *get_function_name(char *fun)
{
  int i;
  for (i = 1; i < NUM_MAX_FUN; i++) {
    if (!strcmp(fun, operators[i]))
      return function_strings[i];
  }
  
  return NULL;
}

Term::Term(TermType tt, long v, long ln)
{
  lineno = ln;
  val = v;
  type = tt;
  if (tt == PT_VARIABLE)
    ground = 0;
  else
    ground = 1;
  has_range = 0;
  has_function = 0;
}



Term *Term::ConstructArgument(ParseNode *item)
{
  Term *new_term;
  long nval;

  switch (item->type) {
  case PT_EXPR:
  case PT_FUNCTION:
    return item->ext.fun; // since postorder this can be done
    break;
  case PT_VARIABLE:
    nval = variable_table->Lookup(item->sval);
    if (nval < 0)
      int_error("unknown variable '%s'", item->sval);
    new_term = new Term(T_VARIABLE, nval, item->lineno);
    if (!new_term)
      error(SYS_ERR, "malloc error");
    return new_term;
  case PT_NUMBER:
    new_term = new Term(T_CONSTANT, item->lval, item->lineno);
    if (!new_term)
      error(SYS_ERR, "malloc error");
    return new_term;
  case PT_CONSTANT:
    nval = constant_table->Lookup(item->sval);
    if (nval < 0)
      int_error("unknown constant '%s'", item->sval);
    new_term = new Term(T_CONSTANT, nval, item->lineno);
    if (!new_term)
      error(SYS_ERR, "malloc error");
    return new_term;
  case PT_RANGE:
    new_term = new Range(item->start, item->end, item->lineno);
    if (!new_term)
      error(SYS_ERR, "malloc error");
    return new_term;
  default:
    int_error("invalid term type %s", parse_strings[item->type]);
  }
  return NULL;
}

Function::Function(InstFunc f, char *n, long ln)
{
  type = T_FUNCTION;
  fun = f;
  arity = 0;
  name = n;
  ground = 1;
  has_range = 0;
  has_function = 1;
  lineno = ln;
  args = new Term*[TERM_MAX_ARITY];
  if (!args)
    error(SYS_ERR, "malloc error");
  memset (args, 0, TERM_MAX_ARITY * sizeof(Term*));
  if (!strcmp(name, "assign"))
    assign = 1;
  else
    assign = 0;
  if (!strcmp(name, "eq") || !strcmp(name, "neq"))
    accept_constants = 1;
  else
    accept_constants = 0;
}

Function::~Function()
{
  delete args;
}

int Function::AddArgument(Term *t)
{
  debug(DBG_TERM, 2, "Adding argument %ld, old arity %d", t, arity);
  assert (t);

  args[arity++] = t;
  if (!t->ground) 
    ground = 0;
  if (t->has_range) 
    has_range = 1;
  return 1;
}

Instance Function::Call()
{
  int i  =0 ;
  int parms[TERM_MAX_ARITY] = { 0 };

  for(i = 0; i < arity; i++) {
    switch (args[i]->type) {
      // non numerical constants are weeded out before this 
    case T_CONSTANT: 
      assert (IS_NUMBER(args[i]->val));
      parms[i] = GET_VALUE(args[i]->val);
      break;
    case T_VARIABLE:
      if (!IS_NUMBER(variables[args[i]->val]) && !accept_constants)
	error(FATAL_ERR, "Runtime error: non-numerical constant '%s'"
	      " passed as an argument to function '%s' at line %ld",
	      constant_table->symbols[variables[args[i]->val]], name,
	      lineno); 
      parms[i] = GET_VALUE(variables[args[i]->val]);
      break;
    case T_FUNCTION:
      parms[i] = ((Function *)args[i])->Call();
      break;
    default:
      int_error("misformed function call tree", "");
    }
  }
  // *** TERM_MAX_ARITY ***
  return fun( arity,
	      parms[0],
	      parms[1],
	      parms[2],
	      parms[3],
	      parms[4],
	      parms[5]);
}

int Function::GetPos()
{
  int res = 0, i = 0, pos = 0;

  for (i=0; i < arity; i++) {
    switch (args[i]->type) {
    case T_FUNCTION:
      res = max(res, ((Function*)args[i])->GetPos());
      break;
    case T_VARIABLE:
      pos = variables[args[i]->val];
      if (pos < 0)  // there are some undefined variables
	return -1;
      res = max(res, pos);
      break;
    case T_CONSTANT:
      break;
    default:
      int_error("misformed function call tree", "");
    }
  }
  return res;
}

Instance Function::Test(int pos)
{
  Instance res = -1;
  
  // If function call has an assignment store the value
  if (assign) {
    if ((arity != 2) || (args[1]->type != T_FUNCTION))
      int_error("misformed assign statement", "");
    
    res = ((Function*)args[1])->Call();
    if (res < 0)
      error(FATAL_ERR, "runtime error: tried to store negative" 
	    " value '%ld' to variable '%s'.\n Value was"
	    " returned by %s().", res,
	    variable_table->symbols[args[0]->val], ((Function
						     *)args[1])->name);
    variables[args[0]->val] = MAKE_NUMBER(res);
    var_pos[args[0]->val] = pos;
    return 1;
  } else
    res = Call();

  if (negative) 
    return ! res;
  else 
    return res;
}

void Function::AddVars(int pos)
{
  int i;

  for (i = 0; i < arity; i++) {
    switch (args[i]->type) {
    case T_FUNCTION:
      ((Function*) args[i])->AddVars(pos);
      break;
    case T_VARIABLE:
      if (variables[args[i]->val] < 0)
	variables[args[i]->val] = pos;
      break;
    case T_CONSTANT:
      break;
    default:
      int_error("misformed function call tree", "");
    }
  }
}

void Function::Restrict(int t)
{
  int i = 0;
  // treat assignments as a domain predicate
  if (assign) 
    variables[args[0]->val] = R_DOMAIN;
  else {
    for (i = 0; i < arity; i++) {
      if ((args[i]->type == T_VARIABLE) &&
	  (variables[args[i]->val] < t))
	variables[args[i]->val] = t;
      else if (args[i]->type == T_FUNCTION)
	((Function *)args[i])->Restrict(t);
    }
  }
}

void Function::Print()
{
  int first = 1;
  int i;
  Range *r;
  fprintf(stderr, "%s(", name);
  for (i = 0; i < arity; i++) {
    if (!first)
      fprintf(stderr, ",");
    first = 0;
    switch (args[i]->type) {
    case T_VARIABLE:
      fprintf(stderr, "%s", variable_table->symbols[args[i]->val]);
      break;
    case T_CONSTANT:
      if (IS_NUMBER(args[i]->val))
	fprintf(stderr, "%ld", GET_VALUE(args[i]->val));
      else
	fprintf(stderr, "%s", constant_table->symbols[args[i]->val]);
      break;
    case T_FUNCTION:
      ((Function*) args[i])->Print();
      break;
    case T_RANGE:
      r = (Range *) args[i];
      fprintf(stderr, "%ld..%ld", r->start, r->end);
      break;
    default:
      fprintf(stderr, "** error **");
    }
  }
  fprintf(stderr, ")");
}
Range::Range(long s, long e, long ln)
{
  type = T_RANGE;
  start = s;
  end = e;
  ground = 1;
  has_range = 1;
  has_function = 0;
  lineno = ln;
}
