import numpy as np
import scipy as sp
import scipy.optimize as opt
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import pylab
import os, time
import multiprocessing
import scipy.sparse.linalg as sparselin
import cmath
import matplotlib.tri as tri
import matplotlib.mlab as griddata
import matplotlib.colors as mcolors
import scipy.linalg as linalg
import argparse
import matplotlib.mlab as griddata
import matplotlib.transforms as mtransforms
import matplotlib.gridspec as gridspec
from   matplotlib import rc
from   matplotlib.animation import FFMpegWriter
from   mpl_toolkits.axes_grid1 import make_axes_locatable
from   pylab import *
from   scipy.signal import savgol_filter
from   scipy.interpolate import make_interp_spline
from   matplotlib import image as mpimg
from   mpl_toolkits.axes_grid1 import make_axes_locatable
from   PIL import Image
from   itertools import groupby 
import os.path
#----------------------------- rc Parameters-----------------------------------
w= 3.375; h = 1.35; fs = 14; lw = 1.5; ax_lw = 0.5; tick_M = 1.5; tick_m = 0.8;
tex = 1; minor_x = 0; minor_y = 0; np_h=4; np_v=3; minor = 0; mks = 1;
mew=0.1; FontFamily="Times New Roman";
if minor: 
    minor_x = 1
    minor_y = 1
#------------------------------------------------------------------------------
# Tex
plt.rcParams['text.usetex'] = False
rcParams.update({'font.family': FontFamily})
# Fig
rcParams.update({'figure.figsize': [w*np_h,h*np_v]})
rcParams.update({'figure.dpi': 150}) 
rcParams.update({'legend.fontsize': fs})
# Axes
rcParams.update({'axes.labelsize' : fs})
rcParams.update({'axes.linewidth':  ax_lw})
# Ticks
rcParams.update({'xtick.labelsize': fs-2})
rcParams.update({'ytick.labelsize': fs-2}) 
rcParams.update({'xtick.direction': 'in'}) 
rcParams.update({'ytick.direction': 'in'}) 
rcParams.update({'xtick.major.width': ax_lw}) 
rcParams.update({'ytick.major.width': ax_lw}) 
rcParams.update({'xtick.major.pad': 3})
rcParams.update({'ytick.major.pad': 3})
rcParams.update({'xtick.major.size': tick_M})
rcParams.update({'ytick.major.size': tick_M})  
rcParams.update({'xtick.top': False}) 
rcParams.update({'xtick.labeltop': False}) 
rcParams.update({'xtick.bottom': True})  
rcParams.update({'xtick.labelbottom': True})
if minor_x:
    rcParams.update({'xtick.minor.visible':True})
    rcParams.update({'xtick.minor.size': tick_m}) 
    rcParams.update({'xtick.minor.width': ax_lw}) 
if minor_y:
    rcParams.update({'ytick.minor.visible':True})  
    rcParams.update({'ytick.minor.size': tick_m}) 
    rcParams.update({'ytick.minor.width': ax_lw}) 
# Lines
rcParams.update({'lines.markersize': mks})
rcParams.update({'lines.linewidth': lw})
rcParams.update({'lines.markeredgewidth': mew})
#------------------------------------------------------------------------------

#------------------------------- Plot Function---------------------------------
def PlotFun(ax, X, Y, NP, PCol, PLab, PLSty, XLab, YLab, Tit, TitFS, SetZ, SetminY, FSize, Legloc, LegFS):
    if NP>1:
        for iY in range(NP):
            ax.plot(X, Y[:,iY], label = PLab[iY], linestyle = PLSty, color = PCol[iY])    
    else:
        ax.plot(X, Y, label = PLab, linestyle = PLSty, color = PCol) 
    ax.set_title(Tit, fontsize=TitFS)
    ax.set_xlabel(XLab, fontsize=FSize)
    ax.set_ylabel(YLab, fontsize=FSize, labelpad=2)
    ax.set_xlim(min(X),max(X))
    if SetminY:
        ax.set_ylim(bottom=0.0, top=max(Y[:])*1.1)
    if len(PLab)!=0 and PLab[0]!='':
        ax.legend(loc=Legloc, framealpha=0.0, fontsize=LegFS)
    if SetZ:
        ax.axhline(0.0, lw=0.01, color='k')
#------------------------------------------------------------------------------

#--------------------------------- Main ---------------------------------------
Ymax =  3.0
Ymin = -3.0


for F in os.listdir():
    if F.endswith(".in"):
        FN = F
HSP  = []
HSPN = []; HSPN.append(0)
fname = FN
f = open(fname,'r')
while True:
    line = f.readline()
    if "System.Name" in line:
        SysNam = str(line.split()[1])
    if "scf.SpinPolarization" in line:
        if str(line.split()[1])=='On' or str(line.split()[1])=='on':
            BN = 2
        else:
            BN = 1
    if "<Band.kpath" in line:        
        while True:
            Nline = f.readline()
            if "Band.kpath>" in Nline:
                break                
            else:
                HSPN.append(int(Nline.split()[0])+HSPN[-1])
                HSP.append(str(Nline.split()[7]))
                HSP.append(str(Nline.split()[8]))
        break
f.close()
HSP  = array(HSP)
HSPN = array(HSPN); HSPN[1:] = HSPN[1:]-1
HSPL = []
for ii in range(len(HSP)):
    if ii==0 or ii==len(HSP)-1:
        HSPL.append(HSP[ii])
    elif ii%2!=0:
        if HSP[ii]==HSP[ii+1]:
            HSPL.append(HSP[ii])
        else:
            HSPL.append(HSP[ii]+','+HSP[ii+1])
HSPL  = array(HSPL);

print(FN)
print(HSPN)
print(HSPL)
print(BN)
print(SysNam)
NB = int(HSPN[-1]+1)
NL = -1
f = open(str(SysNam)+".BANDDAT1",'r')
while True:
    line = f.readline()
    NL  += 1
    if len(line)>1: 
        if str(line.split()[0])=="0.000000":
            break
f.close()
print(NL)


if BN==2:
    Data  = np.loadtxt(str(SysNam)+".BANDDAT1", skiprows=NL)
    BandU = np.reshape(Data[:,1], (-1,NB))
    Data  = np.loadtxt(str(SysNam)+".BANDDAT2", skiprows=NL)
    BandD = np.reshape(Data[:,1], (-1,NB))
    BG    = min(BandU[BandU>0]) - max(BandU[BandU<0])
    Kp    = np.reshape(Data[:,0], (-1,NB))[0,:]
    HSPK  = []
    for ii in HSPN:
        HSPK.append(Kp[ii])
    HSPK  = array(HSPK)
    
    fig=plt.figure(figsize=(8,6)); ax1 = fig.add_subplot(111)
    for jj in range(BandU.shape[0]):
        if jj==0:
            PlotFun(ax=ax1, X=Kp, Y=BandU[jj,:], NP=1, PCol='b', PLab='Spin-Up', PLSty ='-', XLab='', YLab='E (eV)', Tit='', TitFS=fs, SetZ=0, SetminY=0, FSize=fs, Legloc='best', LegFS=fs)
            PlotFun(ax=ax1, X=Kp, Y=BandD[jj,:], NP=1, PCol='r', PLab='Spin-Down', PLSty ='-', XLab='', YLab='E (eV)', Tit='', TitFS=fs, SetZ=0, SetminY=0, FSize=fs, Legloc='best', LegFS=fs)
        else:
            PlotFun(ax=ax1, X=Kp, Y=BandU[jj,:], NP=1, PCol='b', PLab='', PLSty ='-', XLab='', YLab='E (eV)', Tit='', TitFS=fs, SetZ=0, SetminY=0, FSize=fs, Legloc='best', LegFS=fs)
            PlotFun(ax=ax1, X=Kp, Y=BandD[jj,:], NP=1, PCol='r', PLab='', PLSty ='-', XLab='', YLab='E (eV)', Tit='', TitFS=fs, SetZ=0, SetminY=0, FSize=fs, Legloc='best', LegFS=fs)        
    plt.axhline(0, lw=1, color='k')           
    for k in HSPK:
        plt.axvline(k, lw=0.5, color='k')
    plt.xticks(HSPK,HSPL)
    plt.rc('xtick', labelsize=10) 
    plt.rc('ytick', labelsize=20) 
    plt.ylabel('E (eV)')
    plt.ylim(Ymin,Ymax)
    plt.xlim(0,max(Kp))
    plt.title('Band Structure,  BG = {:3.2f} eV'.format(BG), fontsize=14)
    plt.tight_layout()
    plt.savefig('Band.png', bbox_inches='tight', dpi=200)
else:
    Data  = np.loadtxt(str(SysNam)+".BANDDAT1", skiprows=NL)
    Band  = np.reshape(Data[:,1], (-1,NB))
    BG    = min(Band[Band>0]) - max(Band[Band<0])
    Kp    = np.reshape(Data[:,0], (-1,NB))[0,:]
    HSPK  = []
    for ii in HSPN:
        HSPK.append(Kp[ii])          
    fig=plt.figure(figsize=(8,6)); ax1 = fig.add_subplot(111)
    for jj in range(Band.shape[0]):
            PlotFun(ax=ax1, X=Kp, Y=Band[jj,:], NP=1, PCol='b', PLab='', PLSty ='-', XLab='', YLab='E (eV)', Tit='', TitFS=fs, SetZ=0, SetminY=0, FSize=fs, Legloc='best', LegFS=fs)        
    plt.axhline(0, lw=1, color='k')           
    for k in HSPK:
        plt.axvline(k, lw=0.5, color='k')
    plt.xticks(HSPK,HSPL)
    plt.rc('xtick', labelsize=10) 
    plt.rc('ytick', labelsize=20) 
    plt.ylabel('E (eV)')
    plt.ylim(Ymin,Ymax)
    plt.xlim(0,max(Kp))
    plt.title('Band Structure,  BG = {:3.2f} eV'.format(BG), fontsize=14)
    plt.tight_layout()
    plt.savefig('Band.png', bbox_inches='tight', dpi=200)