import os, os.path, string
import matplotlib.pyplot as plt
import matplotlib.dates as dates
import pdb
import collections as cll

import numpy as np
from read_pbpfile import *
from read_statevec import *
from read_profile import *
from read_ak import *
from read_misc import *
import detfile


# here to make pickle work
prf = cll.namedtuple('prf', ['date',
                             'duration',
                             'column',
                             'column_apriori',
                             'vmr', 
                             'vmr_apriori',
                             'vmr_sigma',
                             'Z', 'Zbar',
                             'avk',
                             'all_gases', 'all_columns', 'all_dcolumns',
                             'snr_theo', 'snr', 'iter', 'nvar', 'nfit'])

def read_sfit4(direc):
    p = []

    m = 0
    n = 0
    dirs = os.listdir(direc)
    for dd in dirs:
        n = n+1
        miscfile = string.join([direc, '/', dd, '/misc.out'], '')
        if os.path.isfile(miscfile):
            mt, dt = read_misc(miscfile)
            date = mt.pop(0)
            dur = dt.pop(0)
        else:
            continue

        prffile = string.join([direc, '/', dd, '/PRFS.out'], '')
        if os.path.isfile(prffile):
            Z,ZBAR,ap_vmr,rt_vmr,sigma_vmr,ap_col,rt_col = read_prf(prffile)
        else:
            continue

        avkfile = string.join([direc, '/', dd, '/AK.out'], '')
        if os.path.isfile(avkfile):
            avk = read_ak(avkfile)
        else:
            continue

        dfile = string.join([direc, '/', dd, '/detail'], '')
        if os.path.isfile(dfile):
            det = detfile.read_det_sfit4(dfile)
            if det == -1:
                continue
        else:
            continue

        date = dates.epoch2num(float(date)).copy()
        p.append(prf(date,dur,rt_col,ap_col,rt_vmr,ap_vmr,sigma_vmr,ZBAR,Z,avk,
                     det.gas,det.col,det.dcol,det.SNR_Theo,det.SNR,det.Iter,det.NVAR,det.NFIT))
        m = m+1
        
    return(p)

def get_column(p, ts=None):
    col = []
    dd = []
    for n in p:
        col.append(sum(n.column))
        dd.append(n.date)
#
    col = array(col)
    dd = array(dd)

    if ts == 'daily':
        dd_mean = list(set(dd.round()))
        col_mean = []
        for ndd in dd_mean:
            inds = np.nonzero(abs(ndd - dd)<1)
#            pdb.set_trace()           
            col_mean.append(np.mean(col[inds]))
        dd = array(dd_mean).copy()
        col = array(col_mean).copy()

    return(col, dd)

def get_avk_column(p):
    avk_col = []
    dd = []

    nr = 0
    for n in p:
        avk_col.append(np.sum(n.avk,axis=0))
        dd.append(n.date)

    return(avk_col, dd)

def get_dofs(p):
    dofs = []
    dd = []

    for n in p:
        dofs.append(np.trace(n.avk,2))
        dd.append(n.date)

    return(np.array(dofs), np.array(dd))

def get_snr(p):
    snr = []
    dd = []

    for n in p:
        snr.append(n.snr)
        dd.append(n.date)

    
    return(np.array(snr), np.array(dd))


def plot_column(p, ts=None):
    
    col,dd = get_column(p, ts)

    fig = plt.figure(1)
    fig.clf()
    plt.plot_date(dd,col,'x',xdate=True,ydate=False)
    hfmt = dates.DateFormatter('%Y')
    ax = fig.gca()
    ax.xaxis.set_major_locator(dates.YearLocator())
    ax.xaxis.set_minor_locator(dates.MonthLocator(bymonth=7))
    ax.xaxis.set_major_formatter(hfmt)
    fig.show()

def plot_vmr(p):
    vmr = np.ndarray((len(p[0].vmr), len(p)))
    Z = np.ndarray((len(p[0].vmr), len(p)))
    dd = np.ndarray(len(p))

    nr = 0
    for n in p:
        vmr[:,nr] = np.array(n.vmr).copy()
        Z[:,nr] = np.array(n.Z).copy()
        dd[nr] = n.date
        nr = nr+1

    b = dd.argsort()        
    dd = dd[b]
    vmr = vmr[:,b]
    Z = Z[:,b]
    fig = plt.figure(1)
    fig.clf()
    plt.contourf(dd, Z[:,1], vmr)
    plt.gca().xaxis_date()
    hfmt = dates.DateFormatter('%Y')
    ax = fig.gca()
    ax.xaxis.set_major_locator(dates.YearLocator())
    ax.xaxis.set_minor_locator(dates.MonthLocator(bymonth=7))
    ax.xaxis.set_major_formatter(hfmt)
    h = plt.colorbar(format='%0.0f')
    h.set_ticklabels(h.boundaries*1e6)
    h.set_label('VMR [ppmv]')
    fig.show()


