import sys,os,json,types
import numpy as np
import matplotlib.pyplot as plt
import matplotlib._color_data as mcd
from matplotlib.pyplot import figure

from rms_analysis import calibrate_recording_analysis
from rms_analysis import key_info_dictionary

def plot_by_pitch( inDir, keyInfoD, pitch=None ):

    anlD = calibrate_recording_analysis( inDir )
    jsonFn  = os.path.join(inDir, "meas.json" )
    audioFn = os.path.join(inDir, "audio.wav" )

    with open(jsonFn,"r") as f:
        r = json.load(f)

    measD = r['measD']
    cfg  = types.SimpleNamespace(**r['cfg'])
    
    axN = len(measD) if pitch is None else 1
    fig,axL = plt.subplots(axN,1)
    fig.set_size_inches(18.5, 10.5*axN)


    # for each pitch
    for axi,(midi_pitch,measL)in enumerate(measD.items()):

        midi_pitch = int(midi_pitch)
        
        if pitch is not None and pitch != midi_pitch:
            continue

        if pitch is not None:
            axi = 0
            axL = [ axL ]
        
        targetDbS  = set()
        hmPulseDbL = []
        tdPulseDbL = []
        anPulseDbL = []

        # for each measurement on this pitch
        for mi,d in enumerate(measL):
            m = types.SimpleNamespace(**d)

            # form a list of pulse/db measurements associated with this pitch
            hmPulseDbL.append( (m.pulse_us,m.hm['db'],m.matchFl,m.hm['durMs'],m.skipMeasFl) )
            tdPulseDbL.append( (m.pulse_us,m.td['db'],m.matchFl,m.td['durMs'],m.skipMeasFl) )

            ar = next(ad for ad in anlD[midi_pitch] if ad['meas_idx']==mi )
            anPulseDbL.append( (m.pulse_us,ar['db'],m.matchFl,m.hm['durMs'],m.skipMeasFl))

            # get the unique set of targets
            targetDbS.add(m.targetDb)


        # sort measurements on pulse length
        hmPulseDbL = sorted(hmPulseDbL,key=lambda x: x[0])
        tdPulseDbL = sorted(tdPulseDbL,key=lambda x: x[0])
        anPulseDbL = sorted(anPulseDbL,key=lambda x: x[0])

        # plot the re-analysis 
        pulseL,dbL,matchFlL,_,_ = zip(*anPulseDbL)
        axL[axi].plot( pulseL, dbL, label="post", marker='.' )
        
        # plot harmonic measurements
        pulseL,dbL,matchFlL,durMsL,skipFlL = zip(*hmPulseDbL)
        axL[axi].plot( pulseL, dbL, label="harm", marker='.' )

        # plot time-domain based measuremented
        pulseL,dbL,matchFlL,_,_ = zip(*tdPulseDbL)
        axL[axi].plot( pulseL, dbL, label="td", marker='.' )

        
        # plot target boundaries
        for targetDb in targetDbS:
            lwr = targetDb * ((100.0 - cfg.tolDbPct)/100.0)
            upr = targetDb * ((100.0 + cfg.tolDbPct)/100.0 )       

            axL[axi].axhline(targetDb)
            axL[axi].axhline(lwr,color='lightgray')
            axL[axi].axhline(upr,color='gray')

        # plot match and 'too-short' markers
        for i,matchFl in enumerate(matchFlL):

            if durMsL[i] < cfg.minMeasDurMs:
                axL[axi].plot( pulseL[i], dbL[i], marker='x', color='black', linestyle='None')

            if skipFlL[i]:
                axL[axi].plot( pulseL[i], dbL[i], marker='+', color='blue', linestyle='None')
                
            if matchFl:
                axL[axi].plot( pulseL[i], dbL[i], marker='.', color='red', linestyle='None')

                
                

        axL[axi].set_title("pitch:%i %s" % (midi_pitch,keyInfoD[midi_pitch].type))
        
    plt.legend()
    plt.show()

def plot_all_notes( inDir ):

    jsonFn  = os.path.join(inDir, "meas.json" )
    audioFn = os.path.join(inDir, "audio.wav" )

    with open(jsonFn,"r") as f:
        r = json.load(f)

    measD = r['measD']

    axN = 0
    for midi_pitch,measL in measD.items():
        axN += len(measL)

    print(axN)
    fig,axL = plt.subplots(axN,1)
    fig.set_size_inches(18.5, 10.5*axN)
    

    i = 0
    for midi_pitch,measL in measD.items():
        for d in measL:
            axL[i].plot(d['td']['rmsDbV'])
            axL[i].plot(d['hm']['rmsDbV'])

            axL[i].axvline(d['td']['pk_idx'],color='red')
            axL[i].axvline(d['hm']['pk_idx'],color='green')

            i += 1

    plt.show()



if __name__ == "__main__":

    pitch = None
    inDir = sys.argv[1]
    yamlFn = sys.argv[2]
    if len(sys.argv) > 3:
        pitch = int(sys.argv[3])

    keyInfoD = key_info_dictionary( yamlCfgFn=yamlFn)
    #plot_all_notes( inDir )
    plot_by_pitch(inDir,keyInfoD,pitch)
    #calibrate_recording_analysis( inDir )