#!/usr/bin/env  python
# encoding: utf-8
import numpy as np
import warnings
from pprint import pprint
from optparse import OptionParser
import pandas as pd
from tabulate import tabulate
from vasprun.IR import IR
from vasprun import vasprun

if __name__ == "__main__":
    parser = OptionParser()
    parser.add_option("-i", "--incar", dest="incar", action='store_true',
                      help="export incar file", metavar='incar file')
    parser.add_option("-p", "--poscar", dest="poscar",
                      help="export poscar file", metavar="poscar file")
    parser.add_option("-k", "--kpoints", dest="kpoints", action='store_true',
                      help="kpoints file", metavar="kpoints file")
    parser.add_option("-d", "--dosplot", dest="dosplot", metavar="dos_plot", type=str,
                      help="export dos plot, options: t, spd, a, a-Si, a-1")
    parser.add_option("-b", "--bandplot", dest="bandplot", metavar="band_plot", type=str,
                      help="export band plot, options: normal or projected")
    parser.add_option("-v", "--vasprun", dest="vasprun", default='vasprun.xml',
                      help="path of vasprun.xml file, default: vasprun.xml", metavar="vasprun")
    parser.add_option("-f", "--showforce", dest="force", action='store_true',
                      help="show forces, default: no", metavar="dos_plot")
    parser.add_option("-a", "--allparameters", dest="parameters", action='store_true',
                      help="show all parameters", metavar="parameter")
    parser.add_option("-e", "--eigenvalues", dest="band", action='store_true',
                      help="show eigenvalues in valence/conduction band", metavar="dos_plot")
    parser.add_option("-s", "--smear", dest="smear", type='float',
                      help="smearing parameter for dos plot, e.g., 0.1 A", metavar="smearing")
    parser.add_option("-n", "--figname", dest="figname", type=str, default='fig.png',
                      help="dos/band figure name, default: fig.png", metavar="figname")
    parser.add_option("-l", "--lim", dest="lim", default='-3,3', 
                      help="dos/band plot lim, default: -3,3", metavar="lim")
    parser.add_option("-m", "--max", dest="max", default='0.0,0.5', 
                      help="band plot colorbar, default: 0.0,0.5", metavar="max")
    parser.add_option("--dyn", dest="dyn", action='store_true',
                      help="dynamic matrix analysis, default: false", metavar="dyn")
    parser.add_option("--dpi", dest="dpi", default=300, type=int,
                      help="figure' dpi, default: 300", metavar="dpi")
    parser.add_option("-S", "--save-bands", dest="saveBands", action='store_true',
                      help="saving of bands contributing to the plot, default: false", metavar="saveBands")

    (options, args) = parser.parse_args()
    if options.vasprun is None:
        test = vasprun()
    else:
        test = vasprun(options.vasprun)

    # standard output
    if test.values['parameters']['ionic']['NSW'] <= 1:
        print('This is a single point calculation')
    # pprint(test.values['kpoints'])

    output = {'formula': None,
              'calculation': ['efermi', 'energy', 'energy_per_atom'],
              'metal': None,
              'gap': None,
              "bands": None}
    for tag in output.keys():
        if output[tag] is None:
            print(tag, ':  ', test.values[tag])
        else:
            for subtag in output[tag]:
                print(subtag, ':  ', test.values[tag][subtag])

    # show VBM and CBM when it is nonmetal
    if test.values['metal'] is False:
        col_name = {'label': ['CBM', 'VBM'],
                    'kpoint': [test.values['cbm']['kpoint'], test.values['vbm']['kpoint']],
                    'values': [test.values['cbm']['value'], test.values['vbm']['value']]}
        df = pd.DataFrame(col_name)
        print(tabulate(df, headers='keys', tablefmt='psql'))

    col_name = {'valence': test.values['valence'],
                'labels': test.values['pseudo_potential']['labels'],
                'functional': test.values['pseudo_potential']['functional']}
    df = pd.DataFrame(col_name)
    print(tabulate(df, headers='keys', tablefmt='psql'))

    if options.force:
        col_name = {'lattice': test.values['finalpos']['basis'],
                    'stress (kbar)': test.values['calculation']['stress']}
        df = pd.DataFrame(col_name)
        print(tabulate(df, headers='keys', tablefmt='psql'))
        col_name = {'atom': test.values['finalpos']['positions'],
                    'force (eV/A)': test.values['calculation']['force']}
        df = pd.DataFrame(col_name)
        print(tabulate(df, headers='keys', tablefmt='psql'))

    if options.incar:
        test.export_incar(filename=options.incar)
    elif options.kpoints:
        test.export_kpoints(filename=options.kpoints)
    elif options.poscar:
        test.export_poscar(filename=options.poscar)
    elif options.parameters:
        pprint(test.values['parameters'])
    elif options.dosplot:
        lim = options.lim.split(',')
        lim = [float(i) for i in lim]
        test.plot_dos(styles=options.dosplot, filename=options.figname, xlim=lim, smear=options.smear)
    elif options.bandplot:
        lim = options.lim.split(',')
        lim = [float(i) for i in lim]
        pmaxmin = options.max.split(',')
        pmaxmin = [float(i) for i in pmaxmin]
        test.parse_bandpath()
        test.plot_band(styles=options.bandplot, filename=options.figname, ylim=lim, plim=pmaxmin, saveBands=options.saveBands, dpi=options.dpi)
    elif options.band:
        vb = test.values['bands']-1
        cb = vb + 1
        test.show_eigenvalues_by_band([vb, cb])
        cbs = test.eigenvalues_by_band(cb)
        vbs = test.eigenvalues_by_band(vb)
        ID = np.argmin(cbs-vbs)
        if len(cbs) == len(test.values['kpoints']['list']):
            print("Eigenvalue at CBM: {:8.4f}".format(min(cbs)))
            print("Eigenvalue at VBM: {:8.4f}".format(max(vbs)))
            print("minimum gap at : ", test.values['kpoints']['list'][ID])
            print("CB:                {:8.4f}".format(cbs[ID]))
            print("VB:                {:8.4f}".format(vbs[ID]))
            print("diff:              {:8.4f}".format(cbs[ID]-vbs[ID]))
        else:
            print("This is spin calculation")
    elif options.dyn:
        from vasprun.IR import IR
        chg = test.values['calculation']['born_charges']
        eig = test.values['calculation']['normal_modes_eigenvalues']
        eigv = test.values['calculation']['normal_modes_eigenvectors']
        vol = np.linalg.det(np.array(test.values['finalpos']['basis']))
        mass = []
        for i, ele in enumerate(test.values["composition"].keys()):
            for m in range(test.values["composition"][ele]):
                mass.append(test.values['mass'][i])
        IR(chg, eig, eigv, mass, vol).show()
        modes = []
        for mode in eigv:
            modes.append(np.array(mode))
        print(np.sum(modes[14]*modes[10]))
        eps = np.array(test.values['calculation']['epsilon_ion'])
        print("{:25s} {:12.3f} {:12.3f} {:12.3f}".format('DFPT', eps[0,0], eps[0,1], eps[0,2]))
        print("{:25s} {:12.3f} {:12.3f} {:12.3f}".format('DFPT', eps[1,0], eps[1,1], eps[1,2]))
        print("{:25s} {:12.3f} {:12.3f} {:12.3f}".format('DFPT', eps[2,0], eps[2,1], eps[2,2]))
