import os
import math
import json
import types
import wt_util
import calc_sample_atk_dur
import numpy as np
import matplotlib.pyplot as plt
import multiproc as mp

from scipy.interpolate import CubicSpline


def upsample( aV, N, interp_degree ):
# aV[] - signal vector
# N - upsample factor (must be an integer >= 2)
# interp_degree - "linear" , "cubic"

    N = int(N)
    
    assert( N>= 2)
    
    aN     = len(aV)
    z      = np.zeros((aN,N))
    z[:,0] = aV
    
    # z is a copy of aV with zeros in the positions to be interpolated
    z      = np.squeeze(np.reshape(z,(aN*N,1)))

    # x contains the indexes into z which contain values from aV
    x  = [ i*N for i in range(aN) ]

    # xi contains the indexes into z which have zeros
    xi = [ i for i in range(len(z)) if i not in x and i < x[-1] ]

    # calc values for the zeros in z
    if interp_degree == "linear":
        cs = CubicSpline(x,aV)
        z[xi] = cs(xi)
        
    elif interp_degree == "cubic":
        z[xi] = np.interp(xi,x,aV)
    else:
        assert(0)

    # The last N-1 values are not set because they would require extrapolation
    # (they have no value to their right).  Instead we set these values
    # as the mean of the preceding N values.
    k = (len(z)-N)+1    
    for i in range(N-1):
        z[k+i] = np.mean(z[ k+i-N:k+i])

    return z #z[0:-(N-1)]

    

def estimate_pitch_ac( aV, si, hzL, srate, argsD ):
    # aV[] - audio vector containing a wavetable that starts at aV[si]
    # hzL[] - a list of candidate pitches
    # srate   - sample rate of aV[]
    # args[cycle_cnt] - count of cycles to autocorrelate on either side of the reference pitch at aV[si:]
    #             (1=correlate with the cycle at aV[ si-fsmp_per+cyc:] and the cycle at aV[si+fsmp_per_cyc],
    #             (2=correlate with cycles at aV[ si-2*fsmp_per+cyc:],aV[ si-fsmp_per+cyc:],aV[ si+fsmp_per+cyc:],aV[ si-2*fsmp_per+cyc:])
    # args[up_fact] - Set to and integer greater than 1 to upsample the signal prior to estimating the pitch
    # args[up_interp_degree] - Upsampling interpolator "linear" or "cubic"
    
    def _auto_corr( aV, si, fsmp_per_cyc, cycle_offset_idx, interp_degree ):

        smp_per_cyc = int(math.floor(fsmp_per_cyc))

        xi    = [si + (cycle_offset_idx * fsmp_per_cyc) + i for i in range(smp_per_cyc)]
        x_min = int(math.floor(xi[0]))
        x_max = int(math.ceil(xi[-1]))
        x     = [ i for i in range(x_min,x_max) ]
        y     = aV[x]

        if interp_degree == "cubic":
            cs = CubicSpline(x,y)
            yi = cs(xi)
        elif interp_degree == "linear":
            yi = np.interp(xi,x,y)
        else:
            assert(0)
        
        # calc the sum of squared differences between the reference cycle and the 'offset' cycle
        ac    = np.sum(np.pow(yi - aV[si:si+smp_per_cyc],2.0))

        return ac


    def auto_corr( aV, si, fsmp_per_cyc, cycle_cnt, interp_degree ):

        ac = 0
        for i in range(1,cycle_cnt+1):
            ac  = _auto_corr(aV,si,fsmp_per_cyc,  i, interp_degree)
            ac += _auto_corr(aV,si,fsmp_per_cyc, -i, interp_degree)

        # return the average sum of squared diff's per cycle
        return ac/(cycle_cnt*2)


    def ac_upsample( aV, si, fsmp_per_cyc, cycle_cnt, up_fact, up_interp_degree ):

        pad = 0 # count of leading/trailing pad positions to allow for interpolation
        
        if up_interp_degree == "cubic":
            pad = 2
        elif up_interp_degre == "linear":
            pad = 1
        else:
            assert(0)

        # calc the beg/end of the signal segment to upsample
        bi = si - math.ceil(fsmp_per_cyc * cycle_cnt) - pad
        ei = si + math.ceil(fsmp_per_cyc * (cycle_cnt + 1)) + pad

        up_aV = upsample(aV[bi:ei],up_fact,up_interp_degree)

        # calc. index of the center signal value
        u_si = (si-bi)*up_fact

        # the center value should not change after upsampling
        assert aV[si] == up_aV[u_si]

        return up_aV,u_si

    
    args = types.SimpleNamespace(**argsD)
    
    # if upsampling was requested
    if args.up_fact > 1:
        hz_min           = min(hzL)     # Select the freq candidate with the longest period,
        max_fsmp_per_cyc = srate/hz_min # because we want to upsample just enough of the signal to test for all possible candidates,
        aV,si            = ac_upsample( aV, si, max_fsmp_per_cyc, args.cycle_cnt, args.up_fact, args.up_interp_degree )
        srate            = srate * args.up_fact
        

    # calc. the auto-correlation for every possible candidate frequency
    acL = []
    for hz in hzL:
        fsmp_per_cyc = srate / hz
        acL.append( auto_corr(aV,si,fsmp_per_cyc,args.cycle_cnt,args.interp_degree) )

    
        
    if False:
        _,ax = plt.subplots(1,1)
        ax.plot(hzL,acL)
        plt.show()

    # winning candidate is the one with the lowest AC score
    cand_hz_idx = np.argmin(acL)
        
    return  hzL[cand_hz_idx]
    
# Note that we want a higher rate of pitch tracking than wave table generation - thus
# we downsample the pitch tracking interval by some integer factor to arrive at the
# rate at the wave table generation period.
def gen_wave_table_list( audio_fname,
                         mark_tsv_fname, gateL,
                         midi_pitch,
                         pitch_track_interval_secs,
                         wt_interval_down_sample_fact,
                         min_wt_db,
                         dom_ch_idx,
                         est_hz_argD,
                         ac_argD ):

    est_hz_args = types.SimpleNamespace(**est_hz_argD)
    
    aM,srate        = wt_util.parse_audio_file(audio_fname)
    markL           = wt_util.parse_marker_file(mark_tsv_fname)
    ch_cnt          = aM.shape[1]
    frm_cnt         = aM.shape[0]
    pt_interval_smp = int(round(pitch_track_interval_secs*srate))
    wt_interval_fact= int(wt_interval_down_sample_fact)
    hz              = wt_util.midi_pitch_to_hz(midi_pitch)
    fsmp_per_cyc    = srate/hz
    fsmp_per_wt     = fsmp_per_cyc * 2
    smp_per_wt      = int(math.floor(fsmp_per_wt))

    # calc. the range of possible pitch estimates
    hz_min = wt_util.midi_pitch_to_hz(midi_pitch-1)
    hz_ctr = wt_util.midi_pitch_to_hz(midi_pitch)
    hz_max = wt_util.midi_pitch_to_hz(midi_pitch+1)
    cents_per_semi = 100
    
    # hzL is a list of candidate pitches with a range of +/- 1 semitone and a resolution of 1 cent
    hzCandL = [ hz_min + i*(hz_ctr-hz_min)/100.0 for i in range(cents_per_semi) ] + [ hz_ctr + i*(hz_max-hz_ctr)/100.0 for i in range(cents_per_semi) ]

    assert( len(markL) == len(gateL) )

    # setup the return data structure
    pitchD = { "midi_pitch":midi_pitch,
               "srate":srate,
               "est_hz_mean":None,
               "est_hz_err_cents":None,
               "est_hz_std_cents":None,
               "wt_interval_secs":pitch_track_interval_secs * wt_interval_fact,
               "dominant_ch_idx":int(dom_ch_idx),
               "audio_fname":audio_fname,
               "mark_tsv_fname":mark_tsv_fname,
               "velL":[]
              }
    
    hzL = []
    for i,(beg_sec,end_sec,vel_label) in enumerate(markL):
        bsi = int(round(beg_sec*srate))
        esi = int(round(end_sec*srate))
        vel = int(vel_label)
        eai = gateL[i][1]  # end of attack

        velD = { "vel":vel, "bsi":bsi, "chL":[ [] for _ in range(ch_cnt)] }

        for ch_idx in range(ch_cnt):            

            i = 0
            while True:

                wt_smp_idx = eai + i*pt_interval_smp
                
                # select the first zero crossing after the end of the attack
                # as the start of the first sustain wavetable
                wtbi = wt_util.find_zero_crossing(aM[:,ch_idx],wt_smp_idx,1)

                #if len(velD['chL'][ch_idx]) == 0:
                #    print(midi_pitch,vel,(wtbi-bsi)/srate)

                if wtbi == None:
                    break;
                
                wtei = wtbi + smp_per_wt

                if wtei > esi:
                    break

                # estimate the pitch near wave tables which are: on the 'dominant' channel,
                # above a certain velocity and not too far into the decay 
                if ch_idx==dom_ch_idx and est_hz_args.min_wt_idx <= i and i <= est_hz_args.max_wt_idx and vel >= est_hz_args.min_vel:
                    est_hz = estimate_pitch_ac( aM[:,dom_ch_idx],wtbi,hzCandL,srate,ac_argD)
                    hzL.append( est_hz )
                    #print(vel, i, est_hz)

                if i % wt_interval_fact == 0:
                    # measure the RMS of the wavetable
                    wt_rms = float(np.pow(np.mean(np.pow(aM[wtbi:wtei,ch_idx],2.0)),0.5))

                    # filter out quiet wavetable but guarantee that there are always at least two wt's.
                    if 20*math.log10(wt_rms) > min_wt_db or len(velD['chL'][ch_idx]) < 2:

                        # store the location and RMS of the wavetable 
                        velD['chL'][ch_idx].append({"wtbi":int(wtbi),"wtei":int(wtei),"rms":float(wt_rms), "est_hz":0})

                i+=1


        pitchD['velL'].append(velD)

    # update est_hz in each of the wavetable records
    est_hz       = np.mean(hzL)
    est_hz_delta = np.array(hzCandL) - est_hz
    est_hz_idx   = np.argmin(np.abs(est_hz_delta))
    est_hz_std   = np.std(hzL)
    
    if est_hz_delta[est_hz_idx] > 0:
        est_hz_std_cents = est_hz_std / ((hz_ctr-hz_min)/100.0)
    else:
        est_hz_std_cents = est_hz_std / ((hz_max-hz_ctr)/100.0)
    
    est_hz_err_cents = est_hz_idx - cents_per_semi
            
    print(f"{midi_pitch} est pitch:{est_hz}(hz) err:{est_hz_err_cents}(cents)" )

    pitchD["est_hz_mean"]      = float(est_hz)
    pitchD["est_hz_err_cents"] = float(est_hz_err_cents)
    pitchD["est_hz_std_cents"] = float(est_hz_std_cents)
    
    return pitchD

def _gen_wave_table_bank( src_dir, midi_pitch, argD ):

    args = types.SimpleNamespace(**argD)
    
    audio_fname    = os.path.join(src_dir,f"wav/{midi_pitch:03}_samples.wav")
    mark_tsv_fname = os.path.join(src_dir,f"{midi_pitch:03}_marker.txt")

    if True:
        gateL,ch_avgRmsL = calc_sample_atk_dur.generate_gate_db(audio_fname,
                                                                mark_tsv_fname,
                                                                args.rms_wnd_ms,
                                                                args.rms_hop_ms,
                                                                args.atk_min_dur_ms,
                                                                args.atk_end_thresh_db )

    if False:
        gateL,ch_avgRmsL = calc_sample_atk_dur.generate_gate_pct(audio_fname,
                                                                 mark_tsv_fname,
                                                                 args.rms_wnd_ms,
                                                                 args.rms_hop_ms,
                                                                 args.atk_min_dur_ms,
                                                                 0.1 )
    
    dom_ch_idx = np.argmax(ch_avgRmsL)

    pitchD = gen_wave_table_list( audio_fname,
                                  mark_tsv_fname,
                                  gateL,
                                  midi_pitch,
                                  args.pitch_track_interval_secs,
                                  args.wt_interval_down_sample_fact,
                                  args.min_wt_db,
                                  dom_ch_idx,
                                  args.est_hz,
                                  args.ac )

    return pitchD



        
        
def gen_wave_table_bank_mp( processN, src_dir, midi_pitchL, out_fname, argD ):

    def _multi_proc_func( procId, procArgsD, taskArgsD ):

        return _gen_wave_table_bank( procArgsD['src_dir'],
                                     taskArgsD['midi_pitch'],
                                     procArgsD['argD'] )
    
    procArgsD = {
        "src_dir":src_dir,
        "argD": argD
    }

    taskArgsL = [ { 'midi_pitch':midi_pitch } for midi_pitch in midi_pitchL ]

    processN = min(processN,len(taskArgsL))
    
    if processN > 0:
        pitchL = mp.local_distribute_main( processN,_multi_proc_func,procArgsD,taskArgsL )
    else:
        pitchL = [ _gen_wave_table_bank( src_dir, r['midi_pitch'], argD ) for r in range(taskArgsL) ]
                

    pitchL = sorted(pitchL,key=lambda x:x['midi_pitch'])
        
    with open(out_fname,"w") as f:
        json.dump({"pitchL":pitchL, "instr":"piano", "argD":argD},f)


        
def plot_rms( wtb_json_fname ):

    with open(wtb_json_fname) as f:
        pitchL = json.load(f)['pitchL']

    pitchL = sorted(pitchL,key=lambda x:x['midi_pitch'])

    rmsnL = []
    for pitchD in pitchL:
        _,ax = plt.subplots(1,1)
        for wtVelD in pitchD['wtL']:
            for velChL in wtVelD['wtL']:
                rmsL = [ 20*math.log10(wt['rms']) for wt in velChL ]
                ax.plot(rmsL)
                rmsnL.append(len(rmsL))

        plt.title(f"{pitchD['midi_pitch']}")
        plt.show()

def plot_atk_dur( wtb_json_fname ):
    
    with open(wtb_json_fname) as f:
        pitchL = json.load(f)['pitchL']

    pitchL = sorted(pitchL,key=lambda x:x['midi_pitch'])

    rmsnL = []
    for pitchD in pitchL:
        _,ax = plt.subplots(1,1)

        secL = [ (v['chL'][0][0]['wtbi']-v['bsi'])/pitchD['srate']  for v in pitchD['velL'] ]
        velL = [ x['vel'] for x in pitchD['velL'] ]
        ax.plot(velL,secL,marker=".")

        plt.title(f"{pitchD['midi_pitch']}")
        plt.show()


def plot_hz( wtb_json_fname ):

    with open(wtb_json_fname) as f:
        pitchL = json.load(f)['pitchL']

    pitchL = sorted(pitchL,key=lambda x:x['midi_pitch'])

    _,ax = plt.subplots(3,1)

    midiL  = [ pitchD['midi_pitch'] for pitchD in pitchL ]
    hzL    = [ pitchD["est_hz_mean"] for pitchD in pitchL ]
    hzStdL = [ pitchD["est_hz_std_cents"] for pitchD in pitchL ]
    hzErrL = [ pitchD["est_hz_err_cents"] for pitchD in pitchL ]

    ax[0].plot(midiL,hzL)
    ax[1].plot(hzL,hzStdL)
    ax[2].hlines([0,10,20],midiL[0],midiL[-1],color="red")
    ax[2].plot(midiL,hzErrL)
    

    plt.show()    
    


if __name__ == "__main__":

    midi_pitchL = [ pitch for pitch in range(21,109) ]
    #midi_pitchL = [60 ]
    out_fname = "/home/kevin/temp/temp_5.json"
    src_dir= "/home/kevin/temp/wt6"
    
    argD = {
        'rms_wnd_ms':50,
        'rms_hop_ms':10,
        'atk_min_dur_ms':1000,
        'atk_end_thresh_db':-43.0,
        'min_wt_db':-80.0,
        'pitch_track_interval_secs':0.25,
        'wt_interval_down_sample_fact':8.0, # wt_interval_secs = pitch_track_interval_secs * wt_interval_down_sample_fact
        'est_hz': {
            'min_vel':50,
            'min_wt_idx':2,
            'max_wt_idx':4
        },
        'ac': {
            'cycle_cnt':8,        # count of cycles to use for auto-corr. pitch detection
            'interp_degree':"cubic",
            'up_fact':2,
            'up_interp_degree':"cubic"
        }
    }

    gen_wave_table_bank_mp(20, src_dir, midi_pitchL, out_fname, argD )

    #plot_rms(out_fname)
    #plot_hz(out_fname)
    plot_atk_dur(out_fname)