// predicate.cc -- implementation of predicate data type
// Copyright (C) 1999-2000 Tommi Syrjnen <Tommi.Syrjanen@hut.fi>
//  
// This program is free software; you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation; either version 2 of the License, or
// (at your option) any later version.
//  
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License for more details.
//  
// You should have received a copy of the GNU General Public License
// along with this program; if not, write to the Free Software
// Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
//  

#include "../config.h"
#include <string.h>
#include <stdio.h>
#include <stdlib.h>

#ifndef PREDICATE_H
#include "predicate.h"
#endif
#ifndef SYMBOL_H
#include "symbol.h"
#endif
#ifndef ITERATE_H
#include "iterate.h"
#endif
#ifndef ERROR_H
#include "error.h"
#endif

Predicate::Predicate()
  : indices(NULL), atoms(NULL), existing_indexes(0), arity(-1),
    rules(UNORDERED), special_rules(UNORDERED),
    it(NULL), name(NULL), status(DM_UNKNOWN) 
{
  pred = -1;
  complement = -1;
  emitted = NULL;
  follows = 0;
  hidden = 0;
  special = 0;
  has_rule = 0;
  is_warned = 0;
  predicate_weight_list = NULL;
  weight_list = NULL;
  first_line = 0;
}


Predicate::~Predicate()
{
  int i;
  if (indices) {
    for (i = 0; i < arity; i++) {
      if (indices[i])
	delete indices[i];
    } 
    delete [] indices;
  }
  if (it)
    delete it;
  if (atoms)
    delete atoms;
  if (name && strcmp(name, FALSE_NAME))
    free(name);
}


// returns NULL if the predicate name is already complemented 
char *complement_name(char *nm)
{
  int len = strlen(nm);
  static char comp_name[BUFSIZ] = { 0 };

  if (len > BUFSIZ -2) {
    error(FATAL_ERR, "too long predicate name '%s'", nm);
  }

  if (nm[len-1] == '\'' || nm[0] == '_' )
    return NULL;
  
  strcpy(comp_name, nm);
  strcpy(&comp_name[len], "'");
  return comp_name;
}

RestrictType Predicate::CheckRestrict(int print)
{
  debug(DBG_RESTRICT, 2, "Checking restrict for %s", name);
  assert(arity >= 0);
  RestrictType result = RT_STRONG, tmp = RT_STRONG;
  Rule *rl = NULL;

  if ((status == DM_SPECIAL) ||
      (status == DM_INTERNAL))
    return RT_STRONG;
  
  while ((rl = rules.Iterate())) {
    tmp = rl->CheckRestrict(print);
    switch (tmp) {
    case RT_NONE:
      result = RT_NONE;
      break;
    case RT_WEAK:
      if (result > RT_NONE)
	result = RT_WEAK;
      break;
    default:
      break;
    }
  }

  debug(DBG_RESTRICT, 2, "\tResult: $d", result);
  return result;
}

void Predicate::AddRule(Rule *rl, int cmp)
{
  has_rule = 1;
  // check if the rule is a fact
  if ((rl->positive.Size() == 0) && (rl->negative.Size() == 0) &&
      (rl->functions.Size() == 0)) {
    AddInstance(rl->head->cons);

    if (sys_data.use_regular_models && cmp) {
      predicates[complement]->AddInstance(rl->head->cons);
    }
    
    delete rl;
  } else {
    rules.Insert(rl);
    if (sys_data.use_regular_models && cmp) {
      Rule *r = rl->CreateComplement();
    predicates[complement]->rules.Insert(r);
    }
  }
}

void Predicate::AddSpecialRule(Rule *rl, int cmp)
{
  has_rule = 1;
  special = 1;

  special_rules.Insert(rl);

  // check if this is the internal false predicate that is otherwise a
  // domain predicate and false.
  if (false_lit && (false_lit->pred == pred) && (status == DM_DOMAIN)) {
    status = DM_FALSE;
  }

  if (sys_data.use_regular_models && cmp) {
    Rule *r = rl->CreateComplement();
    predicates[complement]->special_rules.Insert(r);
  }
}


void Predicate::AddInstance(Instance *it)
{
  if (arity == 0)
    follows = 1;
  else 
    atoms->Insert(it);
}

void Predicate::AddWeight(Weight *nw, Literal *lt, int global)
{
  WeightNode *wn = NULL;

  wn = new WeightNode;
  wn->wt = nw;
  wn->lt = lt;
  if (global) {
    wn->prev = predicate_weight_list;
    predicate_weight_list = wn;
  } else {
    wn->prev = weight_list;
    weight_list = wn;
  }
}


void Predicate::SetArity(int ar)
{
  arity = ar;

  if (ar > 0) {
    if (!atoms) {
      atoms = new InstanceSet(sys_data.domain_size, ar);
      it = new InstanceIterator(atoms);
      if (!atoms || ! it)
	error(SYS_ERR, "malloc error");
    } else {
      atoms->Clear(ar);
      it->Clear();
    }
  }
}

void Predicate::CreateIndex(int ind)
{
  Instance *item = NULL;
  assert((ind >= 0) && (ind < arity));

  if (!indices) {
    indices = new Index*[arity];
    if (!indices)
      error(SYS_ERR, "malloc error");
    for (long i = 0; i < arity; i++ )
      indices[i] = NULL;
  }
  if (existing_indexes & (1 << ind)) 
    int_error("trying to create same index two times in predicate"
	      " '%s'", name);

  indices[ind] = new Index(arity, ind);
  if (!indices[ind])
    error(SYS_ERR, "malloc error");

  existing_indexes |= (1 << ind);

  it->Clear();
  while ((item = it->Iterate())) 
    indices[ind]->Insert(item);
}


void Predicate::Emit()
{
  Instance *item = NULL;
  Rule *rl = NULL;
  Literal *head = NULL;
  static Instance *items = NULL;
  int i = 0;
  
  // check if this is a domain predicate that shouldn't be printed
  if (((sys_data.print_domains == PR_NONE) &&
       (status == DM_DOMAIN)) || (status == DM_INTERNAL) ||
      (status == DM_SPECIAL))
    return;

  if (it)
    it->Clear();

  if (!items) {
    items = new Instance[predicate_table->MaxArity()+1];
    if (!items)
      error(SYS_ERR, "malloc error");
    memset(items, 0, sizeof(Instance) * predicate_table->MaxArity()+1); 
  }
  
  if (arity == 0) {
    if (follows) {
      // predicate is a ground fact. Construct a rule and emit it.
      head = new Literal(pred, NULL, 0, 0);
      if (!head)
	error(SYS_ERR, "malloc error");
      rl = new Rule(head, BASICRULE);
      if (!rl)
	error(SYS_ERR, "malloc error");
      rl->EmitGround(0,0,1);
      delete rl;
    } else if (false_lit && (false_lit->pred == pred) &&
	       (sys_data.output_version < 2)) {
      // construct the rule false :- false. (this is needed because 
      // otherwise smodels 1 won't work when false is not true :-)
      head = false_lit;
      rl = new Rule(head, BASICRULE);
      if (!rl)
	error(SYS_ERR, "malloc error");
      rl->AddLiteral(head); 
      rl->EmitGround(0,0,0);
      delete rl;
    }

  }
  // print all the ground instances 
  if ((arity > 0) && atoms && (atoms->Size() > 0)) { 
    head = new Literal(pred, items, arity, 0);
    if (!head)
      error(SYS_ERR, "malloc error");
    
    rl = new Rule(head, BASICRULE);
    if (!rl)
      error(SYS_ERR, "malloc error");
    
    while ((item = it->Iterate())) {
      for (i = 0; i < arity; i++) {
	head->cons[i] = item[i];
      }
      rl->EmitGround(0,0,1);
    }
    delete rl;
  }
  
  rules.ClearIterator();
  // and the normal rules
  if ((status != DM_DOMAIN) || (sys_data.print_domains == PR_ALL)) {
    
    while ((rl = rules.Iterate())) {
      if ((sys_data.print_domains == PR_ALL) && rl->ground)
	rl->EmitGround(0,0,0);
      else
	rl->GroundRule(GRD_RULE);
    }
    while ((rl = special_rules.Iterate()))
      rl->GroundSpecialRule(GRD_RULE);
    
  }
  
}

void Predicate::CalculateDomain()
{
  debug(DBG_GROUND, 2, "Calculating domain for %s", name);
  Rule *rl = NULL;

  // don't change status of internal range predicates
  if (status != DM_INTERNAL)
    SetStatus(DM_DOMAIN);
  
  // ground all rules one at time
  rules.ClearIterator();
  while ((rl = rules.Iterate())) {
    rl->GroundRule(GRD_DOMAIN);
  }

}
      

void Predicate::EmitAll()
{
  long i = 0;
 
  for( i = 0; i < predicate_table->Size(); i++)
    predicates[i]->Emit();
}

void Predicate::PrintAllRules()
{
  long i = 0;
  Rule *rl = NULL;
  
  for (i = 0; i < predicate_table->Size(); i++) {
    while ((rl = predicates[i]->rules.Iterate())) {
      rl->PrintRule();
    }
  }
}


RestrictType Predicate::CheckAllRestricts()
{
  RestrictType rt = RT_STRONG, tmp = RT_STRONG;
  long i = 0;
  
  for( i = 0; i < predicate_table->Size(); i++){
    tmp = predicates[i]->CheckRestrict(1);
    switch (tmp) {
    case RT_NONE:
      rt = RT_NONE;
      break;
    case RT_WEAK:
      if (rt > RT_WEAK)
	rt = RT_WEAK;
      break;
    default:
      break;
    }
  }
  return rt;
}
	

long Predicate::DefinePredicate(char *nm, int ar, long lineno)
{
  long p = -1;

  p = predicate_table->Insert(nm, ar);
  if (p < 0)
    int_error("unknown predicate '%s'", nm);
  if (predicates[p]->pred == p)
    return p;  // already defined

  if (!strcmp(nm, FALSE_NAME))
    predicates[p]->SetStatus(DM_FALSE);
  
  dependency_graph->AddNode(p);
  predicates[p]->SetName(nm);
  predicates[p]->SetPred(p);
  predicates[p]->SetArity(ar);
  predicates[p]->first_line = lineno;
  if (sys_data.use_regular_models && !strchr(nm, '\'')) {
    if (!strcmp(nm, FALSE_NAME))
      predicates[p]->complement = p;
    else
      DefineComplement(p);
  }

  // check if there is a predicate with same name but different arity 
  if (sys_data.warnings & WARN_ARITY) {
    long p2 = predicate_table->CheckName(nm, ar);
    if (p2 >= 0) {
      warn(WARN_ARITY, predicates[p2]->first_line,
	   "predicate '%s' is used with %d argument%s at line %ld, while " 
	   "it is also used with %d argument%s at line %ld.",
	   nm, predicates[p2]->Arity(),
	   (predicates[p2]->Arity() == 1) ? "" : "s",
	   predicates[p2]->first_line,
	   ar,
	   (ar == 1) ? "" : "s",
	   lineno);
      predicates[p2]->is_warned = 1;
    }
  }
  
  return p;
}

// generate a dummy head that cannot be true in any case. 
long Predicate::ConstructDummy()
{
  static char *buf = 0;
  int first = 0;
  
  if (!buf) {
    buf = strdup("_dummy");
    first = 1;
  }
  long p = DefinePredicate(buf, 0, 1);
  if (first) {
    Literal *lt = new Literal();
    lt->pred = p;
    lt->negative = 1;
    compute->rl->AddLiteral(lt);
  }
  return p;
}

long Predicate::DefineSystemPredicate(DomainType tp)
{
  static char buf[RANGE_MAX_NAME] = { 0 }; 
  const int pos = 4;
  int val = 0;
  long pred = -1;
  if (!buf[0]) {
    sprintf(buf, "_int");
  }
  
  val = sprintf(&buf[pos], "%ld", sys_data.internal_atoms++);
  if (val + pos > RANGE_MAX_NAME)
    int_error("too many special rules","");

  pred =  DefinePredicate(strdup(buf), 0, 0);
  predicates[pred]->SetStatus(tp);
  return pred;
}

// initialize a complementary predicate to allow regular model
// semantics to work 
void Predicate::DefineComplement(long pred)
{
  // sanity check
  if (predicates[pred]->pred != pred)
    int_error("trying to define a complement of unexisting predicate","");

  long comp = -1;
  Predicate *original = predicates[pred];

  // already defined?
  if (original->complement >= 0)
    return;

  char *new_name = complement_name(original->Name());
  if (!new_name) // complemented already 
    return; 
  comp = DefinePredicate(strdup(new_name), original->Arity(), original->first_line);
  
  original->complement = comp;
  predicates[comp]->complement = pred;

  dependency_graph->AddEdge(comp, pred);
  dependency_graph->AddEdge(pred, comp);
  
  return;
}

long Predicate::Size()
{
  if (!atoms)
    return 0;
  else
    return atoms->Size();
}

void Predicate::EmitComplements()
{
  long at = 0;
  long compl = 0;
  long f = 0;
  char *at_st = NULL;
  char *compl_st = NULL;


  //  if (!false_lit) {
  //    f = DefinePredicate(FALSE_NAME, 0, 0);
  //    false_lit = new Literal();
  //    false_lit->pred = f;
  //    
  //    f = atom_table->Insert(FALSE_NAME);
  //  }
  for (at = 0; at < atom_table->Size(); at++) {
    at_st = atom_table->symbols[at];
    compl_st = complement_name(at_st);

    if (compl_st) {
      if (sys_data.emit_text) {
	fprintf(sys_data.output_file, "%s :- %s.\n", compl_st, at_st);
	if (sys_data.regular_level == REGULAR_ALL_CONSTRAINTS) {
	  fprintf(sys_data.output_file, "%s :- %s.\n", at_st ,
		  compl_st);
	}
      } else {
	compl = atom_table->Lookup(compl_st);
	if (compl < 0) {
	  int_error("missing complement", "");
	}
	
	if (sys_data.output_version >= 2) {
	  fprintf(sys_data.output_file, "%d %ld 1 0 %ld\n",
		  BASICRULE, compl+1, at+1);
	  if (sys_data.regular_level == REGULAR_ALL_CONSTRAINTS) {
	    fprintf(sys_data.output_file, "%d %ld 1 0 %ld\n",
		    BASICRULE, at+1, compl+1);
	  }
	} else {
	  fprintf(sys_data.output_file, "%ld  1 %ld 0\n",
		  compl+1, at +1);
	  if (sys_data.regular_level == REGULAR_ALL_CONSTRAINTS) {
	    fprintf(sys_data.output_file, "%ld  1 %ld 0\n",
		    at +1, compl+1);
	  }
	}
      }
    }
  }
}


void Predicate::CheckUnsatisfiable()
{
  for (long i = 0; i < predicate_table->Size(); i++) {
    if ((*predicates[i]->Name() != '_') && !predicates[i]->has_rule) {
      warn(WARN_UNSAT, predicates[i]->first_line,
	   "predicate '%s' doesn't occur in any rule head.",
	   predicates[i]->Name());
    }
  }
}
