#!/usr/bin/env python
# Author: Kevin Gori
# Date: 18 Aug 2014
from __future__ import print_function
import dendropy as dpy
from dendropy.utility.error import DataParseError
import os
import sys


class SupportValueError(Exception):
    pass


def collapse(tree, threshold, keep_lengths=True, support_key=None, length_threshold=0.0):
    try:
        t = dpy.Tree().clone_from(tree) # dendropy 3 syntax
    except AttributeError:
        t = tree.clone() # dendropy 4 syntax
    to_collapse = []
    for node in t.postorder_node_iter():
        if node.is_leaf():
            if node.edge_length < length_threshold:
                node.edge_length = 0
            continue
        if node is t.seed_node:
            continue
        try:
            if support_key:
                support = float(node.annotations.get_value(support_key))
                node.label = support
            else:
                support = float(node.label)
        except TypeError as e:
            raise SupportValueError('Inner node with length {} has no support value'.format(node.edge_length), e)
        except ValueError as e:
            raise SupportValueError(
                'Inner node with length {} has a non-numeric support value {}'.format(node.edge_length), e)
        if support < threshold:
            to_collapse.append(node.edge)
        elif node.edge_length < length_threshold:
            to_collapse.append(node.edge)

    for edge in to_collapse:
        if keep_lengths:
            for child in edge.head_node.child_nodes():
                child.edge.length += edge.length
        edge.collapse()
    return t


def parse_args():
    """ Parses command line arguments """
    import argparse

    help_msgs = dict(tree='REQUIRED: path to the tree file',
                     threshold=('OPTIONAL: support value threshold - '
                                'nodes with support below this value will be collapsed. DEFAULT=%(default)s'),
                     length_threshold=('OPTIONAL: edge length threshold - '
                                       'edges shorter than this value will be collapsed. DEFAULT=%(default)f'),
                     format='OPTIONAL: file format of the input tree - can be newick or nexus. DEFAULT=%(default)s',
                     keep_lengths=('OPTIONAL: if set, add the length of the branch being removed to the child branches,'
                                   ' so that root-to-tip distances are preserved. DEFAULT=%(default)s'),
                     outfile='OPTIONAL: write the result to disk. DEFAULT=print to terminal',
                     support_key=('OPTIONAL: the key that retrieves the support value, if the support value'
                                  ' is given in some extended newick format (i.e. as key-value pair inside a comment)'),
                     output_format=('OPTIONAL: file format of output tree - newick or nexus. DEFAULT: same as input.'
                                    ' CAVEAT: nexus format may be incompatible with FigTree'))
    parser = argparse.ArgumentParser()
    parser.add_argument('tree', type=str, help=help_msgs['tree'])
    parser.add_argument('--threshold', type=float, default=0.5, help=help_msgs['threshold'])
    parser.add_argument('--length_threshold', type=float, default=0.0001, help=help_msgs['length_threshold'])
    parser.add_argument('--format', type=str, default='newick',
                        choices=['nexus', 'newick'], help=help_msgs['format'])
    parser.add_argument('--output_format', type=str, help=help_msgs['output_format'])
    parser.add_argument('--keep_lengths', action='store_true', help=help_msgs['keep_lengths'])
    parser.add_argument('--outfile', type=str, help=help_msgs['outfile'])
    parser.add_argument('--test', action='store_true',
                        help=argparse.SUPPRESS)  # just used as a quick debug test - not relevant for user
    parser.add_argument('--support_key', type=str, help=help_msgs['support_key'])
    return parser.parse_args()


def read_tree_file(tree_file, tree_format):
    filename = tree_file
    if not os.path.exists(filename):
        raise IOError('File not found: {}'.format(filename))
    try:
        t = dpy.Tree.get_from_path(filename, tree_format, extract_comment_metadata=True)
    except DataParseError:
        file_format, = ({'nexus', 'newick'} - {tree_format})
        t = dpy.Tree.get_from_path(filename, file_format, extract_comment_metadata=True)
    return t


def test():
    n = ('((A:1,(B:0.5,C:0.5)80:0.1)30:1,'
         '(D:1.2,(E:0.7,F:0.7)20:0.5)60:0.8)100:0;')
    t = dpy.Tree.get_from_string(n, 'newick')
    t.print_plot(plot_metric='length')
    fh = sys.stdout
    zero = collapse(t, 25)
    zero.print_plot(plot_metric='length')
    zero.write(file=fh, schema='newick', suppress_rooting=True)
    one = collapse(t, 50)
    one.print_plot(plot_metric='length')
    one.write(file=fh, schema='newick', suppress_rooting=True)
    two = collapse(t, 65)
    two.print_plot(plot_metric='length')
    two.write(file=fh, schema='newick', suppress_rooting=True)
    three = collapse(t, 65, False)
    three.print_plot(plot_metric='length')
    three.write(file=fh, schema='newick', suppress_rooting=True)
    four = collapse(t, 25, True, length_threshold=0.11)
    four.print_plot(plot_metric='length')
    four.write(file=fh, schema='newick', suppress_rooting=True)


def get_file_handle(outfile):
    if outfile == 'stdout' or outfile is None:
        return sys.stdout
    else:
        return open(outfile, 'w')


def main():
    args = parse_args()
    if args.output_format is None:
        args.output_format = args.format
    if args.test:
        return test()
    tree = read_tree_file(args.tree, args.format)
    collapsed_tree = collapse(tree, args.threshold, args.keep_lengths, args.support_key, args.length_threshold)
    with get_file_handle(args.outfile) as handle:
        collapsed_tree.write(file=handle, schema=args.output_format, suppress_rooting=True)


if __name__ == '__main__':
    sys.exit(main())
