/* rule.cc -- declarations for rules
   
   This program has no warranties of any kind. Use at own risk.

   Author: Tommi Syrjnen (tommi.syrjanen@hut.fi)
   
   $Id: rule.cc,v 1.1 1998/08/04 09:19:03 tssyrjan Exp $	 
*/

#include "global.h"
#include "rule.h"
#include "literal.h"
#include "list.h"
#include <string.h>
#include "graph.h"
#include <limits.h>
#include "symbol.h"
#include "predicate.h"
#include "iterate.h"


Rule::Rule(Literal *lt)
  : positive(UNORDERED), negative(UNORDERED)
{
  status = RT_STRONG;
  head = lt;
  line_start = line_end = 0;
  
  if (lt->has_range || lt->has_function)
    SimplifyLiteral(lt);
}

void Rule::SimplifyLiteral(Literal *lt)
{

  // go through all arguments, generate a new variable and a new
  // predicate for each range found and generate a new variable and a
  // function term for every function found.

  int pos = 0;
  long i, j, start, end, tmp;
  long var = -1, pred = -1;
  Literal *new_lt = NULL;
  Instance new_instance = -1;
  Term *nt1 = NULL;
  Function *nt2 = NULL;
  InstFunc fp;

  static char *varname = NULL;
  static char *predname = NULL;

  if (!varname) {
    varname = new char[RANGE_MAX_NAME];
    predname = new char[RANGE_MAX_NAME];
    if (!varname || !predname)
      error(SYS_ERR, "malloc error");
  }

  for (i = 0; i < lt->arity; i++) {
    if (lt->args[i] == T_RANGE) {
      if (sys_data.num_ranges >= RANGE_MAX_NUMBER)
	int_error("maximum number of ranges exceeded", "");
      
      // get the right range number
      pos = sprintf(varname, "Range&%d", sys_data.num_ranges);
      sprintf(predname, "range&%d", sys_data.num_ranges);
      sys_data.num_ranges++;
      
      varname[pos] = '\0';
      predname[pos] = '\0';
      
      var = variable_table->Insert(varname);
      if (var < 0)
	int_error("cannot generate range variable '%s'", varname);
      lt->args[i] = T_VARIABLE;
      lt->vars[i] = var;
      
      pred = predicate_table->Insert(predname, 1);
      predicates[pred]->SetArity(1);
      predicates[pred]->SetName(strdup(predname));
      predicates[pred]->SetPred(pred);
      predicates[pred]->SetStatus(DM_INTERNAL);
      
      if (pred < 0)
	int_error("cannot generate range domain '%s'", predname);
    
      dependency_graph->AddNode(pred);
      dependency_graph->AddEdge(head->pred, pred);
      
      start = GET_VALUE(((Range *) lt->terms[i])->start);
      end = GET_VALUE(((Range *) lt->terms[i])->end);
      if (start > end) {
	tmp = start;
	start = end;
	end = tmp;
      }

      // add the values of range to the predicate
      for (j = start; j <= end; j++) {
	new_instance = MAKE_NUMBER(j);
	predicates[pred]->AddInstance(&new_instance);
      }
      // and finally construct a matching literal and add it
      new_lt = new Literal(var, pred);
      AddLiteral(new_lt, 0);
    } else
      if (lt->args[i] == T_FUNCTION) {
	pos = sprintf(varname, "Arg&%d", sys_data.num_functions);
	varname[pos] = '\0';
	sys_data.num_functions++;

	// put the variable in place
	var = variable_table->Insert(varname);
	if (var < 0)
	  int_error("cannot generate function argument '%s'",
		    varname);
	lt->args[i] = T_VARIABLE;
	lt->vars[i] = var;
	
	// construct the new function
	fp = function_table->GetFunction("assign");
	nt1 = new Term(T_VARIABLE, var, line_start);
	nt2 = new Function(fp, "assign", line_start);
	nt2->AddArgument(nt1);
	nt2->AddArgument(lt->terms[i]);
	
	if (nt2->has_range)
	  RestrictFunction(nt2);
	functions.Insert(nt2);
      }
  }
}

void Rule::AddLiteral(Literal *lt, int neg)
{
  if (lt->has_range || lt->has_function)
    SimplifyLiteral(lt);
  if (neg)
    negative.Insert(lt);
  else
    positive.Insert(lt);
}

void Rule::AddFunction(Function *t)
{
  functions.Insert(t);
}


void Rule::RestrictFunction(Function *tm)
{

  int pos = 0;
  long i, j, start, end, tmp;
  long var = -1, pred = -1;
  Literal *lt = NULL;
  Instance new_instance = -1;
  static char *varname = NULL;
  static char *predname = NULL;

  if (!varname) {
    varname = new char[RANGE_MAX_NAME];
    predname = new char[RANGE_MAX_NAME];
    if (!varname || !predname)
      error(SYS_ERR, "malloc error");
  }
    
  for (i = 0; i < tm->arity; i++) {
    if (tm->args[i]->type == T_RANGE) {
      
      if (sys_data.num_ranges >= RANGE_MAX_NUMBER)
	int_error("maximum number of ranges exceeded", "");
      
      // get the right range number
      pos = sprintf(varname, "Range&%d", sys_data.num_ranges);
      sprintf(predname, "range&%d", sys_data.num_ranges);
      sys_data.num_ranges++;
      
      varname[pos] = '\0';
      predname[pos] = '\0';
      
      var = variable_table->Insert(varname);
      if (var < 0)
	int_error("cannot generate range variable '%s'", varname);
      tm->args[i]->type = T_VARIABLE;
      tm->val = var;
      
      pred = predicate_table->Insert(predname, 1);

      if (pred < 0)
	int_error("cannot generate range domain '%s'", predname);
      predicates[pred]->SetArity(1);
      predicates[pred]->SetName(strdup(predname));
      predicates[pred]->SetPred(pred);      
      dependency_graph->AddNode(pred);
      dependency_graph->AddEdge(head->pred, pred);
      
      start = GET_VALUE(((Range *) tm->args[i])->start);
      end = GET_VALUE(((Range *) tm->args[i])->end);
      if (start > end) {
	tmp = start;
	start = end;
	end = tmp;
      }

      // add the values of range to the predicate
      for (j = start; j <= end; j++) {
	new_instance = MAKE_NUMBER(j);
	predicates[pred]->AddInstance(&new_instance);
      }
      // and finally construct a matching literal and add it
      lt = new Literal(var, pred);
      AddLiteral(lt, 0);
    } else if (tm->args[i]->has_range) {
      RestrictFunction((Function*) tm->args[i]);
    }
  }
}

void Rule::GroundRule(int domain)
{
  debug(DBG_GROUND, 3, "Grounding rule. domain: %d", domain); 
  Literal *lt = NULL;
  Function *fun = NULL;
  FunctionList *funs = NULL;
  Iterator **indices = NULL;
  Predicate *pr = NULL;
  LiteralList *negs = NULL;
  Literal **lts = NULL;
  Instance *item = NULL;
  static Instance *null_item = NULL;
  Instance **values = NULL;
  int *inds = NULL;
  int max = 0, i = 0, existing = 0, ind = -1;
  LiteralList posg(ORDERED_SMALL), posn(UNORDERED);
  // respectively number of domain positive, nondomain positive and
  // domain negative
  int gpos = 0, npos = 0, gneg = 0, pos = 0, discard = 0; 

  if (!null_item) {
    null_item = new Instance;
    if (!null_item)
      error(SYS_ERR, "malloc error");
    *null_item = -1;
  }
  if (!var_pos) {
    var_pos = new Variable[variable_table->Size()];
    if (!var_pos)
      error(SYS_ERR, "malloc error");
  }

  
  // first we divide positive literals to domain and non-domain
  // predicates and sort domain predicates by the size of their
  // domains. 
  while ((lt = positive.Iterate())) {
    if ((predicates[lt->pred]->Status() == DM_DOMAIN) ||
	(predicates[lt->pred]->Status() == DM_INTERNAL)) {
      posg.Insert(lt, predicates[lt->pred]->Size());
      gpos++;
    } else {
      posn.Insert(lt);
      npos++;
    }
  }

  // allocate and initialize the data structures
  max = posg.Size();
  if (functions.Size() > 0) {
    funs = new FunctionList[max];
    if (!funs)
      error(SYS_ERR, "malloc error");
  }
  indices = new Iterator*[max];
  if (negative.Size() > 0) {
    negs = new LiteralList[max];
    if (!negs)
      error(SYS_ERR, "malloc error");
  }
  inds = new int[max];
  lts = new Literal*[max];
  values = new Instance*[max];
  if (!inds || !indices || !lts || !values)
    error(SYS_ERR, "malloc error");

  memset(variables, -1, variable_table->Size() * sizeof(Instance));
  memset(inds, -1, max * sizeof(int));
  memset(indices, 0, max * sizeof(Iterator*));
  memset(lts, 0, max * sizeof(Literal*));
  memset(var_pos, -1, variable_table->Size() * sizeof(Instance));
  for (i = 0; i < max; i++)
    values[i] = null_item;
    
  // set value of variables[i] to the first literal in positive
  // position where it appears in the domain predicates. Confusing,
  // huh. At same time choose an index for 
  pos = -1;
  while ((lt = posg.Iterate())) {
    pos++; existing = 0; ind = -1;
    lts[pos] = lt;
    pr = predicates[lt->pred];
    for (i = 0; i < lt->arity; i++) {
      if (lt->vars[i] >= 0) {
	if (variables[lt->vars[i]] < 0) {
	  variables[lt->vars[i]] = pos;
	} else { // choose as index if needed
	  if (!existing) {
	    if (pr->existing_indexes & (1 << i)) 
	      existing = 1;
	    ind = i;
	  }
	}
      }
    }
    // allocate the iterator
    if (ind < 0) {
      indices[pos] = new InstanceIterator(pr->atoms);
    } else {
      if (!existing)
	pr->CreateIndex(ind);
      indices[pos] = new IndexIterator(pr->indices[ind]);
    }
    if (!indices[pos])
      error(SYS_ERR, "malloc error");
    inds[pos] = ind;
    // and set the values array to point to right variable
    if (ind >= 0)
      values[pos] = &variables[lt->vars[ind]];
  }

  // then go through all functions to seek when the function can be
  // computed and add it to right list
  if (funs)
    while ((fun = functions.Iterate())) {
      pos = fun->GetPos();
      if (pos < 0)
	pos = max -1;
      if ((pos < 0) || (pos >= max))
	int_error("invalid function position '%d'", pos);
      fun->AddVars(pos);
      funs[pos].Insert(fun);
    }

  // and same to negative literals
  if (negs)
    while ((lt = negative.Iterate())) {
      if (!lt->ground) {
	pos = lt->GetPos();
	if ((pos < 0) || (pos >= max))
	  int_error("invalid negative literal position '%d'", pos);
	negs[pos].Insert(lt);
	// at the same time calculate the number of negative domains
	if (predicates[lt->pred]->Status() == DM_DOMAIN)
	  gneg++;
      }
    }

  // and then do the actual job
  memset(variables, -1, variable_table->Size() * sizeof(Instance));
  pos = 0;

  if (max == 0) {
    EmitGround(gpos, gneg);
    return;
  }
  while (1) {
    if (pos < 0)
      break; // all possible bindings have been checked

    discard = 0;
    item = indices[pos]->Iterate(*values[pos]);

    // if run out of items in this level go back up one level
    if (!item) {
      pos--;
      clear_pos(pos);
    } else {
      // bind the new variables
      if (BindLiteral(lts[pos], item, pos)) {
	// values were ok, then check all applicable functions and 
	// negative literals. These have to be in this order since
	// fun->Test() may set values of some variables
	while (funs && (fun = funs[pos].Iterate()))
	  if (!fun->Test(pos))
	    discard = 1;
	if (sys_data.print_domains != PR_ALL)
	  while (negs && (lt = negs[pos].Iterate()))
	    if (!lt->Test())
	      discard = 1;
      } else
	discard = 1; // binding failed

      if (!discard) {
	if (pos == max - 1) { // deepest level
	  if (domain)
	    head->CreateInstance();
	  else
	    EmitGround(gpos, gneg);
	  clear_pos(pos);
	} else {
	  // not yet at deepest level, so go deeper
	  pos++;
	}
      }else {
	// clear the variables and stay at this level
	clear_pos(pos);
      }
    }
  }
  delete [] inds;
  if (funs)
    delete [] funs;
  if (negs)
    delete [] negs;
  delete [] indices;
  delete [] lts;
}

int Rule::BindLiteral(Literal *lt, Instance *item, int pos)
{
  int i;

  if (!var_pos){
    var_pos = new Variable[variable_table->Size() * sizeof(Instance)];
    if (!var_pos)
      error(SYS_ERR, "malloc error");
  }

  for (i = 0; i < lt->arity; i++) {
    if (lt->vars[i] >= 0) {
      if (variables[lt->vars[i]] >= 0) { // variable has an old value
	if (item[i] != variables[lt->vars[i]]) 
	  return 0; // mismatch
	else 
	  continue;
      } else { // set the variable value and position
	variables[lt->vars[i]] = item[i];
	var_pos[lt->vars[i]] = pos;
      }
    } else if (lt->cons[i] != item[i])
      return 0;
  }

  return 1;
}


// prints out one ground instance of the rule posg and negg are number
// of positive domain predicates and negative domain predicates in the
// body.
void Rule::EmitGround(int posg, int negg)
{
  FILE *ofile = sys_data.output_file;
  int first = 1, neg = 0, pos = 0;
  Instance head_atom = -1;
  Literal *lt = NULL;
  long cmp = -1;
  
  sys_data.ground_rules++;
  

  head_atom = head->EmitGround();
  DomainType st = DM_UNKNOWN;

  if (sys_data.emit_text) {
    // don't print implication if fact... Yes, this has to be
    // checked. 
    if ((positive.Size() - posg == 0) &&
	(negative.Size() - negg == 0) &&
	(sys_data.print_domains != PR_ALL) &&
	(sys_data.print_domains != PR_POSITIVE))
      ; // do nothing
    else {
      fprintf(ofile, " :- ");
      while ((lt = positive.Iterate())) {
	st = predicates[lt->pred]->Status();
	if ( (st == DM_INTERNAL) ||
	     ((st == DM_DOMAIN) && ((sys_data.print_domains != PR_ALL)
				    && (sys_data.print_domains !=
					PR_POSITIVE))))
	  ; // do nothing
	else {
	  if (!first)
	    fprintf(ofile, ", ");
	  first = 0;
	  lt->EmitGround();
	}
      }
      while ((lt = negative.Iterate())) {
	st = predicates[lt->pred]->Status();
	if ( (st == DM_INTERNAL) ||
	     ((st == DM_DOMAIN) &&
	      (sys_data.print_domains != PR_ALL)))
	  ; // do nothing
	else {
	  if (!first)
	    fprintf(ofile, ", ");
	  first = 0;
	  fprintf(ofile, "not ");
	  lt->EmitGround();
	}
      }
    }
    fprintf(ofile, ".\n");
  } else {
    switch (sys_data.print_domains) {
    case PR_ALL:
      pos = positive.Size();
      neg = negative.Size();
      break;
    case PR_HEADS:
    case PR_NONE:
      pos = positive.Size() - posg;
      neg = negative.Size() - negg;
      break;
    case PR_POSITIVE:
      pos = positive.Size();
      neg = negative.Size() - negg;
      break;
    }

    // print out the negative literals
    fprintf(ofile, "%d ", neg);
    while ((lt = negative.Iterate())) {
      st = predicates[lt->pred]->Status();
      if ( (st == DM_INTERNAL) ||
	   ((st == DM_DOMAIN) &&
	    (sys_data.print_domains != PR_ALL)))
	; // do nothing
      else {
	lt->EmitGround();
      }
    }
    // and positive
    fprintf(ofile, "%d ", pos);
    while ((lt = positive.Iterate())) {
      st = predicates[lt->pred]->Status();
      if ( (st == DM_INTERNAL) ||
	   ((st == DM_DOMAIN) && ((sys_data.print_domains != PR_ALL)
				  && (sys_data.print_domains !=
				      PR_POSITIVE))))
	; // do nothing
      else {
	lt->EmitGround();
      }
    }
    fprintf(ofile, "\n");
  }
  // if the head foo() has a complementary atom foo'() we have to emit
  // also rule " :- foo, foo'."
  if (predicates[head->pred]->complement >= 0) {
    cmp = predicates[head->pred]->complement;
    if (predicates[head->pred]->emitted->Lookup(&head_atom))
      ; // don't emit
    else {
      sys_data.ground_rules++;
      false_lit->EmitGround();
      if (sys_data.emit_text)
	fprintf(ofile, " :- ");
      else
	fprintf(ofile, " 0 2 ");
      head_atom = head->EmitGround();
      predicates[head->pred]->emitted->Insert(&head_atom);
      if (sys_data.emit_text)
	fprintf(ofile, ", ");
      else
	fprintf(ofile, " ");
      head_atom = head->EmitComplement();
      predicates[head->pred]->emitted->Insert(&head_atom);
      if (sys_data.emit_text)
	fprintf(ofile, ".\n");
      else
	fprintf(ofile, "\n");
    }
  }
}
      
void Rule::clear_pos(int pos)
{
  int i;
  for (i = 0; i < variable_table->Size(); i++) {
    if (var_pos[i] >= pos) {
      variables[i] = -1;
      var_pos[i] = -1;
    }
  }
}

RestrictType Rule::CheckRestrict()
{
  long i = 0;
  Literal *lt = NULL;
  Function *fun = NULL;
  RestrictType res = RT_STRONG;
  static long *weak = NULL, *none = NULL;
  long numweak = 0, numnone = 0;

  if (!weak || !none) {
    weak = new long[variable_table->Size()];
    none = new long[variable_table->Size()];
    
    if (! weak || !none)
      error(SYS_ERR, "malloc error");
  }

  memset(variables, R_NONE, sizeof(Variable) * variable_table->Size()); 
  
  head->Restrict(R_NEGATIVE);

  while (( lt = negative.Iterate()))
    lt->Restrict(R_NEGATIVE);

  while (( fun = functions.Iterate()))
    fun->Restrict(R_NEGATIVE);
  
  while ((lt = positive.Iterate()))
    if (predicates[lt->pred]->Status() == DM_DOMAIN)
      lt->Restrict(R_DOMAIN);
    else
      lt->Restrict(R_POSITIVE);

  for (i = 0; i < variable_table->Size(); i++) {
    switch (variables[i]) {
    case R_NEGATIVE:
      res = RT_NONE;
      none[numnone++] = i;
      break;
    case R_POSITIVE:
      if (res < RT_STRONG)
	res = RT_WEAK;
      weak[numweak++] = i;
      break;
    default:
      break;
    }
  }
  if (res < RT_STRONG) {
    fprintf(stderr, "%ld-%ld: %srestricted rule: ", line_start,
	    line_end, ( (res == RT_NONE) ? "non" : "weakly "));
    PrintRule();
    if (numnone > 0) {
      fprintf(stderr, "\n\tunrestricted variables: ");
      for (i = 0; i < numnone; i++)
	fprintf(stderr, "%s ", variable_table->symbols[none[i]]);
      fprintf(stderr, "\n");
    }
    if (numweak > 0) {
      fprintf(stderr, "\tweakly restricted variables: ");
      for (i = 0; i < numweak; i++)
	fprintf(stderr, "%s ", variable_table->symbols[weak[i]]);
      fprintf(stderr, "\n");
    }
    
  }
  return res;
}

void Rule::PrintRule()
{
  int first = 1;
  Literal *lt = NULL;
  Function *ft = NULL;
  head->Print();

  fprintf(stderr, " :- ");
  
  while (( lt = positive.Iterate())) {
    if (!first)
      fprintf(stderr, ", ");
    first = 0;
    lt->Print();
  }
  while (( lt = negative.Iterate())) {
    if (!first)
      fprintf(stderr, ", ");
    fprintf(stderr, "not ");
    first = 0;
    lt->Print();
  }
  while ((ft = functions.Iterate())) {
    if (!first)
      fprintf(stderr, ", ");
    ft->Print();
  }
  fprintf(stderr, ".");
}

Variable *var_pos = NULL;

