libcm is a C development framework with an emphasis on audio signal processing applications.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

cmRbm.c 18KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701
  1. //| Copyright: (C) 2009-2020 Kevin Larke <contact AT larke DOT org>
  2. //| License: GNU GPL version 3.0 or above. See the accompanying LICENSE file.
  3. #include "cmPrefix.h"
  4. #include "cmGlobal.h"
  5. #include "cmFloatTypes.h"
  6. #include "cmComplexTypes.h"
  7. #include "cmRpt.h"
  8. #include "cmErr.h"
  9. #include "cmCtx.h"
  10. #include "cmMem.h"
  11. #include "cmMallocDebug.h"
  12. #include "cmLinkedHeap.h"
  13. #include "cmMath.h"
  14. #include "cmFile.h"
  15. #include "cmSymTbl.h"
  16. #include "cmTime.h"
  17. #include "cmMidi.h"
  18. #include "cmAudioFile.h"
  19. #include "cmVectOpsTemplateMain.h"
  20. #include "cmStack.h"
  21. #include "cmProcObj.h"
  22. #include "cmProcTemplateMain.h"
  23. #include "cmVectOps.h"
  24. #include "cmProc.h"
  25. #include "cmProc2.h"
  26. #include "cmRbm.h"
  27. typedef struct
  28. {
  29. double trainErr;
  30. double testErr;
  31. } cmRbmMonitor_t;
  32. cmRbmRC_t cmRbmWriteMonitorFile( cmCtx_t* c, cmStackH_t monH, const cmChar_t* fn )
  33. {
  34. cmRbmRC_t rc = kOkRbmRC;
  35. cmCtx* ctx = cmCtxAlloc(NULL, c->err.rpt, cmLHeapNullHandle, cmSymTblNullHandle );
  36. if( cmBinMtxFileWrite(fn, cmStackCount(monH), sizeof(cmRbmMonitor_t)/sizeof(double), NULL, cmStackFlatten(monH), ctx, c->err.rpt ) != cmOkRC )
  37. {
  38. rc = cmErrMsg(&c->err,kMonitorWrFailRbmRC,"Training monitor file '%s' write failed.",cmStringNullGuard(fn));
  39. goto errLabel;
  40. }
  41. errLabel:
  42. cmCtxFree(&ctx);
  43. return rc;
  44. }
  45. double* cmRbmReadDataFile( cmCtx_t* c, const char* fn, unsigned* dimNPtr, unsigned* pointCntPtr )
  46. {
  47. unsigned rowCnt,colCnt,eleByteCnt;
  48. *dimNPtr = 0;
  49. *pointCntPtr = 0;
  50. if( cmBinMtxFileSize(c, fn, &rowCnt, &colCnt, &eleByteCnt ) != cmOkRC )
  51. return NULL;
  52. double* buf = cmMemAllocZ(double,rowCnt*colCnt);
  53. if( cmBinMtxFileRead(c, fn, rowCnt, colCnt, sizeof(double), buf,NULL) != cmOkRC )
  54. {
  55. cmMemFree(buf);
  56. return NULL;
  57. }
  58. *dimNPtr = rowCnt;
  59. *pointCntPtr = colCnt;
  60. return buf;
  61. }
  62. // Generate a matrix of 'pointsN' random binary valued column vectors of dimension dimN.
  63. // The first i = {0...'dimN'-1} elements of each vector contain ones with prob probV[i]
  64. // (or zeros with prob 1 - probV[i].). probV[i] in [0.0,1.0].
  65. // The last element in each column is set to zero.
  66. // The returned matrix m[ dimN+1, pointsN ] is in column major order and
  67. // must be deleted by the caller (e.g. cmMemFree(m)).
  68. double* cmRbmGenBinaryTestData( cmCtx_t* c, const char* fn, const double* probV, unsigned dimN, unsigned pointsN )
  69. {
  70. if( dimN == 0 || pointsN == 0 )
  71. return NULL;
  72. double* m = cmMemAllocZ( double, dimN*pointsN );
  73. unsigned i,j;
  74. for(i=0; i<pointsN; ++i)
  75. for(j=0; j<dimN; ++j)
  76. m[ i*dimN + j ] = rand() < (probV[j] * RAND_MAX);
  77. if( fn != NULL )
  78. cmBinMtxFileWrite(fn,dimN,pointsN,NULL,m,NULL,c->err.rpt);
  79. return m;
  80. }
  81. typedef struct
  82. {
  83. unsigned vN;
  84. double* vs;
  85. double* vp;
  86. double* vb;
  87. double* vd; // std dev. (var = std_dev^2)
  88. unsigned hN;
  89. double* hs;
  90. double* hp;
  91. double* hb;
  92. double* W; // W[vN,hN]
  93. cmStackH_t monH;
  94. } cmRBM_t;
  95. void _cmRbmPrint( cmRBM_t* r, cmRpt_t* rpt )
  96. {
  97. cmVOD_PrintL("hb", rpt, 1, r->hN, r->hb );
  98. cmVOD_PrintL("hp", rpt, 1, r->hN, r->hp );
  99. cmVOD_PrintL("hs", rpt, 1, r->hN, r->hs );
  100. cmVOD_PrintL("vb", rpt, 1, r->vN, r->vb );
  101. cmVOD_PrintL("vp", rpt, 1, r->vN, r->vp );
  102. cmVOD_PrintL("vs", rpt, 1, r->vN, r->vs );
  103. cmVOD_PrintL("W", rpt, r->vN, r->hN, r->W );
  104. }
  105. void _cmRbmRelease( cmRBM_t* r )
  106. {
  107. cmStackFree(&r->monH);
  108. cmMemFree(r);
  109. }
  110. // Adjust the layer geometry to force all sizes to be a multiple of 16 bytes.
  111. // This assumes that all data will be 8 byte doubles.
  112. void _cmRbmAdjustSizes( unsigned* vNp, unsigned* hNp, unsigned* dNp )
  113. {
  114. *vNp = *vNp + (cmIsOddU(*vNp) ? 1 : 0);
  115. *hNp = *hNp + (cmIsOddU(*hNp) ? 1 : 0);
  116. if( dNp != NULL )
  117. *dNp = *dNp + (cmIsOddU(*dNp) ? 1 : 0);
  118. }
  119. cmRBM_t* _cmRbmAlloc( cmCtx_t* ctx, unsigned vN, unsigned hN )
  120. {
  121. unsigned monInitCnt = 1000;
  122. unsigned monExpandCnt = 1000;
  123. // adjust sizes to force base array addresses to be a multiple of 16 bytes.
  124. unsigned vn = vN;
  125. unsigned hn = hN;
  126. _cmRbmAdjustSizes(&vn,&hn,NULL);
  127. unsigned rn = sizeof(cmRBM_t);
  128. // force record to be a multiple of 16
  129. if( rn % 16 )
  130. rn += 16 - (rn % 16);
  131. unsigned dn = 4*vn + 3*hn + vn*hn;
  132. unsigned bn = rn + dn*sizeof(double);
  133. char* cp = cmMemAllocZ(char,bn);
  134. cmRBM_t* r = (cmRBM_t*)cp;
  135. r->vs = (double*)(cp+rn);
  136. r->vp = r->vs + vn;
  137. r->vb = r->vp + vn;
  138. r->vd = r->vb + vn;
  139. r->hs = r->vd + vn;
  140. r->hp = r->hs + hn;
  141. r->hb = r->hp + hn;
  142. r->W = r->hb + hn;
  143. r->vN = vN;
  144. r->hN = hN;
  145. assert(cp+bn == (char*)(r->W + vn*hn));
  146. if( cmStackAlloc(ctx, &r->monH, monInitCnt, monExpandCnt, sizeof(cmRbmMonitor_t)) != kOkStRC )
  147. {
  148. cmErrMsg(&ctx->err,kStackFailRbmRC,"Stack allocation failed for the training monitor data array.");
  149. goto errLabel;
  150. }
  151. return r;
  152. errLabel:
  153. _cmRbmRelease(r);
  154. return NULL;
  155. }
  156. void cmRbmBinaryTrain(
  157. cmCtx_t* ctx,
  158. cmRBM_t* r,
  159. cmRbmTrainParms_t* p,
  160. unsigned dMN,
  161. const double* dM )
  162. {
  163. cmRpt_t* rpt = ctx->err.rpt;
  164. bool stochFl = true;
  165. unsigned i,j,k,ei,di;
  166. unsigned vN = r->vN;
  167. unsigned hN = r->hN;
  168. // adjust the memory sizes to align all arrays on 16 byte boundaries
  169. unsigned vn = vN;
  170. unsigned hn = hN;
  171. unsigned dn = p->batchCnt;
  172. _cmRbmAdjustSizes(&vn,&hn,&dn);
  173. unsigned mn = (3 * hn * vn) + (1 * vn) + (1 * hn) + (3 * hn * dn) + (2 * vn * dn);
  174. double* m = cmMemAllocZ(double,mn);
  175. double* vh0M = m; // vh0M[ hN, vN ]
  176. double* vh1M = vh0M + hn*vn; // vh1M[ hN, vN ]
  177. double* dwM = vh1M + hn*vn; // dwM[ hN, vN ]
  178. double* vdbV = dwM + hn*vn; // vdbV[ vN ]
  179. double* hdbV = vdbV + vn; // hdbV[ hN ]
  180. double* hp0M = hdbV + hn; // hp0M[ hN, dN ]
  181. double* hs0M = hp0M + dn * hn; // hs0M[ dN, hN ]
  182. double* hp1M = hs0M + dn * hn; // hp1M[ hN, dN ]
  183. double* vp1M = hp1M + dn * hn; // vp1M[ dN, vN ]
  184. double* vs1M = vp1M + dn * vn; // vs1M[ vN, dN ]
  185. assert( vs1M + vn * dn == m + mn );
  186. // initilaize the weights with random values
  187. // W = p->initW * randn(vN,hN,0.0,1.0)
  188. for(i=0; i<vN; ++i)
  189. cmVOD_RandomGauss( r->W + i*hN, hN, 0.0, 1.0 );
  190. cmVOD_MultVS( r->W, hN*vN, p->initW);
  191. if(0)
  192. {
  193. const cmChar_t* fn = "/home/kevin/temp/cmRbmWeight.mtx";
  194. //cmBinMtxFileWrite(fn,hN, vN,NULL,dM,NULL,ctx->err.rpt);
  195. cmBinMtxFileRead( ctx, fn, hN, vN, sizeof(double), r->W,NULL);
  196. }
  197. cmVOD_Zero( dwM, vN*hN );
  198. cmVOD_Zero( vdbV, vN );
  199. cmVOD_Zero( hdbV, hN );
  200. for(ei=0; ei<p->epochCnt; ++ei)
  201. {
  202. unsigned dN = 0;
  203. double err = 0;
  204. for(di=0; di<dMN; di+=dN)
  205. {
  206. dN = cmMin(p->batchCnt,dMN-di);
  207. const double* d = dM + di * vN; // d[ vN, dN ]
  208. //
  209. // Update hidden layer from data
  210. //
  211. // hp0M[hN,dN] = W[hN,vN] * d[vN,dN]
  212. cmVOD_MultMMM(hp0M,hN,dN,r->W,d,vN);
  213. // calc hs0M[dN,hN]
  214. for(k=0; k<dN; ++k)
  215. for(j=0; j<hN; ++j)
  216. {
  217. hp0M[ k*hN + j ] = 1.0/(1.0 + exp(-(hp0M[ k*hN + j] + r->hb[j])));
  218. hs0M[ j*dN + k ] = rand() < hp0M[ k*hN + j ] * RAND_MAX;
  219. if( !stochFl )
  220. hs0M[ j*dN + k ] = hp0M[ k*hN + j ] > 0.5;
  221. }
  222. //
  223. // Reconstruct visible layer from hidden
  224. //
  225. // vp1M[dN,vN] = hs0M[dN,hN] * W[hN,vN]
  226. cmVOD_MultMMM(vp1M,dN,vN,hs0M,r->W,hN);
  227. // calc vs1M[vN,dN]
  228. for(k=0; k<dN; ++k)
  229. for(i=0; i<vN; ++i)
  230. {
  231. vp1M[ i*dN + k ] = 1.0/(1.0 + exp(-( vp1M[ i*dN + k ] + r->vb[i]) ) );
  232. vs1M[ k*vN + i ] = rand() < vp1M[ i*dN + k ] * RAND_MAX;
  233. if( !stochFl )
  234. vs1M[ k*vN + i ] = vp1M[ i*dN + k ] > 0.5;
  235. // calc training error
  236. err += pow(d[ k*vN + i ] - vp1M[ i*dN + k ],2.0);
  237. }
  238. //
  239. // Update hidden layer from reconstruction
  240. //
  241. // hp1M[hN,dN] = W[hN,vN] * vs1[vN,dN]
  242. cmVOD_MultMMM(hp1M,hN,dN,r->W,vs1M,vN);
  243. // calc hp1M[hN,dN]
  244. for(k=0; k<dN; ++k)
  245. for(j=0; j<hN; ++j)
  246. hp1M[ k*hN + j ] = 1.0/(1.0 + exp( -hp1M[ k*hN + j ] - r->hb[j] ));
  247. if(0)
  248. {
  249. cmVOD_PrintL("hp0M",rpt,hN,dN,hp0M);
  250. cmVOD_PrintL("hs0M",rpt,dN,hN,hs0M);
  251. cmVOD_PrintL("vp1M",rpt,dN,vN,vp1M);
  252. cmVOD_PrintL("vs1M",rpt,vN,dN,vs1M);
  253. cmVOD_PrintL("hp1M",rpt,hN,dN,hp1M);
  254. }
  255. //
  256. // Update Wieghts
  257. //
  258. // vh0M[hN,vN] = hp0M[hN,dN] * d[vN,dN]'
  259. cmVOD_MultMMMt(vh0M, hN, vN, hp0M, d, dN );
  260. cmVOD_MultMMMt(vh1M, hN, vN, hp1M, vs1M, dN );
  261. for(i=0; i<hN*vN; ++i)
  262. {
  263. dwM[i] = p->momentum * dwM[i] + p->eta * ( (vh0M[i] - vh1M[i]) / dN );
  264. r->W[i] += dwM[i];
  265. }
  266. //
  267. // Update hidden bias
  268. //
  269. // sum(hp0M - hp1M,2) - sum the difference of rows of hp0M and hp1M
  270. cmVOD_SubVV(hp0M,hN*dN,hp1M); // hp0M -= hp1M
  271. cmVOD_SumMN(hp0M,hN,dN,hp1M); // hp1M[1:hN] = sum(hp0M,2) (note: hp1M is rused as temp space)
  272. for(j=0; j<hN; ++j)
  273. {
  274. hdbV[j] = p->momentum * hdbV[j] + p->eta * (hp1M[j] / dN);
  275. r->hb[j] += hdbV[j];
  276. }
  277. //
  278. // Update visible bias
  279. //
  280. // sum(d - vs1M, 2)
  281. cmVOD_SubVVV(vp1M,vN*dN,d,vs1M); // vp1M = d - vs1M (vp1M is reused as temp space)
  282. cmVOD_SumMN(vp1M,vN,dN,vs1M); // vs1M[1:vn] = sum(vp1M,2) (vs1M is reused as temp space)
  283. for(i=0; i<vN; ++i)
  284. {
  285. vdbV[i] = p->momentum * vdbV[i] + p->eta * (vs1M[i] / dN );
  286. r->vb[i] += vdbV[i];
  287. }
  288. if(0)
  289. {
  290. cmVOD_PrintL("dwM", rpt, vN, hN, dwM );
  291. cmVOD_PrintL("vdbV",rpt, 1, vN, vdbV );
  292. cmVOD_PrintL("hdbV",rpt, 1, hN, hdbV );
  293. cmVOD_PrintL("W", rpt, vN, hN, r->W );
  294. cmVOD_PrintL("vb", rpt, 1, vN, r->vb );
  295. cmVOD_PrintL("hb", rpt, 1, hN, r->hb );
  296. }
  297. } // di
  298. cmRptPrintf(rpt,"err:%f\n",err);
  299. if( cmStackIsValid(r->monH))
  300. {
  301. cmRbmMonitor_t monErr;
  302. monErr.trainErr = err;
  303. cmStackPush(r->monH,&monErr,1);
  304. }
  305. } // ei
  306. cmRptPrintf(rpt,"eta:%f momentum:%f\n",p->eta,p->momentum);
  307. cmVOD_PrintL("dwM", rpt, vN, hN, dwM );
  308. cmVOD_PrintL("vdbV",rpt, 1, vN, vdbV );
  309. cmVOD_PrintL("hdbV",rpt, 1, hN, hdbV );
  310. cmVOD_PrintL("W", rpt, vN, hN, r->W );
  311. cmVOD_PrintL("vb", rpt, 1, vN, r->vb );
  312. cmVOD_PrintL("hb", rpt, 1, hN, r->hb );
  313. cmMemFree(m);
  314. }
  315. void cmRbmRealTrain(
  316. cmCtx_t* ctx,
  317. cmRBM_t* r,
  318. cmRbmTrainParms_t* p,
  319. unsigned dMN,
  320. const double* dM )
  321. {
  322. cmRpt_t* rpt = ctx->err.rpt;
  323. unsigned i,j,k,ei,di;
  324. unsigned vN = r->vN;
  325. unsigned hN = r->hN;
  326. // adjust the memory sizes to align all arrays on 16 byte boundaries
  327. unsigned vn = vN;
  328. unsigned hn = hN;
  329. unsigned dn = p->batchCnt;
  330. _cmRbmAdjustSizes(&vn,&hn,&dn);
  331. unsigned mn = (3 * hn * vn) + (1 * vn) + (1 * hn) + (3 * hn * dn) + (2 * vn * dn);
  332. double* m = cmMemAllocZ(double,mn);
  333. double* vh0M = m; // vh0M[ hN, vN ]
  334. double* vh1M = vh0M + hn*vn; // vh1M[ hN, vN ]
  335. double* dwM = vh1M + hn*vn; // dwM[ hN, vN ]
  336. double* vdbV = dwM + hn*vn; // vdbV[ vN ]
  337. double* hdbV = vdbV + vn; // hdbV[ hN ]
  338. double* hp0M = hdbV + hn; // hp0M[ hN, dN ]
  339. double* hs0M = hp0M + dn * hn; // hs0M[ dN, hN ]
  340. double* hp1M = hs0M + dn * hn; // hp1M[ hN, dN ]
  341. double* vp1M = hp1M + dn * hn; // vp1M[ dN, vN ]
  342. double* vs1M = vp1M + dn * vn; // vs1M[ vN, dN ]
  343. assert( vs1M + vn * dn == m + mn );
  344. //
  345. // Initilaize the weights with small random values
  346. // W = p->initW * randn(vN,hN,0.0,1.0)
  347. for(i=0; i<vN; ++i)
  348. cmVOD_RandomGauss( r->W + i*hN, hN, 0.0, 1.0 );
  349. cmVOD_MultVS( r->W, hN*vN, p->initW);
  350. if(0)
  351. {
  352. const cmChar_t* fn = "/home/kevin/temp/cmRbmWeight.mtx";
  353. //cmBinMtxFileWrite(fn,hN, vN,NULL,dM,NULL,ctx->err.rpt);
  354. cmBinMtxFileRead( ctx, fn, hN, vN, sizeof(double), r->W,NULL);
  355. }
  356. cmVOD_Zero( dwM, vN*hN );
  357. cmVOD_Zero( vdbV, vN );
  358. cmVOD_Zero( hdbV, hN );
  359. for(ei=0; ei<p->epochCnt; ++ei)
  360. {
  361. unsigned dN = 0;
  362. double err = 0;
  363. for(di=0; di<dMN; di+=dN)
  364. {
  365. dN = cmMin(p->batchCnt,dMN-di);
  366. const double* d = dM + di * vN; // d[ vN, dN ]
  367. //
  368. // Update hidden layer from data
  369. //
  370. // hp0M[hN,dN] = W[hN,vN] * d[vN,dN]
  371. cmVOD_MultMMM(hp0M,hN,dN,r->W,d,vN);
  372. // calc hs0M[dN,hN]
  373. for(k=0; k<dN; ++k)
  374. for(j=0; j<hN; ++j)
  375. {
  376. hp0M[ k*hN + j ] = 1.0/(1.0 + exp(-(hp0M[ k*hN + j] + r->hb[j])));
  377. hs0M[ j*dN + k ] = rand() < hp0M[ k*hN + j ] * RAND_MAX;
  378. }
  379. //
  380. // Reconstruct visible layer from hidden
  381. //
  382. // vp1M[dN,vN] = hs0M[dN,hN] * W[hN,vN]
  383. cmVOD_MultMMM(vp1M,dN,vN,hs0M,r->W,hN);
  384. // calc vs1M[vN,dN]
  385. for(k=0; k<dN; ++k)
  386. for(i=0; i<vN; ++i)
  387. {
  388. vp1M[ i*dN + k ] = r->vd[i] * vp1M[ i*dN + k ] + r->vb[i];
  389. cmVOD_GaussPDF(vs1M + k*vN + i, 1, vp1M + i*dN + k, r->vb[i], r->vd[i] );
  390. // calc training error
  391. err += pow(d[ k*vN + i ] - vp1M[ i*dN + k ],2.0);
  392. }
  393. //
  394. // Update hidden layer from reconstruction
  395. //
  396. // hp1M[hN,dN] = W[hN,vN] * vs1[vN,dN]
  397. cmVOD_MultMMM(hp1M,hN,dN,r->W,vs1M,vN);
  398. // calc hp1M[hN,dN]
  399. for(k=0; k<dN; ++k)
  400. for(j=0; j<hN; ++j)
  401. hp1M[ k*hN + j ] = 1.0/(1.0 + exp( -hp1M[ k*hN + j ] - r->hb[j] ));
  402. if(0)
  403. {
  404. cmVOD_PrintL("hp0M",rpt,hN,dN,hp0M);
  405. cmVOD_PrintL("hs0M",rpt,dN,hN,hs0M);
  406. cmVOD_PrintL("vp1M",rpt,dN,vN,vp1M);
  407. cmVOD_PrintL("vs1M",rpt,vN,dN,vs1M);
  408. cmVOD_PrintL("hp1M",rpt,hN,dN,hp1M);
  409. }
  410. //
  411. // Update Wieghts
  412. //
  413. // vh0M[hN,vN] = hp0M[hN,dN] * d[vN,dN]'
  414. cmVOD_MultMMMt(vh0M, hN, vN, hp0M, d, dN );
  415. cmVOD_MultMMMt(vh1M, hN, vN, hp1M, vs1M, dN );
  416. for(i=0,k=0; i<vN; ++i)
  417. for(j=0; j<hN; ++j,++k)
  418. {
  419. dwM[k] = p->momentum * dwM[k] + p->eta * ( (vh0M[k] - vh1M[k]) / (dN * r->vd[i]) );
  420. r->W[k] += dwM[k];
  421. }
  422. //
  423. // Update hidden bias
  424. //
  425. // sum(hp0M - hp1M,2) - sum the difference of rows of hp0M and hp1M
  426. cmVOD_SubVV(hp0M,hN*dN,hp1M); // hp0M -= hp1M
  427. cmVOD_SumMN(hp0M,hN,dN,hp1M); // hp1M[1:hN] = sum(hp0M,2) (note: hp1M is rused as temp space)
  428. for(j=0; j<hN; ++j)
  429. {
  430. hdbV[j] = p->momentum * hdbV[j] + p->eta * (hp1M[j] / (dN * r->vd[i] * r->vd[i]));
  431. r->hb[j] += hdbV[j];
  432. }
  433. //
  434. // Update visible bias
  435. //
  436. // sum(d - vs1M, 2)
  437. cmVOD_SubVVV(vp1M,vN*dN,d,vs1M); // vp1M = d - vs1M (vp1M is reused as temp space)
  438. cmVOD_SumMN(vp1M,vN,dN,vs1M); // vs1M[1:vn] = sum(vp1M,2) (vs1M is reused as temp space)
  439. for(i=0; i<vN; ++i)
  440. {
  441. vdbV[i] = p->momentum * vdbV[i] + p->eta * (vs1M[i] / dN );
  442. r->vb[i] += vdbV[i];
  443. }
  444. for(i=0; i<vN; ++i)
  445. {
  446. for(j=0; j<hN; ++j)
  447. {
  448. double sum_d = 0;
  449. double sum_m = 0;
  450. for(k=0; k<dN; ++k)
  451. {
  452. sum_d += hs0M[ j*dN + k ] * r->W[ i*hN + j ];
  453. sum_m += hp1M[ k*hN + j ] * r->W[ i*hN + j ];
  454. }
  455. }
  456. }
  457. if(0)
  458. {
  459. cmVOD_PrintL("dwM", rpt, vN, hN, dwM );
  460. cmVOD_PrintL("vdbV",rpt, 1, vN, vdbV );
  461. cmVOD_PrintL("hdbV",rpt, 1, hN, hdbV );
  462. cmVOD_PrintL("W", rpt, vN, hN, r->W );
  463. cmVOD_PrintL("vb", rpt, 1, vN, r->vb );
  464. cmVOD_PrintL("hb", rpt, 1, hN, r->hb );
  465. }
  466. } // di
  467. cmRptPrintf(rpt,"err:%f\n",err);
  468. if( cmStackIsValid(r->monH))
  469. {
  470. cmRbmMonitor_t monErr;
  471. monErr.trainErr = err;
  472. cmStackPush(r->monH,&monErr,1);
  473. }
  474. } // ei
  475. cmRptPrintf(rpt,"eta:%f momentum:%f\n",p->eta,p->momentum);
  476. cmVOD_PrintL("dwM", rpt, vN, hN, dwM );
  477. cmVOD_PrintL("vdbV",rpt, 1, vN, vdbV );
  478. cmVOD_PrintL("hdbV",rpt, 1, hN, hdbV );
  479. cmVOD_PrintL("W", rpt, vN, hN, r->W );
  480. cmVOD_PrintL("vb", rpt, 1, vN, r->vb );
  481. cmVOD_PrintL("hb", rpt, 1, hN, r->hb );
  482. cmMemFree(m);
  483. }
  484. void cmRbmBinaryTest( cmCtx_t* ctx )
  485. {
  486. const char* monitorFn = "/home/kevin/temp/cmRbmMonitor0.mtx";
  487. const char* dataFn = "/home/kevin/temp/cmRbmData0.mtx";
  488. unsigned pointsN = 1000;
  489. unsigned dimN = 4;
  490. unsigned vN = dimN;
  491. unsigned hN = 32;
  492. //double probV[] = {0.1,0.2,0.8,0.7};
  493. cmRbmTrainParms_t r;
  494. cmRBM_t* rbm;
  495. r.maxX = 1.0;
  496. r.minX = 0.0;
  497. r.initW = 0.1;
  498. r.eta = 0.01;
  499. r.holdOutFrac = 0.1;
  500. r.epochCnt = 10;
  501. r.momentum = 0.5;
  502. r.batchCnt = 10;
  503. if(0)
  504. {
  505. vN = 4;
  506. hN = 6;
  507. double d[] = {
  508. 0, 1, 1, 1,
  509. 1, 1, 1, 1,
  510. 0, 0, 1, 0,
  511. 0, 0, 1, 0,
  512. 0, 1, 1, 0,
  513. 0, 1, 1, 1,
  514. 1, 1, 1, 1,
  515. 0, 0, 1, 0,
  516. 0, 0, 1, 0,
  517. 0, 1, 1, 0
  518. };
  519. if( (rbm = _cmRbmAlloc(ctx, vN, hN )) == NULL )
  520. return;
  521. pointsN = sizeof(d) / (sizeof(d[0]) * vN);
  522. cmRbmBinaryTrain(ctx,rbm,&r,pointsN,d);
  523. return;
  524. }
  525. if( (rbm = _cmRbmAlloc(ctx, vN, hN )) == NULL )
  526. return;
  527. //double* data0M = cmRbmGenBinaryTestData(ctx,dataFn,probV,dimN,pointsN);
  528. double* data0M = cmRbmReadDataFile(ctx,dataFn,&dimN,&pointsN);
  529. double t[ vN ];
  530. // Sum the columns of sp[srn,scn] into dp[scn].
  531. // dp[] is zeroed prior to computing the sum.
  532. cmVOD_SumMN(data0M, dimN, pointsN, t );
  533. cmVOD_Print( &ctx->rpt, 1, dimN, t );
  534. if(0)
  535. {
  536. //
  537. // Standardize data (subtract mean and divide by standard deviation)
  538. // then set the visible layers initial standard deviation to 1.0.
  539. //
  540. cmVOD_StandardizeRows( data0M, rbm->vN, pointsN, NULL, NULL );
  541. cmVOD_Fill( rbm->vd, rbm->vN, 1.0 );
  542. cmRbmRealTrain(ctx,rbm,&r,pointsN,data0M);
  543. }
  544. cmRbmBinaryTrain(ctx,rbm,&r,pointsN,data0M);
  545. cmRbmWriteMonitorFile(ctx, rbm->monH, monitorFn );
  546. cmMemFree(data0M);
  547. _cmRbmRelease(rbm);
  548. }