picadae calibration programs
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

calibrate.py 18KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526
  1. ##| Copyright: (C) 2019-2020 Kevin Larke <contact AT larke DOT org>
  2. ##| License: GNU GPL version 3.0 or above. See the accompanying LICENSE file.
  3. import os,types,wave,json,array
  4. import numpy as np
  5. from rms_analysis import rms_analyze_one_rt_note
  6. from plot_seq_1 import get_merged_pulse_db_measurements
  7. class Calibrate:
  8. def __init__( self, cfg, audio, midi, api ):
  9. self.cfg = types.SimpleNamespace(**cfg)
  10. self.audio = audio
  11. self.midi = midi
  12. self.api = api
  13. self.state = "stopped" # stopped | started | note_on | note_off | analyzing
  14. self.playOnlyFl = False
  15. self.startMs = None
  16. self.nextStateChangeMs = None
  17. self.curHoldDutyCyclePctD = None # { pitch:dutyPct}
  18. self.noteAnnotationL = [] # (noteOnMs,noteOffMs,pitch,pulseUs)
  19. self.measD = None # { midi_pitch: [ {pulseUs, db, durMs, targetDb } ] }
  20. self.initPulseDbListD = self._get_init_pulseDbD()
  21. self.curNoteStartMs = None
  22. self.curPitchIdx = None
  23. self.curTargetDbIdx = None
  24. self.successN = None
  25. self.failN = None
  26. self.curTargetDb = None
  27. self.curPulseUs = None
  28. self.curMatchN = None
  29. self.curAttemptN = None
  30. self.lastAudiblePulseUs = None
  31. self.maxTooShortPulseUs = None
  32. self.pulseDbL = None
  33. self.deltaUpMult = None
  34. self.deltaDnMult = None
  35. self.skipMeasFl = None
  36. def start(self,ms):
  37. self.stop(ms)
  38. self.state = 'started'
  39. self.playOnlyFl = False
  40. self.nextStateChangeMs = ms + 500
  41. self.startMs = ms
  42. self.curPitchIdx = 0
  43. self.curPulseUs = self.cfg.initPulseUs
  44. self.lastAudiblePulseUs = None
  45. self.maxTooShortPulseUs = None
  46. self.pulseDbL = []
  47. self.pulseDbL = self.initPulseDbListD[ self.cfg.pitchL[ self.curPitchIdx ] ]
  48. self.deltaUpMult = 1
  49. self.deltaDnMult = 1
  50. self.curTargetDbIdx = -1
  51. self._start_new_db_target()
  52. self.curDutyPctD = {}
  53. self.skipMeasFl = False
  54. self.measD = {}
  55. self.successN = 0
  56. self.failN = 0
  57. self.audio.record_enable(True)
  58. def stop(self,ms):
  59. if self.midi is not None:
  60. self.midi.send_all_notes_off()
  61. self.audio.record_enable(False)
  62. if not self.playOnlyFl:
  63. self._save_results()
  64. def play(self,ms):
  65. if self.measD is None or len(self.measD) == 0:
  66. print("Nothing to play.")
  67. else:
  68. self.startMs = ms
  69. self.state = 'started'
  70. self.playOnlyFl = True
  71. self.nextStateChangeMs = ms + 500
  72. self.curPitchIdx = -1
  73. self.curTargetDbIdx = 0
  74. self.audio.record_enable(True)
  75. self._do_play_only_update()
  76. def tick(self,ms):
  77. if self.nextStateChangeMs is not None and ms > self.nextStateChangeMs:
  78. if self.state == 'stopped':
  79. pass
  80. elif self.state == 'started':
  81. self._do_note_on(ms)
  82. self.nextStateChangeMs += self.cfg.noteOnDurMs
  83. self.state = 'note_on'
  84. elif self.state == 'note_on':
  85. self._do_note_off(ms)
  86. self.nextStateChangeMs += self.cfg.noteOffDurMs
  87. self.state = 'note_off'
  88. elif self.state == 'note_off':
  89. if self.playOnlyFl:
  90. if not self._do_play_only_update():
  91. self.stop(ms)
  92. self.state = 'stopped'
  93. else:
  94. if self._do_analysis(ms):
  95. if not self._start_new_db_target():
  96. self.stop(ms)
  97. self.state = 'stopped'
  98. print("DONE!")
  99. # if the state was not changed to 'stopped'
  100. if self.state == 'note_off':
  101. self.state = 'started'
  102. def _calc_play_only_pulse_us( self, pitch, targetDb ):
  103. pulseDbL = []
  104. for d in self.measD[ pitch ]:
  105. if d['targetDb'] == targetDb and d['matchFl']==True:
  106. pulseDbL.append( ( d['pulse_us'], d[self.cfg.dbSrcLabel]['db']) )
  107. if len(pulseDbL) == 0:
  108. return -1
  109. pulseL,dbL = zip(*pulseDbL)
  110. # TODO: make a weighted average based on db error
  111. return np.mean(pulseL)
  112. def _do_play_only_update( self ):
  113. if self.curPitchIdx >= 0:
  114. self._meas_note( self.cfg.pitchL[self.curPitchIdx], self.curPulseUs )
  115. self.curPitchIdx +=1
  116. if self.curPitchIdx >= len(self.cfg.pitchL):
  117. self.curPitchIdx = 0
  118. self.curTargetDbIdx += 1
  119. if self.curTargetDbIdx >= len(self.cfg.targetDbL):
  120. return False
  121. pitch = self.cfg.pitchL[ self.curPitchIdx ]
  122. targetDb = self.cfg.targetDbL[ self.curTargetDbIdx ]
  123. self.curPulseUs = self._calc_play_only_pulse_us( pitch, targetDb )
  124. self.curTargetDb = targetDb
  125. if self.curPulseUs == -1:
  126. print("Pitch:%i TargetDb:%f not found." % (pitch,targetDb))
  127. return False
  128. print("Target db: %4.1f" % (targetDb))
  129. return True
  130. def _get_init_pulseDbD( self ):
  131. initPulseDbListD = {}
  132. print("Calculating initial calibration search us/db lists ...")
  133. if self.cfg.inDir is not None:
  134. for pitch in self.cfg.pitchL:
  135. print(pitch)
  136. inDir = os.path.expanduser( self.cfg.inDir )
  137. usL,dbL,_,_,_ = get_merged_pulse_db_measurements( inDir, pitch, self.cfg.analysisD )
  138. initPulseDbListD[pitch] = [ (us,db) for us,db in zip(usL,dbL) ]
  139. return initPulseDbListD
  140. def _get_duty_cycle( self, pitch, pulseUsec ):
  141. dutyPct = 50
  142. if pitch in self.cfg.holdDutyPctD:
  143. dutyPct = self.cfg.holdDutyPctD[pitch][0][1]
  144. for refUsec,refDuty in self.cfg.holdDutyPctD[pitch]:
  145. if pulseUsec < refUsec:
  146. break
  147. dutyPct = refDuty
  148. return dutyPct
  149. def _set_duty_cycle( self, pitch, pulseUsec ):
  150. dutyPct = self._get_duty_cycle( pitch, pulseUsec )
  151. if pitch not in self.curDutyPctD or self.curDutyPctD[pitch] != dutyPct:
  152. self.curDutyPctD[pitch] = dutyPct
  153. self.api.set_pwm_duty( pitch, dutyPct )
  154. print("Hold Duty Set:",dutyPct)
  155. self.skipMeasFl = True
  156. return dutyPct
  157. def _do_note_on(self,ms):
  158. self.curNoteStartMs = ms
  159. pitch = self.cfg.pitchL[ self.curPitchIdx]
  160. if self.midi is not None:
  161. self.midi.send_note_on( pitch, 60 )
  162. else:
  163. self._set_duty_cycle( pitch, self.curPulseUs )
  164. self.api.note_on_us( pitch, self.curPulseUs )
  165. print("note-on: ",pitch," ",self.curPulseUs," us")
  166. def _do_note_off(self,ms):
  167. self.noteAnnotationL.append( { 'beg_ms':self.curNoteStartMs-self.startMs, 'end_ms':ms-self.startMs, 'midi_pitch':self.cfg.pitchL[ self.curPitchIdx], 'pulse_us':self.curPulseUs } )
  168. if self.midi is not None:
  169. self.midi.send_note_off( self.cfg.pitchL[ self.curPitchIdx] )
  170. else:
  171. for pitch in self.cfg.pitchL:
  172. self.api.note_off( pitch )
  173. #print("note-off: ",self.cfg.pitchL[ self.curPitchIdx])
  174. def _proportional_step( self, targetDb, dbL, pulseL ):
  175. curPulse,curDb = self.pulseDbL[-1]
  176. # get the point closest to the target db
  177. i = np.argmin( np.array(dbL) - targetDb )
  178. # find the percentage difference to the target - based on the closest point
  179. pd = abs(curDb-targetDb) / abs(curDb - dbL[i])
  180. #
  181. delta_pulse = pd * abs(curPulse - pulseL[i])
  182. print("prop:",pd,"delta_pulse:",delta_pulse)
  183. return int(round(curPulse + np.sign(targetDb - curDb) * delta_pulse))
  184. def _step( self, targetDb ):
  185. # get the last two pulse/db samples
  186. pulse0,db0 = self.pulseDbL[-2]
  187. pulse1,db1 = self.pulseDbL[-1]
  188. # microseconds per decibel for the last two points
  189. us_per_db = abs(pulse0-pulse1) / abs(db0-db1)
  190. if us_per_db == 0:
  191. us_per_db = 10 # ************************************** CONSTANT ***********************
  192. # calcuate the decibels we need to move from the last point
  193. error_db = targetDb - db1
  194. print("us_per_db:",us_per_db," error db:", error_db )
  195. return pulse1 + us_per_db * error_db
  196. def _calc_next_pulse_us( self, targetDb ):
  197. # sort pulseDb ascending on db
  198. pulseDbL = sorted( self.pulseDbL, key=lambda x: x[1] )
  199. # get the set of us/db values tried so far
  200. pulseL,dbL = zip(*pulseDbL)
  201. max_i = np.argmax(dbL)
  202. min_i = np.argmin(dbL)
  203. # if the targetDb is greater than the max. db value achieved so far
  204. if targetDb > dbL[max_i]:
  205. pu = pulseL[max_i] + self.deltaUpMult * 500
  206. self.deltaUpMult += 1
  207. # if the targetDb is less than the min. db value achieved so far
  208. elif targetDb < dbL[min_i]:
  209. pu = pulseL[min_i] - self.deltaDnMult * 500
  210. self.deltaDnMult += 1
  211. if self.maxTooShortPulseUs is not None and pu < self.maxTooShortPulseUs:
  212. # BUG: this is a problem is self.pulseL[min_i] is <= than self.maxTooShortPulseUs
  213. # the abs() covers the problem to prevent decreasing from maxTooShortPulseus
  214. pu = self.maxTooShortPulseUs + (abs(pulseL[min_i] - self.maxTooShortPulseUs))/2
  215. self.deltaDnMult = 1
  216. else:
  217. # the targetDb value is inside the min/max range of the db values acheived so far
  218. self.deltaUpMult = 1
  219. self.deltaDnMult = 1
  220. # interpolate the new pulse value based on the values seen so far
  221. # TODO: use only closest 5 values rather than all values
  222. pu = np.interp([targetDb],dbL,pulseL)
  223. # the selected pulse has already been sampled
  224. if int(pu) in pulseL:
  225. pu = self._step(targetDb )
  226. return max(min(pu,self.cfg.maxPulseUs),self.cfg.minPulseUs)
  227. def _do_analysis(self,ms):
  228. analysisDoneFl = False
  229. midi_pitch = self.cfg.pitchL[self.curPitchIdx]
  230. pulse_us = self.curPulseUs
  231. measD = self._meas_note(midi_pitch,pulse_us)
  232. # if the the 'skip' flag is set then don't analyze this note
  233. if self.skipMeasFl:
  234. self.skipMeasFl = False
  235. print("SKIP")
  236. else:
  237. db = measD[self.cfg.dbSrcLabel]['db']
  238. durMs = measD['hm']['durMs']
  239. # if this note is shorter than the minimum allowable duration
  240. if durMs < self.cfg.minMeasDurMs:
  241. print("SHORT!")
  242. if self.maxTooShortPulseUs is None or self.curPulseUs > self.maxTooShortPulseUs:
  243. self.maxTooShortPulseUs = self.curPulseUs
  244. if self.lastAudiblePulseUs is not None and self.curPulseUs < self.lastAudiblePulseUs:
  245. self.curPulseUs = self.lastAudiblePulseUs
  246. else:
  247. self.curPulseUs = self.cfg.initPulseUs
  248. else:
  249. # this is a valid measurement, store it to the pulse-db table
  250. self.pulseDbL.append( (self.curPulseUs,db) )
  251. # track the most recent audible note (to return to if a successive note is too short)
  252. self.lastAudiblePulseUs = self.curPulseUs
  253. # calc the upper and lower bounds db range
  254. lwr_db = self.curTargetDb * ((100.0 - self.cfg.tolDbPct)/100.0)
  255. upr_db = self.curTargetDb * ((100.0 + self.cfg.tolDbPct)/100.0)
  256. # if this note was inside the db range then set the 'match' flag
  257. if lwr_db <= db and db <= upr_db:
  258. self.curMatchN += 1
  259. measD['matchFl'] = True
  260. print("MATCH!")
  261. # calculate the next pulse length
  262. self.curPulseUs = int(self._calc_next_pulse_us(self.curTargetDb))
  263. # if at least minMatchN matches have been made on this pitch/targetDb
  264. if self.curMatchN >= self.cfg.minMatchN:
  265. analysisDoneFl = True
  266. self.successN += 1
  267. print("Anysis Done: Success")
  268. # if at least maxAttemptN match attempts have been made without success
  269. self.curAttemptN += 1
  270. if self.curAttemptN >= self.cfg.maxAttemptN:
  271. analysisDoneFl = True
  272. self.failN += 1
  273. print("Analysis Done: Fail")
  274. if midi_pitch not in self.measD:
  275. self.measD[ midi_pitch ] = []
  276. self.measD[ midi_pitch ].append( measD )
  277. return analysisDoneFl
  278. def _meas_note(self,midi_pitch,pulse_us):
  279. # get the annotation information for the last note
  280. annD = self.noteAnnotationL[-1]
  281. buf_result = self.audio.linear_buffer()
  282. if buf_result:
  283. sigV = buf_result.value
  284. # get the annotated begin and end of the note as sample indexes into sigV
  285. bi = int(round(annD['beg_ms'] * self.audio.srate / 1000))
  286. ei = int(round(annD['end_ms'] * self.audio.srate / 1000))
  287. # calculate half the length of the note-off duration in samples
  288. noteOffSmp_o_2 = int(round( (self.cfg.noteOffDurMs/2) * self.audio.srate / 1000))
  289. # widen the note analysis space noteOffSmp_o_2 samples pre/post the annotated begin/end of the note
  290. bi = max(0,bi - noteOffSmp_o_2)
  291. ei = min(ei+noteOffSmp_o_2,sigV.shape[0]-1)
  292. ar = types.SimpleNamespace(**self.cfg.analysisD)
  293. # shift the annotatd begin/end of the note to be relative to index bi
  294. begMs = noteOffSmp_o_2 * 1000 / self.audio.srate
  295. endMs = begMs + (annD['end_ms'] - annD['beg_ms'])
  296. #print("MEAS:",begMs,endMs,bi,ei,sigV.shape,self.audio.is_recording_enabled(),ar)
  297. # analyze the note
  298. resD = rms_analyze_one_rt_note( sigV[bi:ei], self.audio.srate, begMs, endMs, midi_pitch, rmsWndMs=ar.rmsWndMs, rmsHopMs=ar.rmsHopMs, dbLinRef=ar.dbLinRef, harmCandN=ar.harmCandN, harmN=ar.harmN, durDecayPct=ar.durDecayPct )
  299. resD["pulse_us"] = pulse_us
  300. resD["midi_pitch"] = midi_pitch
  301. resD["beg_ms"] = annD['beg_ms']
  302. resD['end_ms'] = annD['end_ms']
  303. resD['skipMeasFl'] = self.skipMeasFl
  304. resD['matchFl'] = False
  305. resD['targetDb'] = self.curTargetDb
  306. resD['annIdx'] = len(self.noteAnnotationL)-1
  307. print( "%4.1f hm:%4.1f (%4.1f) %4i td:%4.1f (%4.1f) %4i" % (self.curTargetDb,resD['hm']['db'], resD['hm']['db']-self.curTargetDb, resD['hm']['durMs'], resD['td']['db'], resD['td']['db']-self.curTargetDb, resD['td']['durMs']))
  308. return resD
  309. def _start_new_db_target(self):
  310. self.curTargetDbIdx += 1
  311. # if all db targets have been queried then advance to the next pitch
  312. if self.curTargetDbIdx >= len(self.cfg.targetDbL):
  313. self.curTargetDbIdx = 0
  314. self.curPitchIdx += 1
  315. # if all pitches have been queried then we are done
  316. if self.curPitchIdx >= len(self.cfg.pitchL):
  317. return False
  318. # reset the variables prior to begining the next target search
  319. self.curTargetDb = self.cfg.targetDbL[ self.curTargetDbIdx ]
  320. self.curMatchN = 0
  321. self.curAttemptN = 0
  322. self.lastAudiblePulseUs = None
  323. self.maxTooShortPulseUs = None
  324. self.pulseDbL = []
  325. self.pulseDbL = self.initPulseDbListD[ self.cfg.pitchL[ self.curPitchIdx ] ]
  326. self.deltaUpMult = 1
  327. self.deltaDnMult = 1
  328. return True
  329. def _write_16_bit_wav_file( self, fn ):
  330. srate = int(self.audio.srate)
  331. buf_result = self.audio.linear_buffer()
  332. sigV = buf_result.value
  333. smpN = sigV.shape[0]
  334. chN = 1
  335. sigV = np.squeeze(sigV.reshape( smpN * chN, )) * 0x7fff
  336. sigL = [ int(round(sigV[i])) for i in range(smpN) ]
  337. sigA = array.array('h',sigL)
  338. with wave.open( fn, "wb") as f:
  339. bits = 16
  340. bits_per_byte = 8
  341. f.setparams((chN, bits//bits_per_byte, srate, 0, 'NONE', 'not compressed'))
  342. f.writeframes(sigA)
  343. def _save_results( self ):
  344. if self.measD is None or len(self.measD) == 0:
  345. return
  346. outDir = os.path.expanduser( self.cfg.outDir )
  347. if not os.path.isdir(outDir):
  348. os.mkdir(outDir)
  349. outDir = os.path.join( outDir, self.cfg.outLabel )
  350. if not os.path.isdir(outDir):
  351. os.mkdir(outDir)
  352. i = 0
  353. while( os.path.isdir( os.path.join(outDir,"%i" % i )) ):
  354. i += 1
  355. outDir = os.path.join( outDir, "%i" % i )
  356. os.mkdir(outDir)
  357. self._write_16_bit_wav_file( os.path.join(outDir,"audio.wav"))
  358. d = {'cfg':self.cfg.__dict__, 'measD': self.measD, 'annoteL':self.noteAnnotationL }
  359. with open( os.path.join(outDir,"meas.json"), "w") as f:
  360. json.dump(d,f)