/* -*- Mode: C; indent-tabs-mode: t; c-basic-offset: 2; tab-width: 2 -*-  */
/*
 * relalgebra.cc
 * Copyright (C) 2016 Shahab Tasharrofi <shahab@tasharrofi.net>
 *
 * grounder-generator is free software: you can redistribute it and/or modify it
 * under the terms of the GNU General Public License as published by the
 * Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 * 
 * grounder-generator is distributed in the hope that it will be useful, but
 * WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
 * See the GNU General Public License for more details.
 * 
 * You should have received a copy of the GNU General Public License along
 * with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

#include "relalgebra.h"

unordered_map<compact_string, rel_algebra_node, compact_string_hash> database;

unordered_map<NodeData, NodeIndexType, NodeDataHash> rel_algebra::nodeToIndex;
vector<NodeData> rel_algebra::nodeRepository;
unordered_map<JoinIndexType, NodeIndexType, JoinIndexTypeHash> rel_algebra::joinCache;
unordered_map<NodeIndexType, NodeIndexType> rel_algebra::complementationCache;

JoinIndexType rel_algebra::makeJoinIndexPair(const Table t1, const Table t2)
{
	if (t1.index < t2.index)
		return JoinIndexType(t1.index, t2.index);
	else
		return JoinIndexType(t2.index, t1.index);
}

Table rel_algebra::unwindNegation(const Table t)
{
	if (!isNegated(t))
		return t;

	auto it = complementationCache.find(complement(t).index);
	if (it != complementationCache.end())
		return Table(it->second);

	return canonical(getVariable(t), getValue(t), unwindNegation(complement(getThenBranch(t))), unwindNegation(complement(getElseBranch(t))));
}

Table rel_algebra::getNode(VariableType variableIndex, ValueType value, Table thenNode, Table elseNode)
{
	bool negatedResult = false;
	if (isNegated(elseNode))
	{
		negatedResult = true;
		elseNode = complement(elseNode);
		thenNode = complement(thenNode);
	}

	NodeData nodeData(variableIndex, value, thenNode.index, elseNode.index);

	auto it = nodeToIndex.find(nodeData);
	if (it != nodeToIndex.end())
		return (negatedResult ? complement(Table(it->second)) : Table(it->second));

	NodeIndexType index = nodeRepository.size() << 1;
	nodeRepository.push_back(nodeData);
	nodeToIndex[nodeData] = index;

	return (negatedResult ? complement(Table(index)) : Table(index));
}

Table rel_algebra::canonical(VariableType variableIndex, ValueType value, Table thenNode, Table elseNode)
{
	while (true)
	{
		if (thenNode == elseNode)
			return thenNode;

		if (isInternalNode(thenNode))
		{
			if (getVariable(thenNode) == variableIndex)
			{
				if (getValue(thenNode) < value)
				{
					Table tempNode = canonical(variableIndex, value, getElseBranch(thenNode), elseNode);
					value = getValue(thenNode);
					thenNode = tempNode;
				}
				else
					thenNode = getThenBranch(thenNode);

				continue;
			}

			assert(getVariable(thenNode) > variableIndex);
		}

		if (isInternalNode(elseNode))
		{
			if (getVariable(elseNode) == variableIndex)
			{
				assert(getValue(elseNode) > value);

				if (getThenBranch(elseNode) == thenNode)
					return elseNode;
			}
		}

		return getNode(variableIndex, value, thenNode, elseNode);
	}
}

Table rel_algebra::join(Table t1, Table t2)
{
	if (t1 == falseNode())
		return falseNode();
	if (t2 == falseNode())
		return falseNode();
	if (t1 == trueNode())
		return t2;
	if (t2 == trueNode())
		return t1;
	NodeIndexType xorIndex = t1.index ^ t2.index;
	if (xorIndex == 1)
		return falseNode();
	if (xorIndex == 0)
		return t1;

	if ((getVariable(t1) > getVariable(t2)) || ((getVariable(t1) == getVariable(t2)) && (getValue(t1) > getValue(t2))))
	{
		Table temp = t1;
		t1 = t2;
		t2 = temp;
	}

	VariableType currentVar = getVariable(t1);
	size_t currentStackSize = global_table_stack::size();

	while (true)
	{
		JoinIndexType joinIndices = makeJoinIndexPair(t1, t2);
		auto it = joinCache.find(joinIndices);
		if (it != joinCache.end())
		{
			global_table_stack::push(Table(it->second));
			break;
		}

		Table t1ThenBranch = getThenBranch(t1);
		if (getVariable(t2) == currentVar)
		{
			Table t2ThenBranch = getThenBranch(t2);
			global_table_stack::push(join(t1ThenBranch, t2ThenBranch));
		}
		else
			global_table_stack::push(join(t1ThenBranch, t2));
		global_table_stack::push(t1);
		global_table_stack::push(t2);

		if ((getVariable(t2) == currentVar) && (getValue(t2) == getValue(t1)))
			t2 = getElseBranch(t2);

		t1 = getElseBranch(t1);
		if (isLeafNode(t1))
			break;
		if (isLeafNode(t2))
		{
			Table temp = t1;
			t1 = t2;
			t2 = temp;
			break;
		}
		if ((getVariable(t1) > getVariable(t2)) || ((getVariable(t1) == getVariable(t2)) && (getValue(t1) > getValue(t2))))
		{
			Table temp = t1;
			t1 = t2;
			t2 = temp;
		}
		if (getVariable(t1) != currentVar)
			break;
	}

	if (isLeafNode(t1) || (getVariable(t1) != currentVar))
		global_table_stack::push(join(t1, t2));

	Table tempResult = global_table_stack::pop();
	while (global_table_stack::size() > currentStackSize)
	{
		t2 = global_table_stack::pop();
		t1 = global_table_stack::pop();
		tempResult = canonical(currentVar, getValue(t1), global_table_stack::pop(), tempResult);
		joinCache[makeJoinIndexPair(t1, t2)] = tempResult.index;
	}

	return tempResult;
}

Table rel_algebra::divide(Table keptVariables, Table t)
{
	if (isLeafNode(t))
		return t;

	while (true)
	{
		if (isLeafNode(keptVariables))
			break;
		assert(getVariable(keptVariables) == 0);
		if (getVariable(t) <= getValue(keptVariables))
			break;
		keptVariables = getElseBranch(keptVariables);
	}

	if (keptVariables == trueNode())
		return t; // Since we want to keep all variables, "t" is already good.
	if (keptVariables == falseNode())
		return falseNode(); // Since we want to get rid of all variables, the result is the
												// conjunction of all the leaf nodes of "t". Since t is not a leaf
												// node, it certainly contains at least one false leaf node. Hence,
												// the conjunction is always false.

	Table keptThenBranch = getThenBranch(keptVariables);
	assert(isLeafNode(keptThenBranch));

	VariableType currentVar = getVariable(t);
	size_t currentStackSize = global_table_stack::size();

	while (true)
	{
		if (isLeafNode(t))
			break;
		if (getVariable(t) != currentVar)
			break;

		Table tThenBranch = getThenBranch(t);

		global_table_stack::push(divide(keptVariables, tThenBranch));
		if (keptThenBranch == trueNode())
			global_table_stack::push(t);

		t = getElseBranch(t);
	}

	Table tempResult = project(keptVariables, t);
	while (global_table_stack::size() > currentStackSize)
	{
		if (keptThenBranch == trueNode())
		{
			t = global_table_stack::pop();
			tempResult = canonical(currentVar, getValue(t), global_table_stack::pop(), tempResult);
		}
		else
			tempResult = join(global_table_stack::pop(), tempResult);
	}

	return tempResult;
}

SingleVariableUnitIterator rel_algebra::getSingleVariableUnitIterator(Table t)
{
	return SingleVariableUnitIterator(t);
}

UnitIterator rel_algebra::getUnitIterator(Table t)
{
	return UnitIterator(t);
}

int UnitIterator::findVarPosition(VariableType variable)
{
	int lowerBound = -1;
	int upperBound = iterators.size() - 1;
	while (lowerBound < upperBound)
	{
		int mid = (lowerBound + upperBound + 1) / 2;
		VariableType currentVar = iterators[mid].getCurrentVariable();
		if (currentVar < variable)
			lowerBound = mid;
		else if (currentVar == variable)
			return mid;
		else
			upperBound = mid - 1;
	}
	return lowerBound;
}

void UnitIterator::init()
{
	if (iterators.size() == 0)
		return;
	while (true)
	{
		Table t = iterators[iterators.size() - 1].getCurrentNode();
		assert(!rel_algebra::isEmpty(t));
		if (rel_algebra::isFull(t))
			break;
		if (rel_algebra::getVariable(t) > maxVariable)
			break;
		iterators.push_back(SingleVariableUnitIterator(t));
	}
}

void UnitIterator::next()
{
	while (iterators.size() > 0)
	{
		++iterators[iterators.size() - 1];
		if (iterators[iterators.size() - 1].atEnd())
			iterators.pop_back();
		else
			break;
	}
	init();
}

ValueType UnitIterator::getCurrentValue(VariableType variable)
{
	int index = findVarPosition(variable);
	if ((index < 0) || (iterators[index].getCurrentVariable() != variable) || (iterators[index].isUnbounded()))
	{
		cerr << "Tried to access an unbounded variable!" << endl;
		exit(-1);
	}

	return iterators[index].getCurrentValue();
}

bool UnitIterator::isVariableUnbounded(VariableType variable)
{
	int index = findVarPosition(variable);
	if (index < 0)
		return true;
	if (iterators[index].getCurrentVariable() != variable)
		return true;
	return iterators[index].isUnbounded();
}

Table UnitIterator::operator*()
{
	Table result = getCurrentNode();
	for (auto it = iterators.rbegin(); it != iterators.rend(); it++)
		result = rel_algebra::filterEqualTo(it->getCurrentVariable(), it->getCurrentValue(), result);
	return result;
}

UnitIterator::UnitIterator(Table t, VariableType maxVar) : maxVariable(maxVar)
{
	if (rel_algebra::isFull(t))
	{
		cerr << "Cannot iterate over a full table!" << endl;
		exit(-1);
	}

	if (!rel_algebra::isEmpty(t))
	{
		iterators.push_back(SingleVariableUnitIterator(t));
		init();
	}
}

