428 lines
14 KiB
Python
428 lines
14 KiB
Python
|
import os,types,wave,json,array
|
||
|
import numpy as np
|
||
|
from rms_analysis import rms_analyze_one_note
|
||
|
|
||
|
class Calibrate:
|
||
|
def __init__( self, cfg, audio, midi, api ):
|
||
|
self.cfg = types.SimpleNamespace(**cfg)
|
||
|
self.audio = audio
|
||
|
self.midi = midi
|
||
|
self.api = api
|
||
|
self.state = "stopped" # stopped | started | note_on | note_off | analyzing
|
||
|
self.playOnlyFl = False
|
||
|
self.startMs = None
|
||
|
self.nextStateChangeMs = None
|
||
|
self.curHoldDutyCyclePctD = None # { pitch:dutyPct}
|
||
|
self.noteAnnotationL = [] # (noteOnMs,noteOffMs,pitch,pulseUs)
|
||
|
|
||
|
self.measD = None # { midi_pitch: [ {pulseUs, db, durMs, targetDb } ] }
|
||
|
|
||
|
self.curNoteStartMs = None
|
||
|
self.curPitchIdx = None
|
||
|
self.curTargetDbIdx = None
|
||
|
self.successN = None
|
||
|
self.failN = None
|
||
|
|
||
|
self.curTargetDb = None
|
||
|
self.curPulseUs = None
|
||
|
self.curMatchN = None
|
||
|
self.curAttemptN = None
|
||
|
self.lastAudiblePulseUs = None
|
||
|
self.maxTooShortPulseUs = None
|
||
|
self.pulseDbL = None
|
||
|
self.deltaUpMult = None
|
||
|
self.deltaDnMult = None
|
||
|
self.skipMeasFl = None
|
||
|
|
||
|
def start(self,ms):
|
||
|
self.stop(ms)
|
||
|
self.state = 'started'
|
||
|
self.playOnlyFl = False
|
||
|
self.nextStateChangeMs = ms + 500
|
||
|
|
||
|
self.startMs = ms
|
||
|
|
||
|
self.curPitchIdx = 0
|
||
|
self.curPulseUs = self.cfg.initPulseUs
|
||
|
self.lastAudiblePulseUs = None
|
||
|
self.maxTooShortPulseUs = None
|
||
|
self.pulseDbL = []
|
||
|
self.deltaUpMult = 1
|
||
|
self.deltaDnMult = 1
|
||
|
self.curTargetDbIdx = -1
|
||
|
self._start_new_db_target()
|
||
|
|
||
|
self.curDutyPctD = {}
|
||
|
self.skipMeasFl = False
|
||
|
self.measD = {}
|
||
|
|
||
|
self.successN = 0
|
||
|
self.failN = 0
|
||
|
self.audio.record_enable(True)
|
||
|
|
||
|
def stop(self,ms):
|
||
|
|
||
|
if self.midi is not None:
|
||
|
self.midi.send_all_notes_off()
|
||
|
|
||
|
if not self.playOnlyFl:
|
||
|
self.audio.record_enable(False)
|
||
|
|
||
|
self._save_results()
|
||
|
|
||
|
def play(self,ms):
|
||
|
|
||
|
if self.measD is None or len(self.measD) == 0:
|
||
|
print("Nothing to play.")
|
||
|
else:
|
||
|
self.state = 'started'
|
||
|
self.playOnlyFl = True
|
||
|
self.nextStateChangeMs = ms + 500
|
||
|
self.curPitchIdx = -1
|
||
|
self.curTargetDbIdx = 0
|
||
|
self._do_play_update()
|
||
|
|
||
|
def tick(self,ms):
|
||
|
|
||
|
if self.nextStateChangeMs is not None and ms > self.nextStateChangeMs:
|
||
|
|
||
|
if self.state == 'stopped':
|
||
|
pass
|
||
|
|
||
|
elif self.state == 'started':
|
||
|
self._do_note_on(ms)
|
||
|
self.nextStateChangeMs += self.cfg.noteOnDurMs
|
||
|
self.state = 'note_on'
|
||
|
|
||
|
elif self.state == 'note_on':
|
||
|
self._do_note_off(ms)
|
||
|
self.nextStateChangeMs += self.cfg.noteOffDurMs
|
||
|
self.state = 'note_off'
|
||
|
|
||
|
elif self.state == 'note_off':
|
||
|
if self.playOnlyFl:
|
||
|
if not self._do_play_update():
|
||
|
self.state = 'stopped'
|
||
|
else:
|
||
|
if self._do_analysis(ms):
|
||
|
if not self._start_new_db_target():
|
||
|
self.stop(ms)
|
||
|
self.state = 'stopped'
|
||
|
print("DONE!")
|
||
|
|
||
|
# if the state was not changed to 'stopped'
|
||
|
if self.state == 'note_off':
|
||
|
self.state = 'started'
|
||
|
|
||
|
|
||
|
def _do_play_update( self ):
|
||
|
|
||
|
self.curPitchIdx +=1
|
||
|
if self.curPitchIdx >= len(self.cfg.pitchL):
|
||
|
self.curPitchIdx = 0
|
||
|
self.curTargetDbIdx += 1
|
||
|
if self.curTargetDbIdx >= len(self.cfg.targetDbL):
|
||
|
return False
|
||
|
|
||
|
pitch = self.cfg.pitchL[ self.curPitchIdx ]
|
||
|
targetDb = self.cfg.targetDbL[ self.curTargetDbIdx ]
|
||
|
self.curPulseUs = -1
|
||
|
for d in self.measD[ pitch ]:
|
||
|
if d['targetDb'] == targetDb and d['matchFl']==True:
|
||
|
self.curPulseUs = d['pulse_us']
|
||
|
break
|
||
|
|
||
|
if self.curPulseUs == -1:
|
||
|
print("Pitch:%i TargetDb:%f not found." % (pitch,targetDb))
|
||
|
return False
|
||
|
|
||
|
print("Target db: %4.1f" % (targetDb))
|
||
|
|
||
|
return True
|
||
|
|
||
|
|
||
|
|
||
|
def _get_duty_cycle( self, pitch, pulseUsec ):
|
||
|
|
||
|
dutyPct = 50
|
||
|
|
||
|
if pitch in self.cfg.holdDutyPctD:
|
||
|
|
||
|
dutyPct = self.cfg.holdDutyPctD[pitch][0][1]
|
||
|
for refUsec,refDuty in self.cfg.holdDutyPctD[pitch]:
|
||
|
if pulseUsec < refUsec:
|
||
|
break
|
||
|
dutyPct = refDuty
|
||
|
|
||
|
return dutyPct
|
||
|
|
||
|
def _set_duty_cycle( self, pitch, pulseUsec ):
|
||
|
|
||
|
dutyPct = self._get_duty_cycle( pitch, pulseUsec )
|
||
|
|
||
|
if pitch not in self.curDutyPctD or self.curDutyPctD[pitch] != dutyPct:
|
||
|
self.curDutyPctD[pitch] = dutyPct
|
||
|
self.api.set_pwm_duty( pitch, dutyPct )
|
||
|
print("Hold Duty Set:",dutyPct)
|
||
|
self.skipMeasFl = True
|
||
|
|
||
|
return dutyPct
|
||
|
|
||
|
def _do_note_on(self,ms):
|
||
|
self.curNoteStartMs = ms
|
||
|
|
||
|
pitch = self.cfg.pitchL[ self.curPitchIdx]
|
||
|
|
||
|
if self.midi is not None:
|
||
|
self.midi.send_note_on( pitch, 60 )
|
||
|
else:
|
||
|
self._set_duty_cycle( pitch, self.curPulseUs )
|
||
|
self.api.note_on_us( pitch, self.curPulseUs )
|
||
|
|
||
|
|
||
|
print("note-on: ",pitch," ",self.curPulseUs," us")
|
||
|
|
||
|
def _do_note_off(self,ms):
|
||
|
self.noteAnnotationL.append( { 'beg_ms':self.curNoteStartMs-self.startMs, 'end_ms':ms-self.startMs, 'midi_pitch':self.cfg.pitchL[ self.curPitchIdx], 'pulse_us':self.curPulseUs } )
|
||
|
|
||
|
if self.midi is not None:
|
||
|
self.midi.send_note_off( self.cfg.pitchL[ self.curPitchIdx] )
|
||
|
else:
|
||
|
for pitch in self.cfg.pitchL:
|
||
|
self.api.note_off( pitch )
|
||
|
|
||
|
|
||
|
#print("note-off: ",self.cfg.pitchL[ self.curPitchIdx])
|
||
|
|
||
|
|
||
|
def _calc_next_pulse_us( self, targetDb ):
|
||
|
|
||
|
# sort pulseDb ascending on db
|
||
|
self.pulseDbL = sorted( self.pulseDbL, key=lambda x: x[1] )
|
||
|
|
||
|
|
||
|
pulseL,dbL = zip(*self.pulseDbL)
|
||
|
|
||
|
max_i = np.argmax(dbL)
|
||
|
min_i = np.argmin(dbL)
|
||
|
|
||
|
if targetDb > dbL[max_i]:
|
||
|
pu = pulseL[max_i] + self.deltaUpMult * 500
|
||
|
self.deltaUpMult += 1
|
||
|
|
||
|
elif targetDb < dbL[min_i]:
|
||
|
pu = pulseL[min_i] - self.deltaDnMult * 500
|
||
|
self.deltaDnMult += 1
|
||
|
if self.maxTooShortPulseUs is not None and pu < self.maxTooShortPulseUs:
|
||
|
# BUG: this is a problem is self.pulseL[min_i] is <= than self.maxTooShortPulseUs
|
||
|
# the abs() covers the problem to prevent decreasing from maxTooShortPulseus
|
||
|
pu = self.maxTooShortPulseUs + (abs(pulseL[min_i] - self.maxTooShortPulseUs))/2
|
||
|
self.deltaDnMult = 1
|
||
|
else:
|
||
|
self.deltaUpMult = 1
|
||
|
self.deltaDnMult = 1
|
||
|
pu = np.interp([targetDb],dbL,pulseL)
|
||
|
|
||
|
return max(min(pu,self.cfg.maxPulseUs),self.cfg.minPulseUs)
|
||
|
|
||
|
def _do_analysis(self,ms):
|
||
|
|
||
|
analysisDoneFl = False
|
||
|
midi_pitch = self.cfg.pitchL[self.curPitchIdx]
|
||
|
pulse_us = self.curPulseUs
|
||
|
|
||
|
measD = self._meas_note(midi_pitch,pulse_us)
|
||
|
|
||
|
# if the the 'skip' flag is set then don't analyze this note
|
||
|
if self.skipMeasFl:
|
||
|
self.skipMeasFl = False
|
||
|
print("SKIP")
|
||
|
else:
|
||
|
|
||
|
db = measD[self.cfg.dbSrcLabel]['db']
|
||
|
durMs = measD['hm']['durMs']
|
||
|
|
||
|
# if this note is shorter than the minimum allowable duration
|
||
|
if durMs < self.cfg.minMeasDurMs:
|
||
|
|
||
|
print("SHORT!")
|
||
|
|
||
|
if self.maxTooShortPulseUs is None or self.curPulseUs > self.maxTooShortPulseUs:
|
||
|
self.maxTooShortPulseUs = self.curPulseUs
|
||
|
|
||
|
if self.lastAudiblePulseUs is not None and self.curPulseUs < self.lastAudiblePulseUs:
|
||
|
self.curPulseUs = self.lastAudiblePulseUs
|
||
|
else:
|
||
|
self.curPulseUs = self.cfg.initPulseUs
|
||
|
|
||
|
else:
|
||
|
|
||
|
# this is a valid measurement store it to the pulse-db table
|
||
|
self.pulseDbL.append( (self.curPulseUs,db) )
|
||
|
|
||
|
# track the most recent audible note - to return to if a successive note is too short
|
||
|
self.lastAudiblePulseUs = self.curPulseUs
|
||
|
|
||
|
# calc the upper and lower bounds db range
|
||
|
lwr_db = self.curTargetDb * ((100.0 - self.cfg.tolDbPct)/100.0)
|
||
|
upr_db = self.curTargetDb * ((100.0 + self.cfg.tolDbPct)/100.0)
|
||
|
|
||
|
# was this note is inside the db range then set the 'match' flag
|
||
|
if lwr_db <= db and db <= upr_db:
|
||
|
self.curMatchN += 1
|
||
|
measD['matchFl'] = True
|
||
|
print("MATCH!")
|
||
|
|
||
|
#
|
||
|
self.curPulseUs = int(self._calc_next_pulse_us(self.curTargetDb))
|
||
|
|
||
|
# if at least minMatchN matches have been made on this pitch/targetDb
|
||
|
if self.curMatchN >= self.cfg.minMatchN:
|
||
|
analysisDoneFl = True
|
||
|
self.successN += 1
|
||
|
print("Anysis Done: Success")
|
||
|
|
||
|
# if at least maxAttemptN match attempts have been made without success
|
||
|
self.curAttemptN += 1
|
||
|
if self.curAttemptN >= self.cfg.maxAttemptN:
|
||
|
analysisDoneFl = True
|
||
|
self.failN += 1
|
||
|
print("Analysis Done: Fail")
|
||
|
|
||
|
|
||
|
if midi_pitch not in self.measD:
|
||
|
self.measD[ midi_pitch ] = []
|
||
|
|
||
|
self.measD[ midi_pitch ].append( measD )
|
||
|
|
||
|
return analysisDoneFl
|
||
|
|
||
|
|
||
|
def _meas_note(self,midi_pitch,pulse_us):
|
||
|
|
||
|
# get the annotation information for the last note
|
||
|
annD = self.noteAnnotationL[-1]
|
||
|
|
||
|
buf_result = self.audio.linear_buffer()
|
||
|
|
||
|
if buf_result:
|
||
|
|
||
|
sigV = buf_result.value
|
||
|
|
||
|
# get the annotated begin and end of the note as sample indexes into sigV
|
||
|
bi = int(round(annD['beg_ms'] * self.audio.srate / 1000))
|
||
|
ei = int(round(annD['end_ms'] * self.audio.srate / 1000))
|
||
|
|
||
|
# calculate half the length of the note-off duration in samples
|
||
|
noteOffSmp_o_2 = int(round(self.cfg.noteOffDurMs/2 * self.audio.srate / 1000))
|
||
|
|
||
|
# widen the note analysis space noteOffSmp_o_2 samples pre/post the annotated begin/end of the note
|
||
|
bi = max(0,bi - noteOffSmp_o_2)
|
||
|
ei = min(noteOffSmp_o_2,sigV.shape[0]-1)
|
||
|
|
||
|
|
||
|
ar = types.SimpleNamespace(**self.cfg.analysisD)
|
||
|
|
||
|
# shift the annotatd begin/end of the note to be relative to index bi
|
||
|
begMs = noteOffSmp_o_2 * 1000 / self.audio.srate
|
||
|
endMs = begMs + (annD['end_ms'] - annD['beg_ms'])
|
||
|
|
||
|
# analyze the note
|
||
|
resD = rms_analyze_rt_one_note( sigV[bi:ei], self.audio.srate, begMs, endMs, midi_pitch, rmsWndMs=ar.rmsWndMs, rmsHopMs=ar.rmsHopMs, dbRefWndMs=ar.dbRefWndMs, harmCandN=ar.harmCandN, harmN=ar.harmN, durDecayPct=ar.durDecayPct )
|
||
|
|
||
|
resD["pulse_us"] = pulse_us
|
||
|
resD["midi_pitch"] = midi_pitch
|
||
|
resD["beg_ms"] = annD['beg_ms']
|
||
|
resD['end_ms'] = annD['end_ms']
|
||
|
resD['skipMeasFl'] = self.skipMeasFl
|
||
|
resD['matchFl'] = False
|
||
|
resD['targetDb'] = self.curTargetDb
|
||
|
resD['annIdx'] = len(self.noteAnnotationL)-1
|
||
|
|
||
|
print( "%4.1f hm:%4.1f (%4.1f) %4i td:%4.1f (%4.1f) %4i" % (self.curTargetDb,resD['hm']['db'], resD['hm']['db']-self.curTargetDb, resD['hm']['durMs'], resD['td']['db'], resD['td']['db']-self.curTargetDb, resD['td']['durMs']))
|
||
|
|
||
|
|
||
|
return resD
|
||
|
|
||
|
|
||
|
|
||
|
def _start_new_db_target(self):
|
||
|
|
||
|
self.curTargetDbIdx += 1
|
||
|
|
||
|
# if all db targets have been queried then advance to the next pitch
|
||
|
if self.curTargetDbIdx >= len(self.cfg.targetDbL):
|
||
|
|
||
|
self.curTargetDbIdx = 0
|
||
|
self.curPitchIdx += 1
|
||
|
|
||
|
# if all pitches have been queried then we are done
|
||
|
if self.curPitchIdx >= len(self.cfg.pitchL):
|
||
|
return False
|
||
|
|
||
|
|
||
|
self.curTargetDb = self.cfg.targetDbL[ self.curTargetDbIdx ]
|
||
|
self.curMatchN = 0
|
||
|
self.curAttemptN = 0
|
||
|
self.lastAudiblePulseUs = None
|
||
|
self.maxTooShortPulseUs = None
|
||
|
self.pulseDbL = []
|
||
|
self.deltaUpMult = 1
|
||
|
self.deltaDnMult = 1
|
||
|
return True
|
||
|
|
||
|
|
||
|
def _write_16_bit_wav_file( self, fn ):
|
||
|
|
||
|
srate = int(self.audio.srate)
|
||
|
|
||
|
buf_result = self.audio.linear_buffer()
|
||
|
|
||
|
sigV = buf_result.value
|
||
|
|
||
|
smpN = sigV.shape[0]
|
||
|
chN = 1
|
||
|
sigV = np.squeeze(sigV.reshape( smpN * chN, )) * 0x7fff
|
||
|
sigL = [ int(round(sigV[i])) for i in range(smpN) ]
|
||
|
|
||
|
sigA = array.array('h',sigL)
|
||
|
|
||
|
with wave.open( fn, "wb") as f:
|
||
|
|
||
|
bits = 16
|
||
|
bits_per_byte = 8
|
||
|
f.setparams((chN, bits//bits_per_byte, srate, 0, 'NONE', 'not compressed'))
|
||
|
|
||
|
f.writeframes(sigA)
|
||
|
|
||
|
def _save_results( self ):
|
||
|
|
||
|
if self.measD is None or len(self.measD) == 0:
|
||
|
return
|
||
|
|
||
|
outDir = os.path.expanduser( self.cfg.outDir )
|
||
|
|
||
|
if not os.path.isdir(outDir):
|
||
|
os.mkdir(outDir)
|
||
|
|
||
|
outDir = os.path.join( outDir, self.cfg.outLabel )
|
||
|
|
||
|
if not os.path.isdir(outDir):
|
||
|
os.mkdir(outDir)
|
||
|
|
||
|
i = 0
|
||
|
while( os.path.isdir( os.path.join(outDir,"%i" % i )) ):
|
||
|
i += 1
|
||
|
|
||
|
outDir = os.path.join( outDir, "%i" % i )
|
||
|
os.mkdir(outDir)
|
||
|
|
||
|
self._write_16_bit_wav_file( os.path.join(outDir,"audio.wav"))
|
||
|
|
||
|
d = {'cfg':self.cfg.__dict__, 'measD': self.measD, 'annoteL':self.noteAnnotationL }
|
||
|
|
||
|
with open( os.path.join(outDir,"meas.json"), "w") as f:
|
||
|
json.dump(d,f)
|
||
|
|
||
|
|