// Copyright 1998 by Patrik Simons
// This program is free software; you can redistribute it and/or
// modify it under the terms of the GNU General Public License
// as published by the Free Software Foundation; either version 2
// of the License, or (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program; if not, write to the Free Software
// Foundation, Inc., 59 Temple Place - Suite 330, Boston,
// MA 02111-1307, USA.
//
// Patrik.Simons@hut.fi
#include <string.h>
#include <float.h>
#include <limits.h>
#include <assert.h>
#include "atomrule.h"
#include "program.h"
#include "tree.h"
#include "api.h"

Api::list::list ()
{
  top = 0;
  size = 32;
  atoms = new Atom *[size];
  weights = new Weight[size];
}

Api::list::~list ()
{
  delete[] atoms;
  delete[] weights;
}

void
Api::list::push (Atom *a, Weight w)
{
  if (top == size)
    grow ();
  atoms[top] = a;
  weights[top] = w;
  top++;
}

void
Api::list::reset ()
{
  top = 0;
}

void
Api::list::grow ()
{
  long sz = size*2;
  Atom **atom_array = new Atom *[sz];
  Weight *weight_array = new Weight[sz];
  for (int i = 0; i < size; i++)
    {
      atom_array[i] = atoms[i];
      weight_array[i] = weights[i];
    }
  size = sz;
  delete[] atoms;
  atoms = atom_array;
  delete[] weights;
  weights = weight_array;
}

Api::Api (Program *p)
  : program (p)
{
  tree = 0;
  pointer_to_tree = 0;
}

Api::~Api ()
{
  delete pointer_to_tree;
}

inline long
Api::size (list &l)
{
  return l.top;
}

Atom *
Api::new_atom ()
{
  Atom *a = new Atom (program);
  program->atoms.push (a);
  program->number_of_atoms++;
  return a;
}

void
Api::set_compute (Atom *a, bool pos)
{
  assert (a);
  if (pos)
    a->computeTrue = true;
  else
    a->computeFalse = true;
}

void
Api::reset_compute (Atom *a, bool pos)
{
  assert (a);
  if (pos)
    a->computeTrue = false;
  else
    a->computeFalse = false;
}

void
Api::set_name (Atom *a, const char *s)
{
  assert (a);
  if (a->name && tree)
    tree->remove (a);
  delete[] a->name;
  if (s)
    {
      a->name = strcpy (new char[strlen (s)+1], s);
      if (tree)
	tree->insert (a);
    }
  else
    a->name = 0;
}

void
Api::remember ()
{
  if (pointer_to_tree == 0)
    pointer_to_tree = new Tree;
  tree = pointer_to_tree;
}

void
Api::forget ()
{
  tree = 0;
}

Atom *
Api::get_atom (const char *name)
{
  if (pointer_to_tree)
    return pointer_to_tree->find (name);
  else
    return 0;
}

void
Api::begin_rule (RuleType t)
{
  type = t;
  atleast_weight = WEIGHT_MIN;
  atleast_body = 0;
  atleast_head = 1;
  maximize = false;
}

void
Api::end_rule ()
{
  long i, n;
  switch (type)
    {
     case BASICRULE:
      {
	assert (size (head) == 1);
	BasicRule *r = new BasicRule ();
	program->number_of_heads++;
	r->type = BASICRULE;
	program->rules.push (r);
	program->number_of_rules++;
	Atom *a = head.atoms[0];
	a->headof++;
	r->head = a;
	n = size (nbody) + size (pbody);
	if (n != 0)
	  r->nbody = new Atom *[n];
	else
	  r->nbody = 0;
	r->end = r->nbody+n;
	r->pend = r->end;
	program->size_of_program += n+1;
	r->lit = n;
	for (i = 0; i < size (nbody); i++)
	  {
	    r->nbody[i] = nbody.atoms[i];
	    r->nbody[i]->negScore++;
	    r->nbody[i]->isnant = true;
	  }
	r->nend = r->nbody + size (nbody);
	r->pbody = r->nend;
	r->upper = size (pbody);
	for (i = 0; i < size (pbody); i++)
	  {
	    r->pbody[i] = pbody.atoms[i];
	    r->pbody[i]->posScore++;
	  }
	break;
      }
    case CONSTRAINTRULE:
      {
	assert (size (head) == 1);
	program->number_of_heads++;
	ConstraintRule *r = new ConstraintRule ();
	r->type = CONSTRAINTRULE;
	program->rules.push (r);
	program->number_of_rules++;
	Atom *a = head.atoms[0];
	a->headof++;
	r->head = a;
	n = size (nbody) + size (pbody);
	if (n != 0)
	  r->nbody = new Atom *[n];
	else
	  r->nbody = 0;
	r->end = r->nbody+n;
	r->pend = r->end;
	program->size_of_program += n+1;
	r->lit = atleast_body;
	r->inactive = atleast_body - n;
	for (i = 0; i < size (nbody); i++)
	  {
	    r->nbody[i] = nbody.atoms[i];
	    r->nbody[i]->negScore++;
	    r->nbody[i]->isnant = true;
	  }
	r->nend = r->nbody + size (nbody);
	r->pbody = r->nend;
	r->upper = atleast_body - size (nbody);
	for (i = 0; i < size (pbody); i++)
	  {
	    r->pbody[i] = pbody.atoms[i];
	    r->pbody[i]->posScore++;
	  }
	break;
      }
    case GENERATERULE:
      {
	assert (size (head) >= 2);
	GenerateRule *r = new GenerateRule ();
	r->type = GENERATERULE;
	program->rules.push (r);
	program->number_of_rules++;
	long heads = size (head);
	long body = size (pbody);
	program->number_of_heads += heads;
	r->head = new Atom *[heads+body];
	r->hend = r->head+heads;
	r->pbody = r->hend;
	program->size_of_program += 2*heads;   // heads + nbody
	r->neg = heads - atleast_head;
	r->inactiveNeg = -atleast_head;
	for (i = 0; i < size (head); i++)
	  {
	    r->head[i] = head.atoms[i];
	    r->head[i]->headof++;
	    r->head[i]->negScore++;
	    r->head[i]->isnant = true;
	  }
	r->pend = r->pbody + body;
	r->end = r->pend;
	program->size_of_program += body;
	r->pos = body;
	r->upper = body;
	r->inactivePos = 0;
	for (i = 0; i < size (pbody); i++)
	  {
	    r->pbody[i] = pbody.atoms[i];
	    r->pbody[i]->posScore++;
	  }
	break;
      }
    case CHOICERULE:
      {
	assert (size (head) >= 1);
	ChoiceRule *r = new ChoiceRule ();
	r->type = CHOICERULE;
	program->rules.push (r);
	program->number_of_rules++;
	long heads = size (head);
	program->number_of_heads += heads;
	n = heads + size (nbody) + size (pbody);
	program->size_of_program += n;
	r->head = new Atom *[n];
	r->hend = r->head+heads;
	r->end = r->head+n;
	r->pend = r->end;
	for (i = 0; i < heads; i++)
	  {
	    r->head[i] = head.atoms[i];
	    r->head[i]->headof++;
	    r->head[i]->isnant = true;  // Implicit
	  }
	r->nbody = r->hend;
	r->lit = size (nbody) + size (pbody);
	for (i = 0; i < size (nbody); i++)
	  {
	    r->nbody[i] = nbody.atoms[i];
	    r->nbody[i]->negScore++;
	    r->nbody[i]->isnant = true;
	  }
	r->nend = r->nbody + size (nbody);
	r->pbody = r->nend;
	r->upper = size (pbody);
	for (i = 0; i < size (pbody); i++)
	  {
	    r->pbody[i] = pbody.atoms[i];
	    r->pbody[i]->posScore++;
	  }
	break;
      }
    case WEIGHTRULE:
      {
	assert (size (head) == 1);
	program->number_of_heads++;
	WeightRule *r = new WeightRule ();
	r->type = WEIGHTRULE;
	program->rules.push (r);
	program->number_of_rules++;
	Atom *a = head.atoms[0];
	a->headof++;
	r->head = a;
	r->atleast = atleast_weight;
	r->maxweight = 0;
	r->minweight = 0;
	n = size (nbody) + size (pbody);
	if (n != 0)
	  {
	    r->body = new Atom *[n];
	    r->weight = new Weight[n];
	    r->positive = new bool[n];
	    r->reverse = new Follows *[n];
	  }
	else
	  {
	    r->body = 0;
	    r->weight = 0;
	    r->positive = 0;
	    r->reverse = 0;
	  }
	r->bend = r->body+n;
	r->end = r->bend;
	r->max = r->body;
	r->min = r->body;
	r->max_shadow = r->body;
	r->min_shadow = r->body;
	program->size_of_program += n+1;
	for (i = 0; i < size (nbody); i++)
	  {
	    r->body[i] = nbody.atoms[i];
	    r->body[i]->negScore++;
	    r->body[i]->isnant = true;
	    r->weight[i] = nbody.weights[i];
	    r->positive[i] = false;
	    r->reverse[i] = 0;
	  }
	for (long j = 0; j < size (pbody); j++, i++)
	  {
	    r->body[i] = pbody.atoms[j];
	    r->body[i]->posScore++;
	    r->weight[i] = pbody.weights[j];
	    r->positive[i] = true;
	    r->reverse[i] = 0;
	  }
	break;
      }
    case OPTIMIZERULE:
      {
	program->number_of_heads++;
	OptimizeRule *r = new OptimizeRule ();
	r->type = OPTIMIZERULE;
	program->rules.push (r);
	program->number_of_rules++;
	r->next = program->optimize;
	program->optimize = r;
	r->maxweight = 0;
	r->minweight = 0;
	r->maxoptimum = WEIGHT_MIN;
	r->minoptimum = WEIGHT_MAX;
	if (maximize)
	  r->maximize = true;
	else
	  r->maximize = false;
	n = size (nbody) + size (pbody);
	if (n != 0)
	  {
	    r->nbody = new Atom *[n];
	    r->weight = new Weight[n];
	  }
	else
	  {
	    r->nbody = 0;
	    r->weight = 0;
	  }
	r->end = r->nbody+n;
	program->size_of_program += n;
	for (i = 0; i < size (nbody); i++)
	  {
	    r->nbody[i] = nbody.atoms[i];
	    r->nbody[i]->negScore++;
	    r->weight[i] = nbody.weights[i];
	  }
	r->nend = r->nbody + size (nbody);
	r->pbody = r->nend;
	r->pend = r->pbody + size (pbody);
	for (long j = 0; j < size (pbody); j++, i++)
	  {
	    r->pbody[j] = pbody.atoms[j];
	    r->pbody[j]->posScore++;
	    r->weight[i] = pbody.weights[j];
	  }
	break;
      }
    default:
      break;
    }
  pbody.reset ();
  nbody.reset ();
  head.reset ();
}

void
Api::add_head (Atom *a)
{
  assert (a);
  head.push (a);
}

void
Api::add_body (Atom *a, bool pos)
{
  assert (a);
  if (pos)
    pbody.push (a);
  else
    nbody.push (a);
}

void
Api::add_body (Atom *a, bool pos, Weight w)
{
  assert (a && (type == OPTIMIZERULE || w >= 0));
  if (pos)
    pbody.push (a, w);
  else
    nbody.push (a, w);
}

void
Api::change_body (long i, bool pos, Weight w)
{
  assert (type == OPTIMIZERULE || w >= 0);
  if (pos)
    {
      assert (0 <= i && i < size (pbody));
      pbody.weights[i] = w;
    }
  else
    {
      assert (0 <= i && i < size (nbody));
      nbody.weights[i] = w;
    }
}

void
Api::set_atleast_weight (Weight w)
{
  atleast_weight = w;
}

void
Api::set_atleast_body (long n)
{
  atleast_body = n;
}

void
Api::set_atleast_head (long n)
{
  atleast_head = n;
}

void Api::maximize_rule (bool pos)
{
  maximize = pos;
}

void
Api::done ()
{
  // Set up atoms
  for (Node *n = program->atoms.head (); n; n = n->next)
    {
      Atom *a = n->atom;
      if (a->headof + a->posScore + a->negScore)
	{
	  a->head = new Follows[a->headof + a->posScore + a->negScore];
	  a->endHead = a->head + a->headof;
	  a->pos = a->endHead;
	  a->endPos = a->pos + a->posScore;
	  a->endUpper = a->endPos;
	  a->neg = a->endPos;
	  a->endNeg = a->neg + a->negScore;
	  a->end = a->endNeg;
	}
    }
  for (Node *n = program->rules.head (); n; n = n->next)
    n->rule->setup ();
  for (Node *n = program->atoms.head (); n; n = n->next)
    n->atom->head -= n->atom->headof;
}
