#!/usr/bin/env python

"""
A helper script for many matgendb functions.
"""

__author__ = "Shyue Ping Ong"
__copyright__ = "Copyright 2012, The Materials Project"
__version__ = "1.2"
__maintainer__ = "Shyue Ping Ong"
__email__ = "shyue@mit.edu"
__date__ = "Dec 1, 2012"

import datetime
import logging
import multiprocessing
import json
import sys
import argparse
import six

from pymongo import MongoClient, ASCENDING

from pymatgen.apps.borg.queen import BorgQueen

from matgendb import SETTINGS
from matgendb.query_engine import QueryEngine
from matgendb.creator import VaspToDbTaskDrone
from matgendb.dbconfig import DBConfig
from matgendb.util import get_settings, DEFAULT_SETTINGS, MongoJSONEncoder

_log = logging.getLogger("mg")  # parent


def init_db(args):
    settings = DBConfig.ALL_SETTINGS
    defaults = dict(DEFAULT_SETTINGS)
    doc = {}
    print("Please supply the following configuration values")
    print("(press Enter if you want to accept the defaults)\n")
    for k in settings:
        v = defaults.get(k, '')
        val = six.moves.input("Enter {} (default: {}) : ".format(k, v))
        doc[k] = val if val else v
    doc["port"] = int(doc["port"])  # enforce the port as an int
    for k in ["admin_user", "admin_password", "readonly_user",
              "readonly_password"]:
        v = defaults.get(k, '')
        val = six.moves.input("Enter {} (default: {}) : ".format(k, v))
        doc[k] = val if val else v
    with open(args.config_file, "w") as f:
        json.dump(doc, f, indent=4, sort_keys=True)
    print("\nConfiguration written to {}!".format(args.config_file))


def update_db(args):
    FORMAT = "%(relativeCreated)d msecs : %(message)s"

    if args.logfile:
        logging.basicConfig(level=logging.INFO, format=FORMAT,
                            filename=args.logfile[0])
    else:
        logging.basicConfig(level=logging.INFO, format=FORMAT)

    d = get_settings(args.config_file)

    _log.info("Db insertion started at {}.".format(datetime.datetime.now()))
    additional_fields = {"author": args.author, "tags": args.tag}
    drone = VaspToDbTaskDrone(
        host=d["host"], port=d["port"],  database=d["database"],
        user=d["admin_user"], password=d["admin_password"],
        parse_dos=args.parse_dos,
        collection=d["collection"], update_duplicates=args.force_update_dupes,
        additional_fields=additional_fields, mapi_key=d.get("mapi_key", None))
    ncpus = multiprocessing.cpu_count() if not args.ncpus else args.ncpus
    _log.info("Using {} cpus...".format(ncpus))
    queen = BorgQueen(drone, number_of_drones=ncpus)
    queen.parallel_assimilate(args.directory)
    tids = list(map(int, filter(lambda x: x, queen.get_data())))
    _log.info("Db upate completed at {}.".format(datetime.datetime.now()))
    _log.info("{} new task ids inserted.".format(len(tids)))


def optimize_indexes(args):
    d = get_settings(args.config_file)
    c = MongoClient(d["host"], d["port"])
    db = c[d["database"]]
    db.authenticate(d["admin_user"], d["admin_password"])
    coll = db[d["collection"]]
    coll.drop_indexes()
    coll.ensure_index('task_id', unique=True)
    for key in ['unit_cell_formula', 'reduced_cell_formula', 'chemsys',
                'nsites', 'pretty_formula', 'analysis.e_above_hull',
                "icsd_ids"]:
        print("Building {} index".format(key))
        coll.ensure_index(key)
    print("Building nelements and elements compound index")
    compound_index = [('nelements', ASCENDING), ('elements', ASCENDING)]
    coll.ensure_index(compound_index)
    coll.ensure_index(compound_index)


def query_db(args):
    from tabulate import tabulate
    d = get_settings(args.config_file)
    qe = QueryEngine(host=d["host"], port=d["port"], database=d["database"],
                     user=d["readonly_user"], password=d["readonly_password"],
                     collection=d["collection"],
                     aliases_config=d.get("aliases_config", None))
    criteria = None
    if args.criteria:
        try:
            criteria = json.loads(args.criteria)
        except ValueError:
            print("Criteria {} is not a valid JSON string!".format(
                args.criteria))
            sys.exit(-1)

    # TODO: document this 'feature' --dang 4/4/2013
    is_a_file = lambda s: len(s) == 1 and s[0].startswith(':')
    if is_a_file(args.properties):
        with open(args.properties[0][1:], 'rb') as f:
            props = [s.strip() for s in f]
    else:
        props = args.properties

    if args.dump_json:
        for r in qe.query(properties=props, criteria=criteria):
            print(json.dumps(r, cls=MongoJSONEncoder))
    else:
        t = []
        for r in qe.query(properties=props, criteria=criteria):
            t.append([r[p] for p in props])
        print(tabulate(t, headers=props))


class StatsCommands(object):
    """Encapsulate statistics commands.
    """
    #: Constants for names of format options
    FORMAT_YAML, FORMAT_CLEAN = "yaml", "simple"

    def __init__(self, query_engine, output_format=None, latest_prop=None, **extra_kw):
        """Constructor.

        :param query_engine: Connection to Mongo
        :type query_engine: matgendb.query_engine.QueryEngine
        """
        self._qe = query_engine
        self._db, self._coll = query_engine.db, query_engine.collection
        self.indent = ' ' * 4
        # Keywords
        self._latest = latest_prop
        if output_format == self.FORMAT_YAML:
            self._format = self.yaml_format
            self._show_header = True
        elif output_format == self.FORMAT_CLEAN:
            self._format = self.clean_format
            self._show_header = False

    def header(self):
        if not self._show_header:
            return ""
        s = self._format((("Database", self._db.name),
                          ("Collection", self._coll.name),
                          ("Values", "")))
        return s

    def stat(self, name):
        func = 'stat_{}'.format(name)
        return getattr(self, func)()

    def stat_latest(self):
        cur = self._coll.find({}, [self._latest])\
                        .sort(self._latest, -1)\
                        .limit(1)
        try:
            value = "{}".format(cur[0][self._latest])
        except IndexError:
            _log.error("Cannot get latest value: no records with field '{}'"
                       .format(self._latest))
            return ""
        return self._format([("Latest", value)], depth=1)

    def stat_count(self):
        value = "{:d}".format(self._coll.count())
        return self._format([("Count", value)], depth=1)

    def yaml_format(self, data, depth=0):
        rows, ind = [], self.indent * depth
        for k, v in data:
            rows.append("{i}{k}: {v}".format(i=ind, k=k, v=v))
        return '\n'.join(rows) + "\n"

    def clean_format(self, data, depth=0):
        rows = ["{} {}".format(d[0].lower(), d[1]) for d in data]
        return '\n'.join(rows) + "\n"


def db_stats(args):
    """Database and/or collection statistics.
    """
    # Init db.
    cfg = get_settings(args.config_file)
    normalize_userpass(cfg)
    qe = QueryEngine(**cfg)

    # Copy stat args and keywords.
    stats, kw = {}, {}
    for k, v in vars(args).iteritems():
        if k.startswith("st_"):
            stats[k[3:]] = v
        else:
            kw[k] = v

    # Process 'show all'.
    if stats['showall']:
        for k in stats.iterkeys():
            stats[k] = True
    del stats['showall']

    # Output destination.
    w = sys.stdout

    # Run commands.
    cmd = StatsCommands(qe, **kw)
    w.write(cmd.header())
    for k in stats.iterkeys():
        if stats[k]:
            w.write(cmd.stat(k))


def normalize_userpass(cfg):
    """In DB conn. config, normalize user/password from readonly and admin prefixes.
    In the end, there will be only keys 'user' and 'password'.
    """
    for pfx in 'readonly', 'admin':  # in reverse order of priority, to overwrite
        if (pfx + '_user') in cfg and (pfx + '_password') in cfg:
            cfg[QueryEngine.USER_KEY] = cfg[pfx + '_user']
            cfg[QueryEngine.PASSWORD_KEY] = cfg[pfx + '_password']
            del cfg[pfx + '_user']
            del cfg[pfx + '_password']


def init_logging(args):
    """Initialize verbosity
    """
    _log.propagate = False
    hndlr = logging.StreamHandler()
    hndlr.setFormatter(logging.Formatter("[%(levelname)-6s] %(asctime)s %(name)s :: %(message)s"))
    _log.addHandler(hndlr)
    if 'quiet' in args and args.quiet:
        lvl = logging.CRITICAL
    else:
        vb = args.vb if 'vb' in args else 0
        # Level:  default      -v            -vv
        lvl = (logging.WARN, logging.INFO, logging.DEBUG)[min(vb, 2)]
    _log.setLevel(lvl)


if __name__ == "__main__":
    db_file = SETTINGS.get("PMGDB_DB_FILE")
    parser = argparse.ArgumentParser(description="""
    mgdb is a complete command line db management script for pymatgen-db. It
    provides the facility to insert vasp runs, perform queries, and run a web
    server for exploring databases that you create. Type mgdb -h to see the
    various options.

    Author: Shyue Ping Ong
    Version: 3.0
    Last updated: Mar 23 2013""")

    # Parents for all subparsers.
    parent_vb = argparse.ArgumentParser(add_help=False)
    parent_vb.add_argument('--quiet', '-q', dest='quiet', action="store_true",
                           default=False,
                           help="Minimal verbosity.")
    parent_vb.add_argument('--verbose', '-v', dest='vb', action="count",
                           default=0,
                           help="Print more verbose messages to standard error. "
                                "Repeatable. (default=ERROR)")
    parent_cfg = argparse.ArgumentParser(add_help=False)
    parent_cfg.add_argument("-c", "--config", dest="config_file", type=str,
                            default=db_file,
                            help="Config file to use. Generate one using mgdb "
                            "init --config filename.json if necessary. "
                            "Otherwise, the code searches for a db.json. If"
                            "none is found, an no-authentication "
                            "localhost:27017/vasp database and tasks "
                            "collection is assumed.")

    # change db_file to the default "db.json" if it does not exist
    db_file = db_file or "db.json"

    # Init for all subparsers.
    subparsers = parser.add_subparsers()

    # The 'init' subcommand.
    pinit = subparsers.add_parser("init", help="Initialization tools.",
                                  parents=[parent_vb])
    pinit.add_argument("-c", "--config", dest="config_file", type=str,
                       nargs='?', default=db_file,
                       help="Creates an db config file for the database. "
                            "Default filename is db.json.")
    pinit.set_defaults(func=init_db)

    popt = subparsers.add_parser("optimize", help="Optimization tools.")

    popt.add_argument("-c", "--config", dest="config_file", type=str,
                      nargs='?', default=db_file,
                      help="Creates an db config file for the database. "
                           "Default filename is db.json.")
    popt.set_defaults(func=optimize_indexes)

    # The 'insert' subcommand.
    pinsert = subparsers.add_parser("insert", help="Insert vasp runs.",
                                    parents=[parent_vb, parent_cfg])
    pinsert.add_argument("directory", metavar="directory", type=str,
                         default=".", help="Root directory for runs.")
    pinsert.add_argument("-l", "--logfile", dest="logfile", type=str,
                         help="File to log db insertion. Defaults to stdout.")
    pinsert.add_argument("-t", "--tag", dest="tag", type=str, nargs="+",
                         default=[],
                         help="Tag your runs for easier search."
                              " Accepts multiple space-separated tags. E.g.,"
                              " '--tag metal energy'")
    pinsert.add_argument("-f", "--force", dest="force_update_dupes",
                         action="store_true",
                         help="Force update duplicates. This forces the "
                              "analyzer to reanalyze already inserted data.")
    pinsert.add_argument("-d", "--parse_dos", dest="parse_dos",
                         action="store_true",
                         help="Whether to parse the dos.")
    pinsert.add_argument("-a", "--author", dest="author", type=str, nargs=1,
                         default=None,
                         help="Enter a *unique* author field so that you can "
                              "trace back what you ran.")
    pinsert.add_argument("-n", "--ncpus", dest="ncpus", type=int,
                         default=None,
                         help="Number of CPUs to use in inserting. If "
                              "not specified, multiprocessing will use "
                              "the number of cpus detected.")
    pinsert.set_defaults(func=update_db)

    # The 'query' subcommand.
    pquery = subparsers.add_parser("query",
                                   help="Query tools. Requires the "
                                        "use of pretty_table.",
                                   parents=[parent_vb, parent_cfg])
    pquery.add_argument("--crit", dest="criteria", type=str, default=None,
                        help="Query criteria in typical json format. E.g., "
                             "{\"task_id\": 1}.")
    pquery.add_argument("--props", dest="properties", type=str, default=[],
                        nargs='+', required=True,
                        help="Desired properties. Repeatable. E.g., pretty_formula, "
                             "task_id, energy...")
    pquery.add_argument("--dump", dest="dump_json", action='store_true', default=False,
                        help="Simply dump results to JSON instead of a tabular view")
    pquery.set_defaults(func=query_db)

    # The 'stats' subcommand.
    pstats = subparsers.add_parser("stats", help="Database and/or collection statistics.",
                                   parents=[parent_vb, parent_cfg])
    pstats.add_argument("--count", help="Show count of records", dest="st_count", action="store_true")
    pstats.add_argument("--latest", help="Show time of latest record", dest="st_latest", action="store_true")
    #pstats.add_argument("--size", help="Show min/max/avg/total record sizes", dest="", action="store_true")
    pstats.add_argument("--all", help="Show all stats", dest="st_showall", action="store_true")
    pstats.add_argument("--latest-prop", help="Property to use for latest record (default='updated_at')",
                        default="updated_at", dest="latest_prop")
    pstats.add_argument("-f", "--format", help="Output format", choices=[StatsCommands.FORMAT_YAML,
                                                                         StatsCommands.FORMAT_CLEAN],
                        default="yaml", dest="output_format")
    pstats.set_defaults(func=db_stats)

    # Parse args
    args = parser.parse_args()

    init_logging(args)

    # Run appropriate subparser function.
    args.func(args)
