libcm is a C development framework with an emphasis on audio signal processing applications.
Você não pode selecionar mais de 25 tópicos Os tópicos devem começar com uma letra ou um número, podem incluir traços ('-') e podem ter até 35 caracteres.

cmRbm.c 18KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698
  1. #include "cmPrefix.h"
  2. #include "cmGlobal.h"
  3. #include "cmFloatTypes.h"
  4. #include "cmComplexTypes.h"
  5. #include "cmRpt.h"
  6. #include "cmErr.h"
  7. #include "cmCtx.h"
  8. #include "cmMem.h"
  9. #include "cmMallocDebug.h"
  10. #include "cmLinkedHeap.h"
  11. #include "cmMath.h"
  12. #include "cmFile.h"
  13. #include "cmSymTbl.h"
  14. #include "cmMidi.h"
  15. #include "cmAudioFile.h"
  16. #include "cmVectOpsTemplateMain.h"
  17. #include "cmStack.h"
  18. #include "cmProcObj.h"
  19. #include "cmProcTemplateMain.h"
  20. #include "cmVectOps.h"
  21. #include "cmProc.h"
  22. #include "cmProc2.h"
  23. #include "cmRbm.h"
  24. typedef struct
  25. {
  26. double trainErr;
  27. double testErr;
  28. } cmRbmMonitor_t;
  29. cmRbmRC_t cmRbmWriteMonitorFile( cmCtx_t* c, cmStackH_t monH, const cmChar_t* fn )
  30. {
  31. cmRbmRC_t rc = kOkRbmRC;
  32. cmCtx* ctx = cmCtxAlloc(NULL, c->err.rpt, cmLHeapNullHandle, cmSymTblNullHandle );
  33. if( cmBinMtxFileWrite(fn, cmStackCount(monH), sizeof(cmRbmMonitor_t)/sizeof(double), NULL, cmStackFlatten(monH), ctx, c->err.rpt ) != cmOkRC )
  34. {
  35. rc = cmErrMsg(&c->err,kMonitorWrFailRbmRC,"Training monitor file '%s' write failed.",cmStringNullGuard(fn));
  36. goto errLabel;
  37. }
  38. errLabel:
  39. cmCtxFree(&ctx);
  40. return rc;
  41. }
  42. double* cmRbmReadDataFile( cmCtx_t* c, const char* fn, unsigned* dimNPtr, unsigned* pointCntPtr )
  43. {
  44. unsigned rowCnt,colCnt,eleByteCnt;
  45. *dimNPtr = 0;
  46. *pointCntPtr = 0;
  47. if( cmBinMtxFileSize(c, fn, &rowCnt, &colCnt, &eleByteCnt ) != cmOkRC )
  48. return NULL;
  49. double* buf = cmMemAllocZ(double,rowCnt*colCnt);
  50. if( cmBinMtxFileRead(c, fn, rowCnt, colCnt, sizeof(double), buf,NULL) != cmOkRC )
  51. {
  52. cmMemFree(buf);
  53. return NULL;
  54. }
  55. *dimNPtr = rowCnt;
  56. *pointCntPtr = colCnt;
  57. return buf;
  58. }
  59. // Generate a matrix of 'pointsN' random binary valued column vectors of dimension dimN.
  60. // The first i = {0...'dimN'-1} elements of each vector contain ones with prob probV[i]
  61. // (or zeros with prob 1 - probV[i].). probV[i] in [0.0,1.0].
  62. // The last element in each column is set to zero.
  63. // The returned matrix m[ dimN+1, pointsN ] is in column major order and
  64. // must be deleted by the caller (e.g. cmMemFree(m)).
  65. double* cmRbmGenBinaryTestData( cmCtx_t* c, const char* fn, const double* probV, unsigned dimN, unsigned pointsN )
  66. {
  67. if( dimN == 0 || pointsN == 0 )
  68. return NULL;
  69. double* m = cmMemAllocZ( double, dimN*pointsN );
  70. unsigned i,j;
  71. for(i=0; i<pointsN; ++i)
  72. for(j=0; j<dimN; ++j)
  73. m[ i*dimN + j ] = rand() < (probV[j] * RAND_MAX);
  74. if( fn != NULL )
  75. cmBinMtxFileWrite(fn,dimN,pointsN,NULL,m,NULL,c->err.rpt);
  76. return m;
  77. }
  78. typedef struct
  79. {
  80. unsigned vN;
  81. double* vs;
  82. double* vp;
  83. double* vb;
  84. double* vd; // std dev. (var = std_dev^2)
  85. unsigned hN;
  86. double* hs;
  87. double* hp;
  88. double* hb;
  89. double* W; // W[vN,hN]
  90. cmStackH_t monH;
  91. } cmRBM_t;
  92. void _cmRbmPrint( cmRBM_t* r, cmRpt_t* rpt )
  93. {
  94. cmVOD_PrintL("hb", rpt, 1, r->hN, r->hb );
  95. cmVOD_PrintL("hp", rpt, 1, r->hN, r->hp );
  96. cmVOD_PrintL("hs", rpt, 1, r->hN, r->hs );
  97. cmVOD_PrintL("vb", rpt, 1, r->vN, r->vb );
  98. cmVOD_PrintL("vp", rpt, 1, r->vN, r->vp );
  99. cmVOD_PrintL("vs", rpt, 1, r->vN, r->vs );
  100. cmVOD_PrintL("W", rpt, r->vN, r->hN, r->W );
  101. }
  102. void _cmRbmRelease( cmRBM_t* r )
  103. {
  104. cmStackFree(&r->monH);
  105. cmMemFree(r);
  106. }
  107. // Adjust the layer geometry to force all sizes to be a multiple of 16 bytes.
  108. // This assumes that all data will be 8 byte doubles.
  109. void _cmRbmAdjustSizes( unsigned* vNp, unsigned* hNp, unsigned* dNp )
  110. {
  111. *vNp = *vNp + (cmIsOddU(*vNp) ? 1 : 0);
  112. *hNp = *hNp + (cmIsOddU(*hNp) ? 1 : 0);
  113. if( dNp != NULL )
  114. *dNp = *dNp + (cmIsOddU(*dNp) ? 1 : 0);
  115. }
  116. cmRBM_t* _cmRbmAlloc( cmCtx_t* ctx, unsigned vN, unsigned hN )
  117. {
  118. unsigned monInitCnt = 1000;
  119. unsigned monExpandCnt = 1000;
  120. // adjust sizes to force base array addresses to be a multiple of 16 bytes.
  121. unsigned vn = vN;
  122. unsigned hn = hN;
  123. _cmRbmAdjustSizes(&vn,&hn,NULL);
  124. unsigned rn = sizeof(cmRBM_t);
  125. // force record to be a multiple of 16
  126. if( rn % 16 )
  127. rn += 16 - (rn % 16);
  128. unsigned dn = 4*vn + 3*hn + vn*hn;
  129. unsigned bn = rn + dn*sizeof(double);
  130. char* cp = cmMemAllocZ(char,bn);
  131. cmRBM_t* r = (cmRBM_t*)cp;
  132. r->vs = (double*)(cp+rn);
  133. r->vp = r->vs + vn;
  134. r->vb = r->vp + vn;
  135. r->vd = r->vb + vn;
  136. r->hs = r->vd + vn;
  137. r->hp = r->hs + hn;
  138. r->hb = r->hp + hn;
  139. r->W = r->hb + hn;
  140. r->vN = vN;
  141. r->hN = hN;
  142. assert(cp+bn == (char*)(r->W + vn*hn));
  143. if( cmStackAlloc(ctx, &r->monH, monInitCnt, monExpandCnt, sizeof(cmRbmMonitor_t)) != kOkStRC )
  144. {
  145. cmErrMsg(&ctx->err,kStackFailRbmRC,"Stack allocation failed for the training monitor data array.");
  146. goto errLabel;
  147. }
  148. return r;
  149. errLabel:
  150. _cmRbmRelease(r);
  151. return NULL;
  152. }
  153. void cmRbmBinaryTrain(
  154. cmCtx_t* ctx,
  155. cmRBM_t* r,
  156. cmRbmTrainParms_t* p,
  157. unsigned dMN,
  158. const double* dM )
  159. {
  160. cmRpt_t* rpt = ctx->err.rpt;
  161. bool stochFl = true;
  162. unsigned i,j,k,ei,di;
  163. unsigned vN = r->vN;
  164. unsigned hN = r->hN;
  165. // adjust the memory sizes to align all arrays on 16 byte boundaries
  166. unsigned vn = vN;
  167. unsigned hn = hN;
  168. unsigned dn = p->batchCnt;
  169. _cmRbmAdjustSizes(&vn,&hn,&dn);
  170. unsigned mn = (3 * hn * vn) + (1 * vn) + (1 * hn) + (3 * hn * dn) + (2 * vn * dn);
  171. double* m = cmMemAllocZ(double,mn);
  172. double* vh0M = m; // vh0M[ hN, vN ]
  173. double* vh1M = vh0M + hn*vn; // vh1M[ hN, vN ]
  174. double* dwM = vh1M + hn*vn; // dwM[ hN, vN ]
  175. double* vdbV = dwM + hn*vn; // vdbV[ vN ]
  176. double* hdbV = vdbV + vn; // hdbV[ hN ]
  177. double* hp0M = hdbV + hn; // hp0M[ hN, dN ]
  178. double* hs0M = hp0M + dn * hn; // hs0M[ dN, hN ]
  179. double* hp1M = hs0M + dn * hn; // hp1M[ hN, dN ]
  180. double* vp1M = hp1M + dn * hn; // vp1M[ dN, vN ]
  181. double* vs1M = vp1M + dn * vn; // vs1M[ vN, dN ]
  182. assert( vs1M + vn * dn == m + mn );
  183. // initilaize the weights with random values
  184. // W = p->initW * randn(vN,hN,0.0,1.0)
  185. for(i=0; i<vN; ++i)
  186. cmVOD_RandomGauss( r->W + i*hN, hN, 0.0, 1.0 );
  187. cmVOD_MultVS( r->W, hN*vN, p->initW);
  188. if(0)
  189. {
  190. const cmChar_t* fn = "/home/kevin/temp/cmRbmWeight.mtx";
  191. //cmBinMtxFileWrite(fn,hN, vN,NULL,dM,NULL,ctx->err.rpt);
  192. cmBinMtxFileRead( ctx, fn, hN, vN, sizeof(double), r->W,NULL);
  193. }
  194. cmVOD_Zero( dwM, vN*hN );
  195. cmVOD_Zero( vdbV, vN );
  196. cmVOD_Zero( hdbV, hN );
  197. for(ei=0; ei<p->epochCnt; ++ei)
  198. {
  199. unsigned dN = 0;
  200. double err = 0;
  201. for(di=0; di<dMN; di+=dN)
  202. {
  203. dN = cmMin(p->batchCnt,dMN-di);
  204. const double* d = dM + di * vN; // d[ vN, dN ]
  205. //
  206. // Update hidden layer from data
  207. //
  208. // hp0M[hN,dN] = W[hN,vN] * d[vN,dN]
  209. cmVOD_MultMMM(hp0M,hN,dN,r->W,d,vN);
  210. // calc hs0M[dN,hN]
  211. for(k=0; k<dN; ++k)
  212. for(j=0; j<hN; ++j)
  213. {
  214. hp0M[ k*hN + j ] = 1.0/(1.0 + exp(-(hp0M[ k*hN + j] + r->hb[j])));
  215. hs0M[ j*dN + k ] = rand() < hp0M[ k*hN + j ] * RAND_MAX;
  216. if( !stochFl )
  217. hs0M[ j*dN + k ] = hp0M[ k*hN + j ] > 0.5;
  218. }
  219. //
  220. // Reconstruct visible layer from hidden
  221. //
  222. // vp1M[dN,vN] = hs0M[dN,hN] * W[hN,vN]
  223. cmVOD_MultMMM(vp1M,dN,vN,hs0M,r->W,hN);
  224. // calc vs1M[vN,dN]
  225. for(k=0; k<dN; ++k)
  226. for(i=0; i<vN; ++i)
  227. {
  228. vp1M[ i*dN + k ] = 1.0/(1.0 + exp(-( vp1M[ i*dN + k ] + r->vb[i]) ) );
  229. vs1M[ k*vN + i ] = rand() < vp1M[ i*dN + k ] * RAND_MAX;
  230. if( !stochFl )
  231. vs1M[ k*vN + i ] = vp1M[ i*dN + k ] > 0.5;
  232. // calc training error
  233. err += pow(d[ k*vN + i ] - vp1M[ i*dN + k ],2.0);
  234. }
  235. //
  236. // Update hidden layer from reconstruction
  237. //
  238. // hp1M[hN,dN] = W[hN,vN] * vs1[vN,dN]
  239. cmVOD_MultMMM(hp1M,hN,dN,r->W,vs1M,vN);
  240. // calc hp1M[hN,dN]
  241. for(k=0; k<dN; ++k)
  242. for(j=0; j<hN; ++j)
  243. hp1M[ k*hN + j ] = 1.0/(1.0 + exp( -hp1M[ k*hN + j ] - r->hb[j] ));
  244. if(0)
  245. {
  246. cmVOD_PrintL("hp0M",rpt,hN,dN,hp0M);
  247. cmVOD_PrintL("hs0M",rpt,dN,hN,hs0M);
  248. cmVOD_PrintL("vp1M",rpt,dN,vN,vp1M);
  249. cmVOD_PrintL("vs1M",rpt,vN,dN,vs1M);
  250. cmVOD_PrintL("hp1M",rpt,hN,dN,hp1M);
  251. }
  252. //
  253. // Update Wieghts
  254. //
  255. // vh0M[hN,vN] = hp0M[hN,dN] * d[vN,dN]'
  256. cmVOD_MultMMMt(vh0M, hN, vN, hp0M, d, dN );
  257. cmVOD_MultMMMt(vh1M, hN, vN, hp1M, vs1M, dN );
  258. for(i=0; i<hN*vN; ++i)
  259. {
  260. dwM[i] = p->momentum * dwM[i] + p->eta * ( (vh0M[i] - vh1M[i]) / dN );
  261. r->W[i] += dwM[i];
  262. }
  263. //
  264. // Update hidden bias
  265. //
  266. // sum(hp0M - hp1M,2) - sum the difference of rows of hp0M and hp1M
  267. cmVOD_SubVV(hp0M,hN*dN,hp1M); // hp0M -= hp1M
  268. cmVOD_SumMN(hp0M,hN,dN,hp1M); // hp1M[1:hN] = sum(hp0M,2) (note: hp1M is rused as temp space)
  269. for(j=0; j<hN; ++j)
  270. {
  271. hdbV[j] = p->momentum * hdbV[j] + p->eta * (hp1M[j] / dN);
  272. r->hb[j] += hdbV[j];
  273. }
  274. //
  275. // Update visible bias
  276. //
  277. // sum(d - vs1M, 2)
  278. cmVOD_SubVVV(vp1M,vN*dN,d,vs1M); // vp1M = d - vs1M (vp1M is reused as temp space)
  279. cmVOD_SumMN(vp1M,vN,dN,vs1M); // vs1M[1:vn] = sum(vp1M,2) (vs1M is reused as temp space)
  280. for(i=0; i<vN; ++i)
  281. {
  282. vdbV[i] = p->momentum * vdbV[i] + p->eta * (vs1M[i] / dN );
  283. r->vb[i] += vdbV[i];
  284. }
  285. if(0)
  286. {
  287. cmVOD_PrintL("dwM", rpt, vN, hN, dwM );
  288. cmVOD_PrintL("vdbV",rpt, 1, vN, vdbV );
  289. cmVOD_PrintL("hdbV",rpt, 1, hN, hdbV );
  290. cmVOD_PrintL("W", rpt, vN, hN, r->W );
  291. cmVOD_PrintL("vb", rpt, 1, vN, r->vb );
  292. cmVOD_PrintL("hb", rpt, 1, hN, r->hb );
  293. }
  294. } // di
  295. cmRptPrintf(rpt,"err:%f\n",err);
  296. if( cmStackIsValid(r->monH))
  297. {
  298. cmRbmMonitor_t monErr;
  299. monErr.trainErr = err;
  300. cmStackPush(r->monH,&monErr,1);
  301. }
  302. } // ei
  303. cmRptPrintf(rpt,"eta:%f momentum:%f\n",p->eta,p->momentum);
  304. cmVOD_PrintL("dwM", rpt, vN, hN, dwM );
  305. cmVOD_PrintL("vdbV",rpt, 1, vN, vdbV );
  306. cmVOD_PrintL("hdbV",rpt, 1, hN, hdbV );
  307. cmVOD_PrintL("W", rpt, vN, hN, r->W );
  308. cmVOD_PrintL("vb", rpt, 1, vN, r->vb );
  309. cmVOD_PrintL("hb", rpt, 1, hN, r->hb );
  310. cmMemFree(m);
  311. }
  312. void cmRbmRealTrain(
  313. cmCtx_t* ctx,
  314. cmRBM_t* r,
  315. cmRbmTrainParms_t* p,
  316. unsigned dMN,
  317. const double* dM )
  318. {
  319. cmRpt_t* rpt = ctx->err.rpt;
  320. unsigned i,j,k,ei,di;
  321. unsigned vN = r->vN;
  322. unsigned hN = r->hN;
  323. // adjust the memory sizes to align all arrays on 16 byte boundaries
  324. unsigned vn = vN;
  325. unsigned hn = hN;
  326. unsigned dn = p->batchCnt;
  327. _cmRbmAdjustSizes(&vn,&hn,&dn);
  328. unsigned mn = (3 * hn * vn) + (1 * vn) + (1 * hn) + (3 * hn * dn) + (2 * vn * dn);
  329. double* m = cmMemAllocZ(double,mn);
  330. double* vh0M = m; // vh0M[ hN, vN ]
  331. double* vh1M = vh0M + hn*vn; // vh1M[ hN, vN ]
  332. double* dwM = vh1M + hn*vn; // dwM[ hN, vN ]
  333. double* vdbV = dwM + hn*vn; // vdbV[ vN ]
  334. double* hdbV = vdbV + vn; // hdbV[ hN ]
  335. double* hp0M = hdbV + hn; // hp0M[ hN, dN ]
  336. double* hs0M = hp0M + dn * hn; // hs0M[ dN, hN ]
  337. double* hp1M = hs0M + dn * hn; // hp1M[ hN, dN ]
  338. double* vp1M = hp1M + dn * hn; // vp1M[ dN, vN ]
  339. double* vs1M = vp1M + dn * vn; // vs1M[ vN, dN ]
  340. assert( vs1M + vn * dn == m + mn );
  341. //
  342. // Initilaize the weights with small random values
  343. // W = p->initW * randn(vN,hN,0.0,1.0)
  344. for(i=0; i<vN; ++i)
  345. cmVOD_RandomGauss( r->W + i*hN, hN, 0.0, 1.0 );
  346. cmVOD_MultVS( r->W, hN*vN, p->initW);
  347. if(0)
  348. {
  349. const cmChar_t* fn = "/home/kevin/temp/cmRbmWeight.mtx";
  350. //cmBinMtxFileWrite(fn,hN, vN,NULL,dM,NULL,ctx->err.rpt);
  351. cmBinMtxFileRead( ctx, fn, hN, vN, sizeof(double), r->W,NULL);
  352. }
  353. cmVOD_Zero( dwM, vN*hN );
  354. cmVOD_Zero( vdbV, vN );
  355. cmVOD_Zero( hdbV, hN );
  356. for(ei=0; ei<p->epochCnt; ++ei)
  357. {
  358. unsigned dN = 0;
  359. double err = 0;
  360. for(di=0; di<dMN; di+=dN)
  361. {
  362. dN = cmMin(p->batchCnt,dMN-di);
  363. const double* d = dM + di * vN; // d[ vN, dN ]
  364. //
  365. // Update hidden layer from data
  366. //
  367. // hp0M[hN,dN] = W[hN,vN] * d[vN,dN]
  368. cmVOD_MultMMM(hp0M,hN,dN,r->W,d,vN);
  369. // calc hs0M[dN,hN]
  370. for(k=0; k<dN; ++k)
  371. for(j=0; j<hN; ++j)
  372. {
  373. hp0M[ k*hN + j ] = 1.0/(1.0 + exp(-(hp0M[ k*hN + j] + r->hb[j])));
  374. hs0M[ j*dN + k ] = rand() < hp0M[ k*hN + j ] * RAND_MAX;
  375. }
  376. //
  377. // Reconstruct visible layer from hidden
  378. //
  379. // vp1M[dN,vN] = hs0M[dN,hN] * W[hN,vN]
  380. cmVOD_MultMMM(vp1M,dN,vN,hs0M,r->W,hN);
  381. // calc vs1M[vN,dN]
  382. for(k=0; k<dN; ++k)
  383. for(i=0; i<vN; ++i)
  384. {
  385. vp1M[ i*dN + k ] = r->vd[i] * vp1M[ i*dN + k ] + r->vb[i];
  386. cmVOD_GaussPDF(vs1M + k*vN + i, 1, vp1M + i*dN + k, r->vb[i], r->vd[i] );
  387. // calc training error
  388. err += pow(d[ k*vN + i ] - vp1M[ i*dN + k ],2.0);
  389. }
  390. //
  391. // Update hidden layer from reconstruction
  392. //
  393. // hp1M[hN,dN] = W[hN,vN] * vs1[vN,dN]
  394. cmVOD_MultMMM(hp1M,hN,dN,r->W,vs1M,vN);
  395. // calc hp1M[hN,dN]
  396. for(k=0; k<dN; ++k)
  397. for(j=0; j<hN; ++j)
  398. hp1M[ k*hN + j ] = 1.0/(1.0 + exp( -hp1M[ k*hN + j ] - r->hb[j] ));
  399. if(0)
  400. {
  401. cmVOD_PrintL("hp0M",rpt,hN,dN,hp0M);
  402. cmVOD_PrintL("hs0M",rpt,dN,hN,hs0M);
  403. cmVOD_PrintL("vp1M",rpt,dN,vN,vp1M);
  404. cmVOD_PrintL("vs1M",rpt,vN,dN,vs1M);
  405. cmVOD_PrintL("hp1M",rpt,hN,dN,hp1M);
  406. }
  407. //
  408. // Update Wieghts
  409. //
  410. // vh0M[hN,vN] = hp0M[hN,dN] * d[vN,dN]'
  411. cmVOD_MultMMMt(vh0M, hN, vN, hp0M, d, dN );
  412. cmVOD_MultMMMt(vh1M, hN, vN, hp1M, vs1M, dN );
  413. for(i=0,k=0; i<vN; ++i)
  414. for(j=0; j<hN; ++j,++k)
  415. {
  416. dwM[k] = p->momentum * dwM[k] + p->eta * ( (vh0M[k] - vh1M[k]) / (dN * r->vd[i]) );
  417. r->W[k] += dwM[k];
  418. }
  419. //
  420. // Update hidden bias
  421. //
  422. // sum(hp0M - hp1M,2) - sum the difference of rows of hp0M and hp1M
  423. cmVOD_SubVV(hp0M,hN*dN,hp1M); // hp0M -= hp1M
  424. cmVOD_SumMN(hp0M,hN,dN,hp1M); // hp1M[1:hN] = sum(hp0M,2) (note: hp1M is rused as temp space)
  425. for(j=0; j<hN; ++j)
  426. {
  427. hdbV[j] = p->momentum * hdbV[j] + p->eta * (hp1M[j] / (dN * r->vd[i] * r->vd[i]));
  428. r->hb[j] += hdbV[j];
  429. }
  430. //
  431. // Update visible bias
  432. //
  433. // sum(d - vs1M, 2)
  434. cmVOD_SubVVV(vp1M,vN*dN,d,vs1M); // vp1M = d - vs1M (vp1M is reused as temp space)
  435. cmVOD_SumMN(vp1M,vN,dN,vs1M); // vs1M[1:vn] = sum(vp1M,2) (vs1M is reused as temp space)
  436. for(i=0; i<vN; ++i)
  437. {
  438. vdbV[i] = p->momentum * vdbV[i] + p->eta * (vs1M[i] / dN );
  439. r->vb[i] += vdbV[i];
  440. }
  441. for(i=0; i<vN; ++i)
  442. {
  443. for(j=0; j<hN; ++j)
  444. {
  445. double sum_d = 0;
  446. double sum_m = 0;
  447. for(k=0; k<dN; ++k)
  448. {
  449. sum_d += hs0M[ j*dN + k ] * r->W[ i*hN + j ];
  450. sum_m += hp1M[ k*hN + j ] * r->W[ i*hN + j ];
  451. }
  452. }
  453. }
  454. if(0)
  455. {
  456. cmVOD_PrintL("dwM", rpt, vN, hN, dwM );
  457. cmVOD_PrintL("vdbV",rpt, 1, vN, vdbV );
  458. cmVOD_PrintL("hdbV",rpt, 1, hN, hdbV );
  459. cmVOD_PrintL("W", rpt, vN, hN, r->W );
  460. cmVOD_PrintL("vb", rpt, 1, vN, r->vb );
  461. cmVOD_PrintL("hb", rpt, 1, hN, r->hb );
  462. }
  463. } // di
  464. cmRptPrintf(rpt,"err:%f\n",err);
  465. if( cmStackIsValid(r->monH))
  466. {
  467. cmRbmMonitor_t monErr;
  468. monErr.trainErr = err;
  469. cmStackPush(r->monH,&monErr,1);
  470. }
  471. } // ei
  472. cmRptPrintf(rpt,"eta:%f momentum:%f\n",p->eta,p->momentum);
  473. cmVOD_PrintL("dwM", rpt, vN, hN, dwM );
  474. cmVOD_PrintL("vdbV",rpt, 1, vN, vdbV );
  475. cmVOD_PrintL("hdbV",rpt, 1, hN, hdbV );
  476. cmVOD_PrintL("W", rpt, vN, hN, r->W );
  477. cmVOD_PrintL("vb", rpt, 1, vN, r->vb );
  478. cmVOD_PrintL("hb", rpt, 1, hN, r->hb );
  479. cmMemFree(m);
  480. }
  481. void cmRbmBinaryTest( cmCtx_t* ctx )
  482. {
  483. const char* monitorFn = "/home/kevin/temp/cmRbmMonitor0.mtx";
  484. const char* dataFn = "/home/kevin/temp/cmRbmData0.mtx";
  485. unsigned pointsN = 1000;
  486. unsigned dimN = 4;
  487. unsigned vN = dimN;
  488. unsigned hN = 32;
  489. //double probV[] = {0.1,0.2,0.8,0.7};
  490. cmRbmTrainParms_t r;
  491. cmRBM_t* rbm;
  492. r.maxX = 1.0;
  493. r.minX = 0.0;
  494. r.initW = 0.1;
  495. r.eta = 0.01;
  496. r.holdOutFrac = 0.1;
  497. r.epochCnt = 10;
  498. r.momentum = 0.5;
  499. r.batchCnt = 10;
  500. if(0)
  501. {
  502. vN = 4;
  503. hN = 6;
  504. double d[] = {
  505. 0, 1, 1, 1,
  506. 1, 1, 1, 1,
  507. 0, 0, 1, 0,
  508. 0, 0, 1, 0,
  509. 0, 1, 1, 0,
  510. 0, 1, 1, 1,
  511. 1, 1, 1, 1,
  512. 0, 0, 1, 0,
  513. 0, 0, 1, 0,
  514. 0, 1, 1, 0
  515. };
  516. if( (rbm = _cmRbmAlloc(ctx, vN, hN )) == NULL )
  517. return;
  518. pointsN = sizeof(d) / (sizeof(d[0]) * vN);
  519. cmRbmBinaryTrain(ctx,rbm,&r,pointsN,d);
  520. return;
  521. }
  522. if( (rbm = _cmRbmAlloc(ctx, vN, hN )) == NULL )
  523. return;
  524. //double* data0M = cmRbmGenBinaryTestData(ctx,dataFn,probV,dimN,pointsN);
  525. double* data0M = cmRbmReadDataFile(ctx,dataFn,&dimN,&pointsN);
  526. double t[ vN ];
  527. // Sum the columns of sp[srn,scn] into dp[scn].
  528. // dp[] is zeroed prior to computing the sum.
  529. cmVOD_SumMN(data0M, dimN, pointsN, t );
  530. cmVOD_Print( &ctx->rpt, 1, dimN, t );
  531. if(0)
  532. {
  533. //
  534. // Standardize data (subtract mean and divide by standard deviation)
  535. // then set the visible layers initial standard deviation to 1.0.
  536. //
  537. cmVOD_StandardizeRows( data0M, rbm->vN, pointsN, NULL, NULL );
  538. cmVOD_Fill( rbm->vd, rbm->vN, 1.0 );
  539. cmRbmRealTrain(ctx,rbm,&r,pointsN,data0M);
  540. }
  541. cmRbmBinaryTrain(ctx,rbm,&r,pointsN,data0M);
  542. cmRbmWriteMonitorFile(ctx, rbm->monH, monitorFn );
  543. cmMemFree(data0M);
  544. _cmRbmRelease(rbm);
  545. }