448 lines
15 KiB
Python
448 lines
15 KiB
Python
import os
|
|
import math
|
|
import json
|
|
import types
|
|
import wt_util
|
|
import calc_sample_atk_dur
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
import multiproc as mp
|
|
|
|
from scipy.interpolate import CubicSpline
|
|
|
|
|
|
def upsample( aV, N, interp_degree ):
|
|
# aV[] - signal vector
|
|
# N - upsample factor (must be an integer >= 2)
|
|
# interp_degree - "linear" , "cubic"
|
|
|
|
N = int(N)
|
|
|
|
assert( N>= 2)
|
|
|
|
aN = len(aV)
|
|
z = np.zeros((aN,N))
|
|
z[:,0] = aV
|
|
|
|
# z is a copy of aV with zeros in the positions to be interpolated
|
|
z = np.squeeze(np.reshape(z,(aN*N,1)))
|
|
|
|
# x contains the indexes into z which contain values from aV
|
|
x = [ i*N for i in range(aN) ]
|
|
|
|
# xi contains the indexes into z which have zeros
|
|
xi = [ i for i in range(len(z)) if i not in x and i < x[-1] ]
|
|
|
|
# calc values for the zeros in z
|
|
if interp_degree == "linear":
|
|
cs = CubicSpline(x,aV)
|
|
z[xi] = cs(xi)
|
|
|
|
elif interp_degree == "cubic":
|
|
z[xi] = np.interp(xi,x,aV)
|
|
else:
|
|
assert(0)
|
|
|
|
# The last N-1 values are not set because they would require extrapolation
|
|
# (they have no value to their right). Instead we set these values
|
|
# as the mean of the preceding N values.
|
|
k = (len(z)-N)+1
|
|
for i in range(N-1):
|
|
z[k+i] = np.mean(z[ k+i-N:k+i])
|
|
|
|
return z #z[0:-(N-1)]
|
|
|
|
|
|
|
|
def estimate_pitch_ac( aV, si, hzL, srate, argsD ):
|
|
# aV[] - audio vector containing a wavetable that starts at aV[si]
|
|
# hzL[] - a list of candidate pitches
|
|
# srate - sample rate of aV[]
|
|
# args[cycle_cnt] - count of cycles to autocorrelate on either side of the reference pitch at aV[si:]
|
|
# (1=correlate with the cycle at aV[ si-fsmp_per+cyc:] and the cycle at aV[si+fsmp_per_cyc],
|
|
# (2=correlate with cycles at aV[ si-2*fsmp_per+cyc:],aV[ si-fsmp_per+cyc:],aV[ si+fsmp_per+cyc:],aV[ si-2*fsmp_per+cyc:])
|
|
# args[up_fact] - Set to and integer greater than 1 to upsample the signal prior to estimating the pitch
|
|
# args[up_interp_degree] - Upsampling interpolator "linear" or "cubic"
|
|
|
|
def _auto_corr( aV, si, fsmp_per_cyc, cycle_offset_idx, interp_degree ):
|
|
|
|
smp_per_cyc = int(math.floor(fsmp_per_cyc))
|
|
|
|
xi = [si + (cycle_offset_idx * fsmp_per_cyc) + i for i in range(smp_per_cyc)]
|
|
x_min = int(math.floor(xi[0]))
|
|
x_max = int(math.ceil(xi[-1]))
|
|
x = [ i for i in range(x_min,x_max) ]
|
|
y = aV[x]
|
|
|
|
if interp_degree == "cubic":
|
|
cs = CubicSpline(x,y)
|
|
yi = cs(xi)
|
|
elif interp_degree == "linear":
|
|
yi = np.interp(xi,x,y)
|
|
else:
|
|
assert(0)
|
|
|
|
# calc the sum of squared differences between the reference cycle and the 'offset' cycle
|
|
ac = np.sum(np.pow(yi - aV[si:si+smp_per_cyc],2.0))
|
|
|
|
return ac
|
|
|
|
|
|
def auto_corr( aV, si, fsmp_per_cyc, cycle_cnt, interp_degree ):
|
|
|
|
ac = 0
|
|
for i in range(1,cycle_cnt+1):
|
|
ac = _auto_corr(aV,si,fsmp_per_cyc, i, interp_degree)
|
|
ac += _auto_corr(aV,si,fsmp_per_cyc, -i, interp_degree)
|
|
|
|
# return the average sum of squared diff's per cycle
|
|
return ac/(cycle_cnt*2)
|
|
|
|
|
|
def ac_upsample( aV, si, fsmp_per_cyc, cycle_cnt, up_fact, up_interp_degree ):
|
|
|
|
pad = 0 # count of leading/trailing pad positions to allow for interpolation
|
|
|
|
if up_interp_degree == "cubic":
|
|
pad = 2
|
|
elif up_interp_degre == "linear":
|
|
pad = 1
|
|
else:
|
|
assert(0)
|
|
|
|
# calc the beg/end of the signal segment to upsample
|
|
bi = si - math.ceil(fsmp_per_cyc * cycle_cnt) - pad
|
|
ei = si + math.ceil(fsmp_per_cyc * (cycle_cnt + 1)) + pad
|
|
|
|
up_aV = upsample(aV[bi:ei],up_fact,up_interp_degree)
|
|
|
|
# calc. index of the center signal value
|
|
u_si = (si-bi)*up_fact
|
|
|
|
# the center value should not change after upsampling
|
|
assert aV[si] == up_aV[u_si]
|
|
|
|
return up_aV,u_si
|
|
|
|
|
|
args = types.SimpleNamespace(**argsD)
|
|
|
|
# if upsampling was requested
|
|
if args.up_fact > 1:
|
|
hz_min = min(hzL) # Select the freq candidate with the longest period,
|
|
max_fsmp_per_cyc = srate/hz_min # because we want to upsample just enough of the signal to test for all possible candidates,
|
|
aV,si = ac_upsample( aV, si, max_fsmp_per_cyc, args.cycle_cnt, args.up_fact, args.up_interp_degree )
|
|
srate = srate * args.up_fact
|
|
|
|
|
|
# calc. the auto-correlation for every possible candidate frequency
|
|
acL = []
|
|
for hz in hzL:
|
|
fsmp_per_cyc = srate / hz
|
|
acL.append( auto_corr(aV,si,fsmp_per_cyc,args.cycle_cnt,args.interp_degree) )
|
|
|
|
|
|
|
|
if False:
|
|
_,ax = plt.subplots(1,1)
|
|
ax.plot(hzL,acL)
|
|
plt.show()
|
|
|
|
# winning candidate is the one with the lowest AC score
|
|
cand_hz_idx = np.argmin(acL)
|
|
|
|
return hzL[cand_hz_idx]
|
|
|
|
# Note that we want a higher rate of pitch tracking than wave table generation - thus
|
|
# we downsample the pitch tracking interval by some integer factor to arrive at the
|
|
# rate at the wave table generation period.
|
|
def gen_wave_table_list( audio_fname,
|
|
mark_tsv_fname, gateL,
|
|
midi_pitch,
|
|
pitch_track_interval_secs,
|
|
wt_interval_down_sample_fact,
|
|
min_wt_db,
|
|
dom_ch_idx,
|
|
est_hz_argD,
|
|
ac_argD ):
|
|
|
|
est_hz_args = types.SimpleNamespace(**est_hz_argD)
|
|
|
|
aM,srate = wt_util.parse_audio_file(audio_fname)
|
|
markL = wt_util.parse_marker_file(mark_tsv_fname)
|
|
ch_cnt = aM.shape[1]
|
|
frm_cnt = aM.shape[0]
|
|
pt_interval_smp = int(round(pitch_track_interval_secs*srate))
|
|
wt_interval_fact= int(wt_interval_down_sample_fact)
|
|
hz = wt_util.midi_pitch_to_hz(midi_pitch)
|
|
fsmp_per_cyc = srate/hz
|
|
fsmp_per_wt = fsmp_per_cyc * 2
|
|
smp_per_wt = int(math.floor(fsmp_per_wt))
|
|
|
|
# calc. the range of possible pitch estimates
|
|
hz_min = wt_util.midi_pitch_to_hz(midi_pitch-1)
|
|
hz_ctr = wt_util.midi_pitch_to_hz(midi_pitch)
|
|
hz_max = wt_util.midi_pitch_to_hz(midi_pitch+1)
|
|
cents_per_semi = 100
|
|
|
|
# hzL is a list of candidate pitches with a range of +/- 1 semitone and a resolution of 1 cent
|
|
hzCandL = [ hz_min + i*(hz_ctr-hz_min)/100.0 for i in range(cents_per_semi) ] + [ hz_ctr + i*(hz_max-hz_ctr)/100.0 for i in range(cents_per_semi) ]
|
|
|
|
assert( len(markL) == len(gateL) )
|
|
|
|
# setup the return data structure
|
|
pitchD = { "midi_pitch":midi_pitch,
|
|
"srate":srate,
|
|
"est_hz_mean":None,
|
|
"est_hz_err_cents":None,
|
|
"est_hz_std_cents":None,
|
|
"wt_interval_secs":pitch_track_interval_secs * wt_interval_fact,
|
|
"dominant_ch_idx":int(dom_ch_idx),
|
|
"audio_fname":audio_fname,
|
|
"mark_tsv_fname":mark_tsv_fname,
|
|
"velL":[]
|
|
}
|
|
|
|
hzL = []
|
|
for i,(beg_sec,end_sec,vel_label) in enumerate(markL):
|
|
bsi = int(round(beg_sec*srate))
|
|
esi = int(round(end_sec*srate))
|
|
vel = int(vel_label)
|
|
eai = gateL[i][1] # end of attack
|
|
|
|
velD = { "vel":vel, "bsi":bsi, "chL":[ [] for _ in range(ch_cnt)] }
|
|
|
|
for ch_idx in range(ch_cnt):
|
|
|
|
i = 0
|
|
while True:
|
|
|
|
wt_smp_idx = eai + i*pt_interval_smp
|
|
|
|
# select the first zero crossing after the end of the attack
|
|
# as the start of the first sustain wavetable
|
|
wtbi = wt_util.find_zero_crossing(aM[:,ch_idx],wt_smp_idx,1)
|
|
|
|
#if len(velD['chL'][ch_idx]) == 0:
|
|
# print(midi_pitch,vel,(wtbi-bsi)/srate)
|
|
|
|
if wtbi == None:
|
|
break;
|
|
|
|
wtei = wtbi + smp_per_wt
|
|
|
|
if wtei > esi:
|
|
break
|
|
|
|
# estimate the pitch near wave tables which are: on the 'dominant' channel,
|
|
# above a certain velocity and not too far into the decay
|
|
if ch_idx==dom_ch_idx and est_hz_args.min_wt_idx <= i and i <= est_hz_args.max_wt_idx and vel >= est_hz_args.min_vel:
|
|
est_hz = estimate_pitch_ac( aM[:,dom_ch_idx],wtbi,hzCandL,srate,ac_argD)
|
|
hzL.append( est_hz )
|
|
#print(vel, i, est_hz)
|
|
|
|
if i % wt_interval_fact == 0:
|
|
# measure the RMS of the wavetable
|
|
wt_rms = float(np.pow(np.mean(np.pow(aM[wtbi:wtei,ch_idx],2.0)),0.5))
|
|
|
|
# filter out quiet wavetable but guarantee that there are always at least two wt's.
|
|
if 20*math.log10(wt_rms) > min_wt_db or len(velD['chL'][ch_idx]) < 2:
|
|
|
|
# store the location and RMS of the wavetable
|
|
velD['chL'][ch_idx].append({"wtbi":int(wtbi),"wtei":int(wtei),"rms":float(wt_rms), "est_hz":0})
|
|
|
|
i+=1
|
|
|
|
|
|
pitchD['velL'].append(velD)
|
|
|
|
# update est_hz in each of the wavetable records
|
|
est_hz = np.mean(hzL)
|
|
est_hz_delta = np.array(hzCandL) - est_hz
|
|
est_hz_idx = np.argmin(np.abs(est_hz_delta))
|
|
est_hz_std = np.std(hzL)
|
|
|
|
if est_hz_delta[est_hz_idx] > 0:
|
|
est_hz_std_cents = est_hz_std / ((hz_ctr-hz_min)/100.0)
|
|
else:
|
|
est_hz_std_cents = est_hz_std / ((hz_max-hz_ctr)/100.0)
|
|
|
|
est_hz_err_cents = est_hz_idx - cents_per_semi
|
|
|
|
print(f"{midi_pitch} est pitch:{est_hz}(hz) err:{est_hz_err_cents}(cents)" )
|
|
|
|
pitchD["est_hz_mean"] = float(est_hz)
|
|
pitchD["est_hz_err_cents"] = float(est_hz_err_cents)
|
|
pitchD["est_hz_std_cents"] = float(est_hz_std_cents)
|
|
|
|
return pitchD
|
|
|
|
def _gen_wave_table_bank( src_dir, midi_pitch, argD ):
|
|
|
|
args = types.SimpleNamespace(**argD)
|
|
|
|
audio_fname = os.path.join(src_dir,f"wav/{midi_pitch:03}_samples.wav")
|
|
mark_tsv_fname = os.path.join(src_dir,f"{midi_pitch:03}_marker.txt")
|
|
|
|
if True:
|
|
gateL,ch_avgRmsL = calc_sample_atk_dur.generate_gate_db(audio_fname,
|
|
mark_tsv_fname,
|
|
args.rms_wnd_ms,
|
|
args.rms_hop_ms,
|
|
args.atk_min_dur_ms,
|
|
args.atk_end_thresh_db )
|
|
|
|
if False:
|
|
gateL,ch_avgRmsL = calc_sample_atk_dur.generate_gate_pct(audio_fname,
|
|
mark_tsv_fname,
|
|
args.rms_wnd_ms,
|
|
args.rms_hop_ms,
|
|
args.atk_min_dur_ms,
|
|
0.1 )
|
|
|
|
dom_ch_idx = np.argmax(ch_avgRmsL)
|
|
|
|
pitchD = gen_wave_table_list( audio_fname,
|
|
mark_tsv_fname,
|
|
gateL,
|
|
midi_pitch,
|
|
args.pitch_track_interval_secs,
|
|
args.wt_interval_down_sample_fact,
|
|
args.min_wt_db,
|
|
dom_ch_idx,
|
|
args.est_hz,
|
|
args.ac )
|
|
|
|
return pitchD
|
|
|
|
|
|
|
|
|
|
|
|
def gen_wave_table_bank_mp( processN, src_dir, midi_pitchL, out_fname, argD ):
|
|
|
|
def _multi_proc_func( procId, procArgsD, taskArgsD ):
|
|
|
|
return _gen_wave_table_bank( procArgsD['src_dir'],
|
|
taskArgsD['midi_pitch'],
|
|
procArgsD['argD'] )
|
|
|
|
procArgsD = {
|
|
"src_dir":src_dir,
|
|
"argD": argD
|
|
}
|
|
|
|
taskArgsL = [ { 'midi_pitch':midi_pitch } for midi_pitch in midi_pitchL ]
|
|
|
|
processN = min(processN,len(taskArgsL))
|
|
|
|
if processN > 0:
|
|
pitchL = mp.local_distribute_main( processN,_multi_proc_func,procArgsD,taskArgsL )
|
|
else:
|
|
pitchL = [ _gen_wave_table_bank( src_dir, r['midi_pitch'], argD ) for r in range(taskArgsL) ]
|
|
|
|
|
|
pitchL = sorted(pitchL,key=lambda x:x['midi_pitch'])
|
|
|
|
with open(out_fname,"w") as f:
|
|
json.dump({"pitchL":pitchL, "instr":"piano", "argD":argD},f)
|
|
|
|
|
|
|
|
def plot_rms( wtb_json_fname ):
|
|
|
|
with open(wtb_json_fname) as f:
|
|
pitchL = json.load(f)['pitchL']
|
|
|
|
pitchL = sorted(pitchL,key=lambda x:x['midi_pitch'])
|
|
|
|
rmsnL = []
|
|
for pitchD in pitchL:
|
|
_,ax = plt.subplots(1,1)
|
|
for wtVelD in pitchD['wtL']:
|
|
for velChL in wtVelD['wtL']:
|
|
rmsL = [ 20*math.log10(wt['rms']) for wt in velChL ]
|
|
ax.plot(rmsL)
|
|
rmsnL.append(len(rmsL))
|
|
|
|
plt.title(f"{pitchD['midi_pitch']}")
|
|
plt.show()
|
|
|
|
def plot_atk_dur( wtb_json_fname ):
|
|
|
|
with open(wtb_json_fname) as f:
|
|
pitchL = json.load(f)['pitchL']
|
|
|
|
pitchL = sorted(pitchL,key=lambda x:x['midi_pitch'])
|
|
|
|
rmsnL = []
|
|
for pitchD in pitchL:
|
|
_,ax = plt.subplots(1,1)
|
|
|
|
secL = [ (v['chL'][0][0]['wtbi']-v['bsi'])/pitchD['srate'] for v in pitchD['velL'] ]
|
|
velL = [ x['vel'] for x in pitchD['velL'] ]
|
|
ax.plot(velL,secL,marker=".")
|
|
|
|
plt.title(f"{pitchD['midi_pitch']}")
|
|
plt.show()
|
|
|
|
|
|
def plot_hz( wtb_json_fname ):
|
|
|
|
with open(wtb_json_fname) as f:
|
|
pitchL = json.load(f)['pitchL']
|
|
|
|
pitchL = sorted(pitchL,key=lambda x:x['midi_pitch'])
|
|
|
|
_,ax = plt.subplots(3,1)
|
|
|
|
midiL = [ pitchD['midi_pitch'] for pitchD in pitchL ]
|
|
hzL = [ pitchD["est_hz_mean"] for pitchD in pitchL ]
|
|
hzStdL = [ pitchD["est_hz_std_cents"] for pitchD in pitchL ]
|
|
hzErrL = [ pitchD["est_hz_err_cents"] for pitchD in pitchL ]
|
|
|
|
ax[0].plot(midiL,hzL)
|
|
ax[1].plot(hzL,hzStdL)
|
|
ax[2].hlines([0,10,20],midiL[0],midiL[-1],color="red")
|
|
ax[2].plot(midiL,hzErrL)
|
|
|
|
|
|
plt.show()
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
midi_pitchL = [ pitch for pitch in range(21,109) ]
|
|
#midi_pitchL = [60 ]
|
|
out_fname = "/home/kevin/temp/temp_5.json"
|
|
src_dir= "/home/kevin/temp/wt6"
|
|
|
|
argD = {
|
|
'rms_wnd_ms':50,
|
|
'rms_hop_ms':10,
|
|
'atk_min_dur_ms':1000,
|
|
'atk_end_thresh_db':-43.0,
|
|
'min_wt_db':-80.0,
|
|
'pitch_track_interval_secs':0.25,
|
|
'wt_interval_down_sample_fact':8.0, # wt_interval_secs = pitch_track_interval_secs * wt_interval_down_sample_fact
|
|
'est_hz': {
|
|
'min_vel':50,
|
|
'min_wt_idx':2,
|
|
'max_wt_idx':4
|
|
},
|
|
'ac': {
|
|
'cycle_cnt':8, # count of cycles to use for auto-corr. pitch detection
|
|
'interp_degree':"cubic",
|
|
'up_fact':2,
|
|
'up_interp_degree':"cubic"
|
|
}
|
|
}
|
|
|
|
gen_wave_table_bank_mp(20, src_dir, midi_pitchL, out_fname, argD )
|
|
|
|
#plot_rms(out_fname)
|
|
#plot_hz(out_fname)
|
|
plot_atk_dur(out_fname)
|
|
|