libcm is a C development framework with an emphasis on audio signal processing applications.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

cmRbm.c 18KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699
  1. #include "cmPrefix.h"
  2. #include "cmGlobal.h"
  3. #include "cmFloatTypes.h"
  4. #include "cmComplexTypes.h"
  5. #include "cmRpt.h"
  6. #include "cmErr.h"
  7. #include "cmCtx.h"
  8. #include "cmMem.h"
  9. #include "cmMallocDebug.h"
  10. #include "cmLinkedHeap.h"
  11. #include "cmMath.h"
  12. #include "cmFile.h"
  13. #include "cmSymTbl.h"
  14. #include "cmTime.h"
  15. #include "cmMidi.h"
  16. #include "cmAudioFile.h"
  17. #include "cmVectOpsTemplateMain.h"
  18. #include "cmStack.h"
  19. #include "cmProcObj.h"
  20. #include "cmProcTemplateMain.h"
  21. #include "cmVectOps.h"
  22. #include "cmProc.h"
  23. #include "cmProc2.h"
  24. #include "cmRbm.h"
  25. typedef struct
  26. {
  27. double trainErr;
  28. double testErr;
  29. } cmRbmMonitor_t;
  30. cmRbmRC_t cmRbmWriteMonitorFile( cmCtx_t* c, cmStackH_t monH, const cmChar_t* fn )
  31. {
  32. cmRbmRC_t rc = kOkRbmRC;
  33. cmCtx* ctx = cmCtxAlloc(NULL, c->err.rpt, cmLHeapNullHandle, cmSymTblNullHandle );
  34. if( cmBinMtxFileWrite(fn, cmStackCount(monH), sizeof(cmRbmMonitor_t)/sizeof(double), NULL, cmStackFlatten(monH), ctx, c->err.rpt ) != cmOkRC )
  35. {
  36. rc = cmErrMsg(&c->err,kMonitorWrFailRbmRC,"Training monitor file '%s' write failed.",cmStringNullGuard(fn));
  37. goto errLabel;
  38. }
  39. errLabel:
  40. cmCtxFree(&ctx);
  41. return rc;
  42. }
  43. double* cmRbmReadDataFile( cmCtx_t* c, const char* fn, unsigned* dimNPtr, unsigned* pointCntPtr )
  44. {
  45. unsigned rowCnt,colCnt,eleByteCnt;
  46. *dimNPtr = 0;
  47. *pointCntPtr = 0;
  48. if( cmBinMtxFileSize(c, fn, &rowCnt, &colCnt, &eleByteCnt ) != cmOkRC )
  49. return NULL;
  50. double* buf = cmMemAllocZ(double,rowCnt*colCnt);
  51. if( cmBinMtxFileRead(c, fn, rowCnt, colCnt, sizeof(double), buf,NULL) != cmOkRC )
  52. {
  53. cmMemFree(buf);
  54. return NULL;
  55. }
  56. *dimNPtr = rowCnt;
  57. *pointCntPtr = colCnt;
  58. return buf;
  59. }
  60. // Generate a matrix of 'pointsN' random binary valued column vectors of dimension dimN.
  61. // The first i = {0...'dimN'-1} elements of each vector contain ones with prob probV[i]
  62. // (or zeros with prob 1 - probV[i].). probV[i] in [0.0,1.0].
  63. // The last element in each column is set to zero.
  64. // The returned matrix m[ dimN+1, pointsN ] is in column major order and
  65. // must be deleted by the caller (e.g. cmMemFree(m)).
  66. double* cmRbmGenBinaryTestData( cmCtx_t* c, const char* fn, const double* probV, unsigned dimN, unsigned pointsN )
  67. {
  68. if( dimN == 0 || pointsN == 0 )
  69. return NULL;
  70. double* m = cmMemAllocZ( double, dimN*pointsN );
  71. unsigned i,j;
  72. for(i=0; i<pointsN; ++i)
  73. for(j=0; j<dimN; ++j)
  74. m[ i*dimN + j ] = rand() < (probV[j] * RAND_MAX);
  75. if( fn != NULL )
  76. cmBinMtxFileWrite(fn,dimN,pointsN,NULL,m,NULL,c->err.rpt);
  77. return m;
  78. }
  79. typedef struct
  80. {
  81. unsigned vN;
  82. double* vs;
  83. double* vp;
  84. double* vb;
  85. double* vd; // std dev. (var = std_dev^2)
  86. unsigned hN;
  87. double* hs;
  88. double* hp;
  89. double* hb;
  90. double* W; // W[vN,hN]
  91. cmStackH_t monH;
  92. } cmRBM_t;
  93. void _cmRbmPrint( cmRBM_t* r, cmRpt_t* rpt )
  94. {
  95. cmVOD_PrintL("hb", rpt, 1, r->hN, r->hb );
  96. cmVOD_PrintL("hp", rpt, 1, r->hN, r->hp );
  97. cmVOD_PrintL("hs", rpt, 1, r->hN, r->hs );
  98. cmVOD_PrintL("vb", rpt, 1, r->vN, r->vb );
  99. cmVOD_PrintL("vp", rpt, 1, r->vN, r->vp );
  100. cmVOD_PrintL("vs", rpt, 1, r->vN, r->vs );
  101. cmVOD_PrintL("W", rpt, r->vN, r->hN, r->W );
  102. }
  103. void _cmRbmRelease( cmRBM_t* r )
  104. {
  105. cmStackFree(&r->monH);
  106. cmMemFree(r);
  107. }
  108. // Adjust the layer geometry to force all sizes to be a multiple of 16 bytes.
  109. // This assumes that all data will be 8 byte doubles.
  110. void _cmRbmAdjustSizes( unsigned* vNp, unsigned* hNp, unsigned* dNp )
  111. {
  112. *vNp = *vNp + (cmIsOddU(*vNp) ? 1 : 0);
  113. *hNp = *hNp + (cmIsOddU(*hNp) ? 1 : 0);
  114. if( dNp != NULL )
  115. *dNp = *dNp + (cmIsOddU(*dNp) ? 1 : 0);
  116. }
  117. cmRBM_t* _cmRbmAlloc( cmCtx_t* ctx, unsigned vN, unsigned hN )
  118. {
  119. unsigned monInitCnt = 1000;
  120. unsigned monExpandCnt = 1000;
  121. // adjust sizes to force base array addresses to be a multiple of 16 bytes.
  122. unsigned vn = vN;
  123. unsigned hn = hN;
  124. _cmRbmAdjustSizes(&vn,&hn,NULL);
  125. unsigned rn = sizeof(cmRBM_t);
  126. // force record to be a multiple of 16
  127. if( rn % 16 )
  128. rn += 16 - (rn % 16);
  129. unsigned dn = 4*vn + 3*hn + vn*hn;
  130. unsigned bn = rn + dn*sizeof(double);
  131. char* cp = cmMemAllocZ(char,bn);
  132. cmRBM_t* r = (cmRBM_t*)cp;
  133. r->vs = (double*)(cp+rn);
  134. r->vp = r->vs + vn;
  135. r->vb = r->vp + vn;
  136. r->vd = r->vb + vn;
  137. r->hs = r->vd + vn;
  138. r->hp = r->hs + hn;
  139. r->hb = r->hp + hn;
  140. r->W = r->hb + hn;
  141. r->vN = vN;
  142. r->hN = hN;
  143. assert(cp+bn == (char*)(r->W + vn*hn));
  144. if( cmStackAlloc(ctx, &r->monH, monInitCnt, monExpandCnt, sizeof(cmRbmMonitor_t)) != kOkStRC )
  145. {
  146. cmErrMsg(&ctx->err,kStackFailRbmRC,"Stack allocation failed for the training monitor data array.");
  147. goto errLabel;
  148. }
  149. return r;
  150. errLabel:
  151. _cmRbmRelease(r);
  152. return NULL;
  153. }
  154. void cmRbmBinaryTrain(
  155. cmCtx_t* ctx,
  156. cmRBM_t* r,
  157. cmRbmTrainParms_t* p,
  158. unsigned dMN,
  159. const double* dM )
  160. {
  161. cmRpt_t* rpt = ctx->err.rpt;
  162. bool stochFl = true;
  163. unsigned i,j,k,ei,di;
  164. unsigned vN = r->vN;
  165. unsigned hN = r->hN;
  166. // adjust the memory sizes to align all arrays on 16 byte boundaries
  167. unsigned vn = vN;
  168. unsigned hn = hN;
  169. unsigned dn = p->batchCnt;
  170. _cmRbmAdjustSizes(&vn,&hn,&dn);
  171. unsigned mn = (3 * hn * vn) + (1 * vn) + (1 * hn) + (3 * hn * dn) + (2 * vn * dn);
  172. double* m = cmMemAllocZ(double,mn);
  173. double* vh0M = m; // vh0M[ hN, vN ]
  174. double* vh1M = vh0M + hn*vn; // vh1M[ hN, vN ]
  175. double* dwM = vh1M + hn*vn; // dwM[ hN, vN ]
  176. double* vdbV = dwM + hn*vn; // vdbV[ vN ]
  177. double* hdbV = vdbV + vn; // hdbV[ hN ]
  178. double* hp0M = hdbV + hn; // hp0M[ hN, dN ]
  179. double* hs0M = hp0M + dn * hn; // hs0M[ dN, hN ]
  180. double* hp1M = hs0M + dn * hn; // hp1M[ hN, dN ]
  181. double* vp1M = hp1M + dn * hn; // vp1M[ dN, vN ]
  182. double* vs1M = vp1M + dn * vn; // vs1M[ vN, dN ]
  183. assert( vs1M + vn * dn == m + mn );
  184. // initilaize the weights with random values
  185. // W = p->initW * randn(vN,hN,0.0,1.0)
  186. for(i=0; i<vN; ++i)
  187. cmVOD_RandomGauss( r->W + i*hN, hN, 0.0, 1.0 );
  188. cmVOD_MultVS( r->W, hN*vN, p->initW);
  189. if(0)
  190. {
  191. const cmChar_t* fn = "/home/kevin/temp/cmRbmWeight.mtx";
  192. //cmBinMtxFileWrite(fn,hN, vN,NULL,dM,NULL,ctx->err.rpt);
  193. cmBinMtxFileRead( ctx, fn, hN, vN, sizeof(double), r->W,NULL);
  194. }
  195. cmVOD_Zero( dwM, vN*hN );
  196. cmVOD_Zero( vdbV, vN );
  197. cmVOD_Zero( hdbV, hN );
  198. for(ei=0; ei<p->epochCnt; ++ei)
  199. {
  200. unsigned dN = 0;
  201. double err = 0;
  202. for(di=0; di<dMN; di+=dN)
  203. {
  204. dN = cmMin(p->batchCnt,dMN-di);
  205. const double* d = dM + di * vN; // d[ vN, dN ]
  206. //
  207. // Update hidden layer from data
  208. //
  209. // hp0M[hN,dN] = W[hN,vN] * d[vN,dN]
  210. cmVOD_MultMMM(hp0M,hN,dN,r->W,d,vN);
  211. // calc hs0M[dN,hN]
  212. for(k=0; k<dN; ++k)
  213. for(j=0; j<hN; ++j)
  214. {
  215. hp0M[ k*hN + j ] = 1.0/(1.0 + exp(-(hp0M[ k*hN + j] + r->hb[j])));
  216. hs0M[ j*dN + k ] = rand() < hp0M[ k*hN + j ] * RAND_MAX;
  217. if( !stochFl )
  218. hs0M[ j*dN + k ] = hp0M[ k*hN + j ] > 0.5;
  219. }
  220. //
  221. // Reconstruct visible layer from hidden
  222. //
  223. // vp1M[dN,vN] = hs0M[dN,hN] * W[hN,vN]
  224. cmVOD_MultMMM(vp1M,dN,vN,hs0M,r->W,hN);
  225. // calc vs1M[vN,dN]
  226. for(k=0; k<dN; ++k)
  227. for(i=0; i<vN; ++i)
  228. {
  229. vp1M[ i*dN + k ] = 1.0/(1.0 + exp(-( vp1M[ i*dN + k ] + r->vb[i]) ) );
  230. vs1M[ k*vN + i ] = rand() < vp1M[ i*dN + k ] * RAND_MAX;
  231. if( !stochFl )
  232. vs1M[ k*vN + i ] = vp1M[ i*dN + k ] > 0.5;
  233. // calc training error
  234. err += pow(d[ k*vN + i ] - vp1M[ i*dN + k ],2.0);
  235. }
  236. //
  237. // Update hidden layer from reconstruction
  238. //
  239. // hp1M[hN,dN] = W[hN,vN] * vs1[vN,dN]
  240. cmVOD_MultMMM(hp1M,hN,dN,r->W,vs1M,vN);
  241. // calc hp1M[hN,dN]
  242. for(k=0; k<dN; ++k)
  243. for(j=0; j<hN; ++j)
  244. hp1M[ k*hN + j ] = 1.0/(1.0 + exp( -hp1M[ k*hN + j ] - r->hb[j] ));
  245. if(0)
  246. {
  247. cmVOD_PrintL("hp0M",rpt,hN,dN,hp0M);
  248. cmVOD_PrintL("hs0M",rpt,dN,hN,hs0M);
  249. cmVOD_PrintL("vp1M",rpt,dN,vN,vp1M);
  250. cmVOD_PrintL("vs1M",rpt,vN,dN,vs1M);
  251. cmVOD_PrintL("hp1M",rpt,hN,dN,hp1M);
  252. }
  253. //
  254. // Update Wieghts
  255. //
  256. // vh0M[hN,vN] = hp0M[hN,dN] * d[vN,dN]'
  257. cmVOD_MultMMMt(vh0M, hN, vN, hp0M, d, dN );
  258. cmVOD_MultMMMt(vh1M, hN, vN, hp1M, vs1M, dN );
  259. for(i=0; i<hN*vN; ++i)
  260. {
  261. dwM[i] = p->momentum * dwM[i] + p->eta * ( (vh0M[i] - vh1M[i]) / dN );
  262. r->W[i] += dwM[i];
  263. }
  264. //
  265. // Update hidden bias
  266. //
  267. // sum(hp0M - hp1M,2) - sum the difference of rows of hp0M and hp1M
  268. cmVOD_SubVV(hp0M,hN*dN,hp1M); // hp0M -= hp1M
  269. cmVOD_SumMN(hp0M,hN,dN,hp1M); // hp1M[1:hN] = sum(hp0M,2) (note: hp1M is rused as temp space)
  270. for(j=0; j<hN; ++j)
  271. {
  272. hdbV[j] = p->momentum * hdbV[j] + p->eta * (hp1M[j] / dN);
  273. r->hb[j] += hdbV[j];
  274. }
  275. //
  276. // Update visible bias
  277. //
  278. // sum(d - vs1M, 2)
  279. cmVOD_SubVVV(vp1M,vN*dN,d,vs1M); // vp1M = d - vs1M (vp1M is reused as temp space)
  280. cmVOD_SumMN(vp1M,vN,dN,vs1M); // vs1M[1:vn] = sum(vp1M,2) (vs1M is reused as temp space)
  281. for(i=0; i<vN; ++i)
  282. {
  283. vdbV[i] = p->momentum * vdbV[i] + p->eta * (vs1M[i] / dN );
  284. r->vb[i] += vdbV[i];
  285. }
  286. if(0)
  287. {
  288. cmVOD_PrintL("dwM", rpt, vN, hN, dwM );
  289. cmVOD_PrintL("vdbV",rpt, 1, vN, vdbV );
  290. cmVOD_PrintL("hdbV",rpt, 1, hN, hdbV );
  291. cmVOD_PrintL("W", rpt, vN, hN, r->W );
  292. cmVOD_PrintL("vb", rpt, 1, vN, r->vb );
  293. cmVOD_PrintL("hb", rpt, 1, hN, r->hb );
  294. }
  295. } // di
  296. cmRptPrintf(rpt,"err:%f\n",err);
  297. if( cmStackIsValid(r->monH))
  298. {
  299. cmRbmMonitor_t monErr;
  300. monErr.trainErr = err;
  301. cmStackPush(r->monH,&monErr,1);
  302. }
  303. } // ei
  304. cmRptPrintf(rpt,"eta:%f momentum:%f\n",p->eta,p->momentum);
  305. cmVOD_PrintL("dwM", rpt, vN, hN, dwM );
  306. cmVOD_PrintL("vdbV",rpt, 1, vN, vdbV );
  307. cmVOD_PrintL("hdbV",rpt, 1, hN, hdbV );
  308. cmVOD_PrintL("W", rpt, vN, hN, r->W );
  309. cmVOD_PrintL("vb", rpt, 1, vN, r->vb );
  310. cmVOD_PrintL("hb", rpt, 1, hN, r->hb );
  311. cmMemFree(m);
  312. }
  313. void cmRbmRealTrain(
  314. cmCtx_t* ctx,
  315. cmRBM_t* r,
  316. cmRbmTrainParms_t* p,
  317. unsigned dMN,
  318. const double* dM )
  319. {
  320. cmRpt_t* rpt = ctx->err.rpt;
  321. unsigned i,j,k,ei,di;
  322. unsigned vN = r->vN;
  323. unsigned hN = r->hN;
  324. // adjust the memory sizes to align all arrays on 16 byte boundaries
  325. unsigned vn = vN;
  326. unsigned hn = hN;
  327. unsigned dn = p->batchCnt;
  328. _cmRbmAdjustSizes(&vn,&hn,&dn);
  329. unsigned mn = (3 * hn * vn) + (1 * vn) + (1 * hn) + (3 * hn * dn) + (2 * vn * dn);
  330. double* m = cmMemAllocZ(double,mn);
  331. double* vh0M = m; // vh0M[ hN, vN ]
  332. double* vh1M = vh0M + hn*vn; // vh1M[ hN, vN ]
  333. double* dwM = vh1M + hn*vn; // dwM[ hN, vN ]
  334. double* vdbV = dwM + hn*vn; // vdbV[ vN ]
  335. double* hdbV = vdbV + vn; // hdbV[ hN ]
  336. double* hp0M = hdbV + hn; // hp0M[ hN, dN ]
  337. double* hs0M = hp0M + dn * hn; // hs0M[ dN, hN ]
  338. double* hp1M = hs0M + dn * hn; // hp1M[ hN, dN ]
  339. double* vp1M = hp1M + dn * hn; // vp1M[ dN, vN ]
  340. double* vs1M = vp1M + dn * vn; // vs1M[ vN, dN ]
  341. assert( vs1M + vn * dn == m + mn );
  342. //
  343. // Initilaize the weights with small random values
  344. // W = p->initW * randn(vN,hN,0.0,1.0)
  345. for(i=0; i<vN; ++i)
  346. cmVOD_RandomGauss( r->W + i*hN, hN, 0.0, 1.0 );
  347. cmVOD_MultVS( r->W, hN*vN, p->initW);
  348. if(0)
  349. {
  350. const cmChar_t* fn = "/home/kevin/temp/cmRbmWeight.mtx";
  351. //cmBinMtxFileWrite(fn,hN, vN,NULL,dM,NULL,ctx->err.rpt);
  352. cmBinMtxFileRead( ctx, fn, hN, vN, sizeof(double), r->W,NULL);
  353. }
  354. cmVOD_Zero( dwM, vN*hN );
  355. cmVOD_Zero( vdbV, vN );
  356. cmVOD_Zero( hdbV, hN );
  357. for(ei=0; ei<p->epochCnt; ++ei)
  358. {
  359. unsigned dN = 0;
  360. double err = 0;
  361. for(di=0; di<dMN; di+=dN)
  362. {
  363. dN = cmMin(p->batchCnt,dMN-di);
  364. const double* d = dM + di * vN; // d[ vN, dN ]
  365. //
  366. // Update hidden layer from data
  367. //
  368. // hp0M[hN,dN] = W[hN,vN] * d[vN,dN]
  369. cmVOD_MultMMM(hp0M,hN,dN,r->W,d,vN);
  370. // calc hs0M[dN,hN]
  371. for(k=0; k<dN; ++k)
  372. for(j=0; j<hN; ++j)
  373. {
  374. hp0M[ k*hN + j ] = 1.0/(1.0 + exp(-(hp0M[ k*hN + j] + r->hb[j])));
  375. hs0M[ j*dN + k ] = rand() < hp0M[ k*hN + j ] * RAND_MAX;
  376. }
  377. //
  378. // Reconstruct visible layer from hidden
  379. //
  380. // vp1M[dN,vN] = hs0M[dN,hN] * W[hN,vN]
  381. cmVOD_MultMMM(vp1M,dN,vN,hs0M,r->W,hN);
  382. // calc vs1M[vN,dN]
  383. for(k=0; k<dN; ++k)
  384. for(i=0; i<vN; ++i)
  385. {
  386. vp1M[ i*dN + k ] = r->vd[i] * vp1M[ i*dN + k ] + r->vb[i];
  387. cmVOD_GaussPDF(vs1M + k*vN + i, 1, vp1M + i*dN + k, r->vb[i], r->vd[i] );
  388. // calc training error
  389. err += pow(d[ k*vN + i ] - vp1M[ i*dN + k ],2.0);
  390. }
  391. //
  392. // Update hidden layer from reconstruction
  393. //
  394. // hp1M[hN,dN] = W[hN,vN] * vs1[vN,dN]
  395. cmVOD_MultMMM(hp1M,hN,dN,r->W,vs1M,vN);
  396. // calc hp1M[hN,dN]
  397. for(k=0; k<dN; ++k)
  398. for(j=0; j<hN; ++j)
  399. hp1M[ k*hN + j ] = 1.0/(1.0 + exp( -hp1M[ k*hN + j ] - r->hb[j] ));
  400. if(0)
  401. {
  402. cmVOD_PrintL("hp0M",rpt,hN,dN,hp0M);
  403. cmVOD_PrintL("hs0M",rpt,dN,hN,hs0M);
  404. cmVOD_PrintL("vp1M",rpt,dN,vN,vp1M);
  405. cmVOD_PrintL("vs1M",rpt,vN,dN,vs1M);
  406. cmVOD_PrintL("hp1M",rpt,hN,dN,hp1M);
  407. }
  408. //
  409. // Update Wieghts
  410. //
  411. // vh0M[hN,vN] = hp0M[hN,dN] * d[vN,dN]'
  412. cmVOD_MultMMMt(vh0M, hN, vN, hp0M, d, dN );
  413. cmVOD_MultMMMt(vh1M, hN, vN, hp1M, vs1M, dN );
  414. for(i=0,k=0; i<vN; ++i)
  415. for(j=0; j<hN; ++j,++k)
  416. {
  417. dwM[k] = p->momentum * dwM[k] + p->eta * ( (vh0M[k] - vh1M[k]) / (dN * r->vd[i]) );
  418. r->W[k] += dwM[k];
  419. }
  420. //
  421. // Update hidden bias
  422. //
  423. // sum(hp0M - hp1M,2) - sum the difference of rows of hp0M and hp1M
  424. cmVOD_SubVV(hp0M,hN*dN,hp1M); // hp0M -= hp1M
  425. cmVOD_SumMN(hp0M,hN,dN,hp1M); // hp1M[1:hN] = sum(hp0M,2) (note: hp1M is rused as temp space)
  426. for(j=0; j<hN; ++j)
  427. {
  428. hdbV[j] = p->momentum * hdbV[j] + p->eta * (hp1M[j] / (dN * r->vd[i] * r->vd[i]));
  429. r->hb[j] += hdbV[j];
  430. }
  431. //
  432. // Update visible bias
  433. //
  434. // sum(d - vs1M, 2)
  435. cmVOD_SubVVV(vp1M,vN*dN,d,vs1M); // vp1M = d - vs1M (vp1M is reused as temp space)
  436. cmVOD_SumMN(vp1M,vN,dN,vs1M); // vs1M[1:vn] = sum(vp1M,2) (vs1M is reused as temp space)
  437. for(i=0; i<vN; ++i)
  438. {
  439. vdbV[i] = p->momentum * vdbV[i] + p->eta * (vs1M[i] / dN );
  440. r->vb[i] += vdbV[i];
  441. }
  442. for(i=0; i<vN; ++i)
  443. {
  444. for(j=0; j<hN; ++j)
  445. {
  446. double sum_d = 0;
  447. double sum_m = 0;
  448. for(k=0; k<dN; ++k)
  449. {
  450. sum_d += hs0M[ j*dN + k ] * r->W[ i*hN + j ];
  451. sum_m += hp1M[ k*hN + j ] * r->W[ i*hN + j ];
  452. }
  453. }
  454. }
  455. if(0)
  456. {
  457. cmVOD_PrintL("dwM", rpt, vN, hN, dwM );
  458. cmVOD_PrintL("vdbV",rpt, 1, vN, vdbV );
  459. cmVOD_PrintL("hdbV",rpt, 1, hN, hdbV );
  460. cmVOD_PrintL("W", rpt, vN, hN, r->W );
  461. cmVOD_PrintL("vb", rpt, 1, vN, r->vb );
  462. cmVOD_PrintL("hb", rpt, 1, hN, r->hb );
  463. }
  464. } // di
  465. cmRptPrintf(rpt,"err:%f\n",err);
  466. if( cmStackIsValid(r->monH))
  467. {
  468. cmRbmMonitor_t monErr;
  469. monErr.trainErr = err;
  470. cmStackPush(r->monH,&monErr,1);
  471. }
  472. } // ei
  473. cmRptPrintf(rpt,"eta:%f momentum:%f\n",p->eta,p->momentum);
  474. cmVOD_PrintL("dwM", rpt, vN, hN, dwM );
  475. cmVOD_PrintL("vdbV",rpt, 1, vN, vdbV );
  476. cmVOD_PrintL("hdbV",rpt, 1, hN, hdbV );
  477. cmVOD_PrintL("W", rpt, vN, hN, r->W );
  478. cmVOD_PrintL("vb", rpt, 1, vN, r->vb );
  479. cmVOD_PrintL("hb", rpt, 1, hN, r->hb );
  480. cmMemFree(m);
  481. }
  482. void cmRbmBinaryTest( cmCtx_t* ctx )
  483. {
  484. const char* monitorFn = "/home/kevin/temp/cmRbmMonitor0.mtx";
  485. const char* dataFn = "/home/kevin/temp/cmRbmData0.mtx";
  486. unsigned pointsN = 1000;
  487. unsigned dimN = 4;
  488. unsigned vN = dimN;
  489. unsigned hN = 32;
  490. //double probV[] = {0.1,0.2,0.8,0.7};
  491. cmRbmTrainParms_t r;
  492. cmRBM_t* rbm;
  493. r.maxX = 1.0;
  494. r.minX = 0.0;
  495. r.initW = 0.1;
  496. r.eta = 0.01;
  497. r.holdOutFrac = 0.1;
  498. r.epochCnt = 10;
  499. r.momentum = 0.5;
  500. r.batchCnt = 10;
  501. if(0)
  502. {
  503. vN = 4;
  504. hN = 6;
  505. double d[] = {
  506. 0, 1, 1, 1,
  507. 1, 1, 1, 1,
  508. 0, 0, 1, 0,
  509. 0, 0, 1, 0,
  510. 0, 1, 1, 0,
  511. 0, 1, 1, 1,
  512. 1, 1, 1, 1,
  513. 0, 0, 1, 0,
  514. 0, 0, 1, 0,
  515. 0, 1, 1, 0
  516. };
  517. if( (rbm = _cmRbmAlloc(ctx, vN, hN )) == NULL )
  518. return;
  519. pointsN = sizeof(d) / (sizeof(d[0]) * vN);
  520. cmRbmBinaryTrain(ctx,rbm,&r,pointsN,d);
  521. return;
  522. }
  523. if( (rbm = _cmRbmAlloc(ctx, vN, hN )) == NULL )
  524. return;
  525. //double* data0M = cmRbmGenBinaryTestData(ctx,dataFn,probV,dimN,pointsN);
  526. double* data0M = cmRbmReadDataFile(ctx,dataFn,&dimN,&pointsN);
  527. double t[ vN ];
  528. // Sum the columns of sp[srn,scn] into dp[scn].
  529. // dp[] is zeroed prior to computing the sum.
  530. cmVOD_SumMN(data0M, dimN, pointsN, t );
  531. cmVOD_Print( &ctx->rpt, 1, dimN, t );
  532. if(0)
  533. {
  534. //
  535. // Standardize data (subtract mean and divide by standard deviation)
  536. // then set the visible layers initial standard deviation to 1.0.
  537. //
  538. cmVOD_StandardizeRows( data0M, rbm->vN, pointsN, NULL, NULL );
  539. cmVOD_Fill( rbm->vd, rbm->vN, 1.0 );
  540. cmRbmRealTrain(ctx,rbm,&r,pointsN,data0M);
  541. }
  542. cmRbmBinaryTrain(ctx,rbm,&r,pointsN,data0M);
  543. cmRbmWriteMonitorFile(ctx, rbm->monH, monitorFn );
  544. cmMemFree(data0M);
  545. _cmRbmRelease(rbm);
  546. }