# coding: utf-8

"""Encode Newick trees into ASP rules for generating quartets.

Copyright (C) 2015 Laura Koponen

This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with this program.  If not, see <http://www.gnu.org/licenses/>.

The script is written for Python 2.7.3.

It belongs to the software package published on http://research.ics.aalto.fi/software/asp/ and related to the paper
Koponen, L., Oikarinen, E., Janhunen, T., Säilä, L.:
'Optimizing Phylogenetic Supertrees Using Aswer Set Programming'
Theory and Practice of Logic Programming, 15(4-5):604-619.
"""

from __future__ import print_function

import argparse
import os
import re
import sys


class TreeException(BaseException):
	pass


class Tree:
	"""A rooted, directed tree."""

	def __init__(self, root):
		self.__root = root
		root.tree = self
		root.index = 1

		self.__nodes = [root,]
		self.__edges = []
		self.__leaves = set(self.__nodes)

	def add_child(self, parent, child, edge_length=None):
		if not parent in self.__nodes:
			raise TreeException("Cannot add child node to node not in this tree")
		if child in self.__nodes:
			raise TreeException("Node may not have multiple parents in a tree")
		edge = Edge(parent, child, edge_length)
		self.__edges.append(edge)
		self.__nodes.append(child)

		child.parent = parent
		child.index = len(self.__nodes)

		parent.children.append(child)

		# If it is still there, anyway
		self.__leaves.discard(parent)
		self.__leaves.add(child)

	@property
	def edges(self):
		return tuple(self.__edges)

	@property
	def nodes(self):
		return tuple(self.__nodes)

	@property
	def leaves(self):
		return tuple(self.__leaves)

	@property
	def root(self):
		return self.__root

	def reindex(self):
		"""Sorts the node indices so that leaf nodes have the smallest indices."""
		next_index = len(self.__nodes)

		innerqueue = []
		leafqueue = []

		innerqueue.append(self.__root)

		while innerqueue:
			parent = innerqueue.pop(0)
			parent.index = next_index
			next_index -= 1
			for child in parent.children:
				if child not in self.__leaves:
					innerqueue.append(child)
				else:
					leafqueue.append(child)

		while leafqueue:
			leaf = leafqueue.pop(0)
			leaf.index = next_index
			next_index -= 1


class Node:
	def __init__(self, name=None):
		self.name = name
		self.index = None
		self.tree = None
		self.parent = None
		self.children = []


class Edge:
	def __init__(self, source, dest, edge_length=1.0):
		self.source = source
		self.dest = dest
		self.weight = edge_length


def read_tree(tree_str):
	"""Parses a Newick-format tree string and initialises a corresponding tree object.

	Parameters
	----------
	tree_str : string

	Returns
	-------
	tree : Tree object
		A rooted, directed tree.
	"""

	# Remove newlines from the string
	tree_str = tree_str.translate(None, "\n\r")

	# Read root node

	# Last character should be a ';', but this isn't always the case. We'll drop
	# it here if there is one, anyway.
	if tree_str[-1] == ';':
		tree_str = tree_str[:-1]

	# Last return parameter is the length of the edge leading down to the root
	root_neighbours, root_name, _ = read_node(tree_str)

	# Initialise a one-node tree
	root = Node(root_name)
	tree = Tree(root)

	queue = []

	# Add the neighbours of the root to a queue
	for neighbour_str in root_neighbours:
		queue.append((root, neighbour_str))

	# For all nodes in the tree, process them and their children similarly
	while queue:
		#print(queue)
		parent, child_str = queue.pop()

		grandchildren, child_name, edge_length = read_node(child_str)

		child = Node(child_name)
		#print("Adding node %s with %d children" % (child_name, len(grandchildren)))
		tree.add_child(parent, child, edge_length)

		for grandchild in grandchildren:
			queue.append((child, grandchild))

	tree.reindex()

	return tree


def read_node(node_str):
	"""Parse a string referring to a single node.

	Examples.
	>>> read_node("(A:5.0,C:3.0,E:4.0):5.0")
	(['A:5.0', 'C:3.0', 'E:4.0'], None, 5.0)
	>>> read_node("(,(,,),)")
	(['', '(,,)', ''], None, None)
	>>> read_node("dog:25.46154")
	([], 'dog', 25.46154)

	Returns
	-------
	neighbours : list of str
		Strings corresponding to the children of this node ([] if there are no children)
	name : string
		Name of this node (None if not given)
	edge_length : float
		Length of edge leading up to this node (None if not given)
	"""
	neighbours = []
	name = None
	edge_length = None

	if node_str and node_str[0] == "(":
		# Find the ending parenthesis
		i = len(node_str) - 1
		while node_str[i] != ")":
			i -= 1

		# The parameter is the string between the starting and the ending parenthesis
		neighbours = split_neighbours(node_str[1:i])

		# Remove the whole parenthesis part
		node_str = node_str[i+1:]

	# Check if node has a name
	if node_str and node_str[0] != ":":
		name = node_str[0]
		for i in xrange(1, len(node_str)):
			if node_str[i] == ":":
				# Cut just before ":"
				node_str = node_str[i:]
				break
			name += node_str[i]

	# Check if node has an edge length
	# This assumes that solo ":"s without a number following them are not allowed.
	if node_str and node_str[0] == ":":
		edge_length = float(node_str[1:])

	return neighbours, name, edge_length


def split_neighbours(neighbour_str):
	"""Split a neighbour string into strings representing the child nodes.

	Examples.
	>>> split_neighbours("raccoon:19.19959,bear:6.80041")
	['raccoon:19.19959', 'bear:6.80041']
	>>> split_neighbours("B,(A,D),C")
	['B', '(A,D)', 'C']
	>>> split_neighbours("")
	['']

	Returns
	-------
	neighbours : list of str
	"""

	neighbours = []

	# Stack for characters read
	tmp_list = []
	# Counter to check if we are still inside a node
	open_parentheses = 0

	for i in xrange(len(neighbour_str)):

		if neighbour_str[i] == "," and not open_parentheses:
			# Completed reading previous neighbour
			neighbours.append("".join(tmp_list))
			tmp_list = []
		else:
			if neighbour_str[i] == "(":
				open_parentheses += 1
			elif neighbour_str[i] == ")":
				open_parentheses -= 1
			tmp_list.append(neighbour_str[i])

	# Add the last node (other nodes were added after reading a comma, last node is not followed by a comma)
	neighbours.append("".join(tmp_list))

	return neighbours



def encode(filename, to_stdout=True):
	"""Encodes a Newick-format tree as a set of ASP rules.

	Note that the encoded trees are not in the canonical format like
	the output trees in the paper.

	Parameters
	----------
	filename : Name of Newick-format file containing exactly one tree.
		The filename is used as an identifier, so for all files 
		input together, the output of os.path.basename(filename) 
		should be distinct -- otherwise the rules will overlap and
		cause errors later.
	to_stdout : boolean
		Decides whether the output should be printed out
		immediately. If no, the output is just returned as a
		string.
	"""

	with open(filename) as f:
		# This assumes there is just one tree in the entire file
		tree_str = f.read()

	# Derive tree name from filename
	tree_name = os.path.splitext(os.path.basename(filename))[0].lower()
	tree_name = re.sub(r'-', '_', tree_name)
	tree_name = re.sub(r'\W+', '', tree_name)

	# Try to read the tree
	tree = read_tree(tree_str)

	num_nodes = len(tree.nodes)
	num_leaves = len(tree.leaves)

	# This is where all the ASP rules will be saved
	rules_list = []

	# Nodes -- both leaves and inner nodes
	rules_list.append("node(1..%d,%s)." % (num_nodes, tree_name))
	rules_list.append("leaf(1..%d,%s)." % (num_leaves, tree_name))
	rules_list.append("innernode(X, %s) :- node(X, %s), not leaf(X, %s)." % (tree_name, tree_name, tree_name))

	# Print edges
	for node in tree.nodes:
		neighbours = [str(child.index) for child in node.children]
		if neighbours:
			rules_list.append("edge(%d,%s,%s)." % (node.index, ";".join(neighbours), tree_name))

	# Match nodes with names
	# Remember that variable names are capital letters and constants are lowercase!
	for leaf in tree.leaves:
		rules_list.append("name(%d, %s, %s)." % (leaf.index, leaf.name.lower(), tree_name))

	rules_str = "\n".join(rules_list)

	# If desired, print here
	if to_stdout:
		print(rules_str)

	return rules_str


if __name__ == "__main__":

	parser = argparse.ArgumentParser(description="Encode Newick trees in ASP rules.")
	# The parser will collect a list of zero or more filenames.
	# (The function 'encode' uses the filename for tagging the rules as originating from a particular tree.)
	parser.add_argument('file', metavar='FILE', type=str, nargs='*')

	args = parser.parse_args()

	# Try to encode each tree
	# If an error occurs, print the error separately rather than corrupting the rule printout
	rules_list = []

	try:
		for filename in args.file:
			rules_list.append(encode(filename, to_stdout=False))

	except Exception as e:
		print(e)

	else:
		print("\n".join(rules_list))

