libcm is a C development framework with an emphasis on audio signal processing applications.
您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

cmRbm.c 18KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698
  1. #include "cmPrefix.h"
  2. #include "cmGlobal.h"
  3. #include "cmFloatTypes.h"
  4. #include "cmComplexTypes.h"
  5. #include "cmRpt.h"
  6. #include "cmErr.h"
  7. #include "cmCtx.h"
  8. #include "cmMem.h"
  9. #include "cmMallocDebug.h"
  10. #include "cmLinkedHeap.h"
  11. #include "cmMath.h"
  12. #include "cmFile.h"
  13. #include "cmSymTbl.h"
  14. #include "cmMidi.h"
  15. #include "cmAudioFile.h"
  16. #include "cmVectOpsTemplateMain.h"
  17. #include "cmStack.h"
  18. #include "cmProcObj.h"
  19. #include "cmProcTemplateMain.h"
  20. #include "cmVectOps.h"
  21. #include "cmProc.h"
  22. #include "cmProc2.h"
  23. #include "cmRbm.h"
  24. typedef struct
  25. {
  26. double trainErr;
  27. double testErr;
  28. } cmRbmMonitor_t;
  29. cmRbmRC_t cmRbmWriteMonitorFile( cmCtx_t* c, cmStackH_t monH, const cmChar_t* fn )
  30. {
  31. cmRbmRC_t rc = kOkRbmRC;
  32. cmCtx* ctx = cmCtxAlloc(NULL, c->err.rpt, cmLHeapNullHandle, cmSymTblNullHandle );
  33. if( cmBinMtxFileWrite(fn, cmStackCount(monH), sizeof(cmRbmMonitor_t)/sizeof(double), NULL, cmStackFlatten(monH), ctx, c->err.rpt ) != cmOkRC )
  34. {
  35. rc = cmErrMsg(&c->err,kMonitorWrFailRbmRC,"Training monitor file '%s' write failed.",cmStringNullGuard(fn));
  36. goto errLabel;
  37. }
  38. errLabel:
  39. cmCtxFree(&ctx);
  40. return rc;
  41. }
  42. double* cmRbmReadDataFile( cmCtx_t* c, const char* fn, unsigned* dimNPtr, unsigned* pointCntPtr )
  43. {
  44. unsigned rowCnt,colCnt,eleByteCnt;
  45. *dimNPtr = 0;
  46. *pointCntPtr = 0;
  47. if( cmBinMtxFileSize(c, fn, &rowCnt, &colCnt, &eleByteCnt ) != cmOkRC )
  48. return NULL;
  49. double* buf = cmMemAllocZ(double,rowCnt*colCnt);
  50. if( cmBinMtxFileRead(c, fn, rowCnt, colCnt, sizeof(double), buf,NULL) != cmOkRC )
  51. {
  52. cmMemFree(buf);
  53. return NULL;
  54. }
  55. *dimNPtr = rowCnt;
  56. *pointCntPtr = colCnt;
  57. return buf;
  58. }
  59. // Generate a matrix of 'pointsN' random binary valued column vectors of dimension dimN.
  60. // The first i = {0...'dimN'-1} elements of each vector contain ones with prob probV[i]
  61. // (or zeros with prob 1 - probV[i].). probV[i] in [0.0,1.0].
  62. // The last element in each column is set to zero.
  63. // The returned matrix m[ dimN+1, pointsN ] is in column major order and
  64. // must be deleted by the caller (e.g. cmMemFree(m)).
  65. double* cmRbmGenBinaryTestData( cmCtx_t* c, const char* fn, const double* probV, unsigned dimN, unsigned pointsN )
  66. {
  67. if( dimN == 0 || pointsN == 0 )
  68. return NULL;
  69. double* m = cmMemAllocZ( double, dimN*pointsN );
  70. unsigned i,j;
  71. for(i=0; i<pointsN; ++i)
  72. for(j=0; j<dimN; ++j)
  73. m[ i*dimN + j ] = rand() < (probV[j] * RAND_MAX);
  74. if( fn != NULL )
  75. cmBinMtxFileWrite(fn,dimN,pointsN,NULL,m,NULL,c->err.rpt);
  76. return m;
  77. }
  78. typedef struct
  79. {
  80. unsigned vN;
  81. double* vs;
  82. double* vp;
  83. double* vb;
  84. double* vd; // std dev. (var = std_dev^2)
  85. unsigned hN;
  86. double* hs;
  87. double* hp;
  88. double* hb;
  89. double* W; // W[vN,hN]
  90. cmStackH_t monH;
  91. } cmRBM_t;
  92. void _cmRbmPrint( cmRBM_t* r, cmRpt_t* rpt )
  93. {
  94. cmVOD_PrintL("hb", rpt, 1, r->hN, r->hb );
  95. cmVOD_PrintL("hp", rpt, 1, r->hN, r->hp );
  96. cmVOD_PrintL("hs", rpt, 1, r->hN, r->hs );
  97. cmVOD_PrintL("vb", rpt, 1, r->vN, r->vb );
  98. cmVOD_PrintL("vp", rpt, 1, r->vN, r->vp );
  99. cmVOD_PrintL("vs", rpt, 1, r->vN, r->vs );
  100. cmVOD_PrintL("W", rpt, r->vN, r->hN, r->W );
  101. }
  102. void _cmRbmRelease( cmRBM_t* r )
  103. {
  104. cmStackFree(&r->monH);
  105. cmMemFree(r);
  106. }
  107. // Adjust the layer geometry to force all sizes to be a multiple of 16 bytes.
  108. // This assumes that all data will be 8 byte doubles.
  109. void _cmRbmAdjustSizes( unsigned* vNp, unsigned* hNp, unsigned* dNp )
  110. {
  111. *vNp = *vNp + (cmIsOddU(*vNp) ? 1 : 0);
  112. *hNp = *hNp + (cmIsOddU(*hNp) ? 1 : 0);
  113. if( dNp != NULL )
  114. *dNp = *dNp + (cmIsOddU(*dNp) ? 1 : 0);
  115. }
  116. cmRBM_t* _cmRbmAlloc( cmCtx_t* ctx, unsigned vN, unsigned hN )
  117. {
  118. unsigned monInitCnt = 1000;
  119. unsigned monExpandCnt = 1000;
  120. // adjust sizes to force base array addresses to be a multiple of 16 bytes.
  121. unsigned vn = vN;
  122. unsigned hn = hN;
  123. _cmRbmAdjustSizes(&vn,&hn,NULL);
  124. unsigned rn = sizeof(cmRBM_t);
  125. // force record to be a multiple of 16
  126. if( rn % 16 )
  127. rn += 16 - (rn % 16);
  128. unsigned dn = 4*vn + 3*hn + vn*hn;
  129. unsigned bn = rn + dn*sizeof(double);
  130. char* cp = cmMemAllocZ(char,bn);
  131. cmRBM_t* r = (cmRBM_t*)cp;
  132. r->vs = (double*)(cp+rn);
  133. r->vp = r->vs + vn;
  134. r->vb = r->vp + vn;
  135. r->vd = r->vb + vn;
  136. r->hs = r->vd + vn;
  137. r->hp = r->hs + hn;
  138. r->hb = r->hp + hn;
  139. r->W = r->hb + hn;
  140. r->vN = vN;
  141. r->hN = hN;
  142. assert(cp+bn == (char*)(r->W + vn*hn));
  143. if( cmStackAlloc(ctx, &r->monH, monInitCnt, monExpandCnt, sizeof(cmRbmMonitor_t)) != kOkStRC )
  144. {
  145. cmErrMsg(&ctx->err,kStackFailRbmRC,"Stack allocation failed for the training monitor data array.");
  146. goto errLabel;
  147. }
  148. return r;
  149. errLabel:
  150. _cmRbmRelease(r);
  151. return NULL;
  152. }
  153. void cmRbmBinaryTrain(
  154. cmCtx_t* ctx,
  155. cmRBM_t* r,
  156. cmRbmTrainParms_t* p,
  157. unsigned dMN,
  158. const double* dM )
  159. {
  160. cmRpt_t* rpt = ctx->err.rpt;
  161. bool stochFl = true;
  162. unsigned i,j,k,ei,di;
  163. unsigned vN = r->vN;
  164. unsigned hN = r->hN;
  165. // adjust the memory sizes to align all arrays on 16 byte boundaries
  166. unsigned vn = vN;
  167. unsigned hn = hN;
  168. unsigned dn = p->batchCnt;
  169. _cmRbmAdjustSizes(&vn,&hn,&dn);
  170. unsigned mn = (3 * hn * vn) + (1 * vn) + (1 * hn) + (3 * hn * dn) + (2 * vn * dn);
  171. double* m = cmMemAllocZ(double,mn);
  172. double* vh0M = m; // vh0M[ hN, vN ]
  173. double* vh1M = vh0M + hn*vn; // vh1M[ hN, vN ]
  174. double* dwM = vh1M + hn*vn; // dwM[ hN, vN ]
  175. double* vdbV = dwM + hn*vn; // vdbV[ vN ]
  176. double* hdbV = vdbV + vn; // hdbV[ hN ]
  177. double* hp0M = hdbV + hn; // hp0M[ hN, dN ]
  178. double* hs0M = hp0M + dn * hn; // hs0M[ dN, hN ]
  179. double* hp1M = hs0M + dn * hn; // hp1M[ hN, dN ]
  180. double* vp1M = hp1M + dn * hn; // vp1M[ dN, vN ]
  181. double* vs1M = vp1M + dn * vn; // vs1M[ vN, dN ]
  182. assert( vs1M + vn * dn == m + mn );
  183. // initilaize the weights with random values
  184. // W = p->initW * randn(vN,hN,0.0,1.0)
  185. for(i=0; i<vN; ++i)
  186. cmVOD_RandomGauss( r->W + i*hN, hN, 0.0, 1.0 );
  187. cmVOD_MultVS( r->W, hN*vN, p->initW);
  188. if(0)
  189. {
  190. const cmChar_t* fn = "/home/kevin/temp/cmRbmWeight.mtx";
  191. //cmBinMtxFileWrite(fn,hN, vN,NULL,dM,NULL,ctx->err.rpt);
  192. cmBinMtxFileRead( ctx, fn, hN, vN, sizeof(double), r->W,NULL);
  193. }
  194. cmVOD_Zero( dwM, vN*hN );
  195. cmVOD_Zero( vdbV, vN );
  196. cmVOD_Zero( hdbV, hN );
  197. for(ei=0; ei<p->epochCnt; ++ei)
  198. {
  199. unsigned dN = 0;
  200. double err = 0;
  201. for(di=0; di<dMN; di+=dN)
  202. {
  203. dN = cmMin(p->batchCnt,dMN-di);
  204. const double* d = dM + di * vN; // d[ vN, dN ]
  205. //
  206. // Update hidden layer from data
  207. //
  208. // hp0M[hN,dN] = W[hN,vN] * d[vN,dN]
  209. cmVOD_MultMMM(hp0M,hN,dN,r->W,d,vN);
  210. // calc hs0M[dN,hN]
  211. for(k=0; k<dN; ++k)
  212. for(j=0; j<hN; ++j)
  213. {
  214. hp0M[ k*hN + j ] = 1.0/(1.0 + exp(-(hp0M[ k*hN + j] + r->hb[j])));
  215. hs0M[ j*dN + k ] = rand() < hp0M[ k*hN + j ] * RAND_MAX;
  216. if( !stochFl )
  217. hs0M[ j*dN + k ] = hp0M[ k*hN + j ] > 0.5;
  218. }
  219. //
  220. // Reconstruct visible layer from hidden
  221. //
  222. // vp1M[dN,vN] = hs0M[dN,hN] * W[hN,vN]
  223. cmVOD_MultMMM(vp1M,dN,vN,hs0M,r->W,hN);
  224. // calc vs1M[vN,dN]
  225. for(k=0; k<dN; ++k)
  226. for(i=0; i<vN; ++i)
  227. {
  228. vp1M[ i*dN + k ] = 1.0/(1.0 + exp(-( vp1M[ i*dN + k ] + r->vb[i]) ) );
  229. vs1M[ k*vN + i ] = rand() < vp1M[ i*dN + k ] * RAND_MAX;
  230. if( !stochFl )
  231. vs1M[ k*vN + i ] = vp1M[ i*dN + k ] > 0.5;
  232. // calc training error
  233. err += pow(d[ k*vN + i ] - vp1M[ i*dN + k ],2.0);
  234. }
  235. //
  236. // Update hidden layer from reconstruction
  237. //
  238. // hp1M[hN,dN] = W[hN,vN] * vs1[vN,dN]
  239. cmVOD_MultMMM(hp1M,hN,dN,r->W,vs1M,vN);
  240. // calc hp1M[hN,dN]
  241. for(k=0; k<dN; ++k)
  242. for(j=0; j<hN; ++j)
  243. hp1M[ k*hN + j ] = 1.0/(1.0 + exp( -hp1M[ k*hN + j ] - r->hb[j] ));
  244. if(0)
  245. {
  246. cmVOD_PrintL("hp0M",rpt,hN,dN,hp0M);
  247. cmVOD_PrintL("hs0M",rpt,dN,hN,hs0M);
  248. cmVOD_PrintL("vp1M",rpt,dN,vN,vp1M);
  249. cmVOD_PrintL("vs1M",rpt,vN,dN,vs1M);
  250. cmVOD_PrintL("hp1M",rpt,hN,dN,hp1M);
  251. }
  252. //
  253. // Update Wieghts
  254. //
  255. // vh0M[hN,vN] = hp0M[hN,dN] * d[vN,dN]'
  256. cmVOD_MultMMMt(vh0M, hN, vN, hp0M, d, dN );
  257. cmVOD_MultMMMt(vh1M, hN, vN, hp1M, vs1M, dN );
  258. for(i=0; i<hN*vN; ++i)
  259. {
  260. dwM[i] = p->momentum * dwM[i] + p->eta * ( (vh0M[i] - vh1M[i]) / dN );
  261. r->W[i] += dwM[i];
  262. }
  263. //
  264. // Update hidden bias
  265. //
  266. // sum(hp0M - hp1M,2) - sum the difference of rows of hp0M and hp1M
  267. cmVOD_SubVV(hp0M,hN*dN,hp1M); // hp0M -= hp1M
  268. cmVOD_SumMN(hp0M,hN,dN,hp1M); // hp1M[1:hN] = sum(hp0M,2) (note: hp1M is rused as temp space)
  269. for(j=0; j<hN; ++j)
  270. {
  271. hdbV[j] = p->momentum * hdbV[j] + p->eta * (hp1M[j] / dN);
  272. r->hb[j] += hdbV[j];
  273. }
  274. //
  275. // Update visible bias
  276. //
  277. // sum(d - vs1M, 2)
  278. cmVOD_SubVVV(vp1M,vN*dN,d,vs1M); // vp1M = d - vs1M (vp1M is reused as temp space)
  279. cmVOD_SumMN(vp1M,vN,dN,vs1M); // vs1M[1:vn] = sum(vp1M,2) (vs1M is reused as temp space)
  280. for(i=0; i<vN; ++i)
  281. {
  282. vdbV[i] = p->momentum * vdbV[i] + p->eta * (vs1M[i] / dN );
  283. r->vb[i] += vdbV[i];
  284. }
  285. if(0)
  286. {
  287. cmVOD_PrintL("dwM", rpt, vN, hN, dwM );
  288. cmVOD_PrintL("vdbV",rpt, 1, vN, vdbV );
  289. cmVOD_PrintL("hdbV",rpt, 1, hN, hdbV );
  290. cmVOD_PrintL("W", rpt, vN, hN, r->W );
  291. cmVOD_PrintL("vb", rpt, 1, vN, r->vb );
  292. cmVOD_PrintL("hb", rpt, 1, hN, r->hb );
  293. }
  294. } // di
  295. cmRptPrintf(rpt,"err:%f\n",err);
  296. if( cmStackIsValid(r->monH))
  297. {
  298. cmRbmMonitor_t monErr;
  299. monErr.trainErr = err;
  300. cmStackPush(r->monH,&monErr,1);
  301. }
  302. } // ei
  303. cmRptPrintf(rpt,"eta:%f momentum:%f\n",p->eta,p->momentum);
  304. cmVOD_PrintL("dwM", rpt, vN, hN, dwM );
  305. cmVOD_PrintL("vdbV",rpt, 1, vN, vdbV );
  306. cmVOD_PrintL("hdbV",rpt, 1, hN, hdbV );
  307. cmVOD_PrintL("W", rpt, vN, hN, r->W );
  308. cmVOD_PrintL("vb", rpt, 1, vN, r->vb );
  309. cmVOD_PrintL("hb", rpt, 1, hN, r->hb );
  310. cmMemFree(m);
  311. }
  312. void cmRbmRealTrain(
  313. cmCtx_t* ctx,
  314. cmRBM_t* r,
  315. cmRbmTrainParms_t* p,
  316. unsigned dMN,
  317. const double* dM )
  318. {
  319. cmRpt_t* rpt = ctx->err.rpt;
  320. unsigned i,j,k,ei,di;
  321. unsigned vN = r->vN;
  322. unsigned hN = r->hN;
  323. // adjust the memory sizes to align all arrays on 16 byte boundaries
  324. unsigned vn = vN;
  325. unsigned hn = hN;
  326. unsigned dn = p->batchCnt;
  327. _cmRbmAdjustSizes(&vn,&hn,&dn);
  328. unsigned mn = (3 * hn * vn) + (1 * vn) + (1 * hn) + (3 * hn * dn) + (2 * vn * dn);
  329. double* m = cmMemAllocZ(double,mn);
  330. double* vh0M = m; // vh0M[ hN, vN ]
  331. double* vh1M = vh0M + hn*vn; // vh1M[ hN, vN ]
  332. double* dwM = vh1M + hn*vn; // dwM[ hN, vN ]
  333. double* vdbV = dwM + hn*vn; // vdbV[ vN ]
  334. double* hdbV = vdbV + vn; // hdbV[ hN ]
  335. double* hp0M = hdbV + hn; // hp0M[ hN, dN ]
  336. double* hs0M = hp0M + dn * hn; // hs0M[ dN, hN ]
  337. double* hp1M = hs0M + dn * hn; // hp1M[ hN, dN ]
  338. double* vp1M = hp1M + dn * hn; // vp1M[ dN, vN ]
  339. double* vs1M = vp1M + dn * vn; // vs1M[ vN, dN ]
  340. assert( vs1M + vn * dn == m + mn );
  341. //
  342. // Initilaize the weights with small random values
  343. // W = p->initW * randn(vN,hN,0.0,1.0)
  344. for(i=0; i<vN; ++i)
  345. cmVOD_RandomGauss( r->W + i*hN, hN, 0.0, 1.0 );
  346. cmVOD_MultVS( r->W, hN*vN, p->initW);
  347. if(0)
  348. {
  349. const cmChar_t* fn = "/home/kevin/temp/cmRbmWeight.mtx";
  350. //cmBinMtxFileWrite(fn,hN, vN,NULL,dM,NULL,ctx->err.rpt);
  351. cmBinMtxFileRead( ctx, fn, hN, vN, sizeof(double), r->W,NULL);
  352. }
  353. cmVOD_Zero( dwM, vN*hN );
  354. cmVOD_Zero( vdbV, vN );
  355. cmVOD_Zero( hdbV, hN );
  356. for(ei=0; ei<p->epochCnt; ++ei)
  357. {
  358. unsigned dN = 0;
  359. double err = 0;
  360. for(di=0; di<dMN; di+=dN)
  361. {
  362. dN = cmMin(p->batchCnt,dMN-di);
  363. const double* d = dM + di * vN; // d[ vN, dN ]
  364. //
  365. // Update hidden layer from data
  366. //
  367. // hp0M[hN,dN] = W[hN,vN] * d[vN,dN]
  368. cmVOD_MultMMM(hp0M,hN,dN,r->W,d,vN);
  369. // calc hs0M[dN,hN]
  370. for(k=0; k<dN; ++k)
  371. for(j=0; j<hN; ++j)
  372. {
  373. hp0M[ k*hN + j ] = 1.0/(1.0 + exp(-(hp0M[ k*hN + j] + r->hb[j])));
  374. hs0M[ j*dN + k ] = rand() < hp0M[ k*hN + j ] * RAND_MAX;
  375. }
  376. //
  377. // Reconstruct visible layer from hidden
  378. //
  379. // vp1M[dN,vN] = hs0M[dN,hN] * W[hN,vN]
  380. cmVOD_MultMMM(vp1M,dN,vN,hs0M,r->W,hN);
  381. // calc vs1M[vN,dN]
  382. for(k=0; k<dN; ++k)
  383. for(i=0; i<vN; ++i)
  384. {
  385. vp1M[ i*dN + k ] = r->vd[i] * vp1M[ i*dN + k ] + r->vb[i];
  386. cmVOD_GaussPDF(vs1M + k*vN + i, 1, vp1M + i*dN + k, r->vb[i], r->vd[i] );
  387. // calc training error
  388. err += pow(d[ k*vN + i ] - vp1M[ i*dN + k ],2.0);
  389. }
  390. //
  391. // Update hidden layer from reconstruction
  392. //
  393. // hp1M[hN,dN] = W[hN,vN] * vs1[vN,dN]
  394. cmVOD_MultMMM(hp1M,hN,dN,r->W,vs1M,vN);
  395. // calc hp1M[hN,dN]
  396. for(k=0; k<dN; ++k)
  397. for(j=0; j<hN; ++j)
  398. hp1M[ k*hN + j ] = 1.0/(1.0 + exp( -hp1M[ k*hN + j ] - r->hb[j] ));
  399. if(0)
  400. {
  401. cmVOD_PrintL("hp0M",rpt,hN,dN,hp0M);
  402. cmVOD_PrintL("hs0M",rpt,dN,hN,hs0M);
  403. cmVOD_PrintL("vp1M",rpt,dN,vN,vp1M);
  404. cmVOD_PrintL("vs1M",rpt,vN,dN,vs1M);
  405. cmVOD_PrintL("hp1M",rpt,hN,dN,hp1M);
  406. }
  407. //
  408. // Update Wieghts
  409. //
  410. // vh0M[hN,vN] = hp0M[hN,dN] * d[vN,dN]'
  411. cmVOD_MultMMMt(vh0M, hN, vN, hp0M, d, dN );
  412. cmVOD_MultMMMt(vh1M, hN, vN, hp1M, vs1M, dN );
  413. for(i=0,k=0; i<vN; ++i)
  414. for(j=0; j<hN; ++j,++k)
  415. {
  416. dwM[k] = p->momentum * dwM[k] + p->eta * ( (vh0M[k] - vh1M[k]) / (dN * r->vd[i]) );
  417. r->W[k] += dwM[k];
  418. }
  419. //
  420. // Update hidden bias
  421. //
  422. // sum(hp0M - hp1M,2) - sum the difference of rows of hp0M and hp1M
  423. cmVOD_SubVV(hp0M,hN*dN,hp1M); // hp0M -= hp1M
  424. cmVOD_SumMN(hp0M,hN,dN,hp1M); // hp1M[1:hN] = sum(hp0M,2) (note: hp1M is rused as temp space)
  425. for(j=0; j<hN; ++j)
  426. {
  427. hdbV[j] = p->momentum * hdbV[j] + p->eta * (hp1M[j] / (dN * r->vd[i] * r->vd[i]));
  428. r->hb[j] += hdbV[j];
  429. }
  430. //
  431. // Update visible bias
  432. //
  433. // sum(d - vs1M, 2)
  434. cmVOD_SubVVV(vp1M,vN*dN,d,vs1M); // vp1M = d - vs1M (vp1M is reused as temp space)
  435. cmVOD_SumMN(vp1M,vN,dN,vs1M); // vs1M[1:vn] = sum(vp1M,2) (vs1M is reused as temp space)
  436. for(i=0; i<vN; ++i)
  437. {
  438. vdbV[i] = p->momentum * vdbV[i] + p->eta * (vs1M[i] / dN );
  439. r->vb[i] += vdbV[i];
  440. }
  441. for(i=0; i<vN; ++i)
  442. {
  443. for(j=0; j<hN; ++j)
  444. {
  445. double sum_d = 0;
  446. double sum_m = 0;
  447. for(k=0; k<dN; ++k)
  448. {
  449. sum_d += hs0M[ j*dN + k ] * r->W[ i*hN + j ];
  450. sum_m += hp1M[ k*hN + j ] * r->W[ i*hN + j ];
  451. }
  452. }
  453. }
  454. if(0)
  455. {
  456. cmVOD_PrintL("dwM", rpt, vN, hN, dwM );
  457. cmVOD_PrintL("vdbV",rpt, 1, vN, vdbV );
  458. cmVOD_PrintL("hdbV",rpt, 1, hN, hdbV );
  459. cmVOD_PrintL("W", rpt, vN, hN, r->W );
  460. cmVOD_PrintL("vb", rpt, 1, vN, r->vb );
  461. cmVOD_PrintL("hb", rpt, 1, hN, r->hb );
  462. }
  463. } // di
  464. cmRptPrintf(rpt,"err:%f\n",err);
  465. if( cmStackIsValid(r->monH))
  466. {
  467. cmRbmMonitor_t monErr;
  468. monErr.trainErr = err;
  469. cmStackPush(r->monH,&monErr,1);
  470. }
  471. } // ei
  472. cmRptPrintf(rpt,"eta:%f momentum:%f\n",p->eta,p->momentum);
  473. cmVOD_PrintL("dwM", rpt, vN, hN, dwM );
  474. cmVOD_PrintL("vdbV",rpt, 1, vN, vdbV );
  475. cmVOD_PrintL("hdbV",rpt, 1, hN, hdbV );
  476. cmVOD_PrintL("W", rpt, vN, hN, r->W );
  477. cmVOD_PrintL("vb", rpt, 1, vN, r->vb );
  478. cmVOD_PrintL("hb", rpt, 1, hN, r->hb );
  479. cmMemFree(m);
  480. }
  481. void cmRbmBinaryTest( cmCtx_t* ctx )
  482. {
  483. const char* monitorFn = "/home/kevin/temp/cmRbmMonitor0.mtx";
  484. const char* dataFn = "/home/kevin/temp/cmRbmData0.mtx";
  485. unsigned pointsN = 1000;
  486. unsigned dimN = 4;
  487. unsigned vN = dimN;
  488. unsigned hN = 32;
  489. //double probV[] = {0.1,0.2,0.8,0.7};
  490. cmRbmTrainParms_t r;
  491. cmRBM_t* rbm;
  492. r.maxX = 1.0;
  493. r.minX = 0.0;
  494. r.initW = 0.1;
  495. r.eta = 0.01;
  496. r.holdOutFrac = 0.1;
  497. r.epochCnt = 10;
  498. r.momentum = 0.5;
  499. r.batchCnt = 10;
  500. if(0)
  501. {
  502. vN = 4;
  503. hN = 6;
  504. double d[] = {
  505. 0, 1, 1, 1,
  506. 1, 1, 1, 1,
  507. 0, 0, 1, 0,
  508. 0, 0, 1, 0,
  509. 0, 1, 1, 0,
  510. 0, 1, 1, 1,
  511. 1, 1, 1, 1,
  512. 0, 0, 1, 0,
  513. 0, 0, 1, 0,
  514. 0, 1, 1, 0
  515. };
  516. if( (rbm = _cmRbmAlloc(ctx, vN, hN )) == NULL )
  517. return;
  518. pointsN = sizeof(d) / (sizeof(d[0]) * vN);
  519. cmRbmBinaryTrain(ctx,rbm,&r,pointsN,d);
  520. return;
  521. }
  522. if( (rbm = _cmRbmAlloc(ctx, vN, hN )) == NULL )
  523. return;
  524. //double* data0M = cmRbmGenBinaryTestData(ctx,dataFn,probV,dimN,pointsN);
  525. double* data0M = cmRbmReadDataFile(ctx,dataFn,&dimN,&pointsN);
  526. double t[ vN ];
  527. // Sum the columns of sp[srn,scn] into dp[scn].
  528. // dp[] is zeroed prior to computing the sum.
  529. cmVOD_SumMN(data0M, dimN, pointsN, t );
  530. cmVOD_Print( &ctx->rpt, 1, dimN, t );
  531. if(0)
  532. {
  533. //
  534. // Standardize data (subtract mean and divide by standard deviation)
  535. // then set the visible layers initial standard deviation to 1.0.
  536. //
  537. cmVOD_StandardizeRows( data0M, rbm->vN, pointsN, NULL, NULL );
  538. cmVOD_Fill( rbm->vd, rbm->vN, 1.0 );
  539. cmRbmRealTrain(ctx,rbm,&r,pointsN,data0M);
  540. }
  541. cmRbmBinaryTrain(ctx,rbm,&r,pointsN,data0M);
  542. cmRbmWriteMonitorFile(ctx, rbm->monH, monitorFn );
  543. cmMemFree(data0M);
  544. _cmRbmRelease(rbm);
  545. }