libcm is a C development framework with an emphasis on audio signal processing applications.
Vous ne pouvez pas sélectionner plus de 25 sujets Les noms de sujets doivent commencer par une lettre ou un nombre, peuvent contenir des tirets ('-') et peuvent comporter jusqu'à 35 caractères.

cmRbm.c 18KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699
  1. #include "cmPrefix.h"
  2. #include "cmGlobal.h"
  3. #include "cmFloatTypes.h"
  4. #include "cmComplexTypes.h"
  5. #include "cmRpt.h"
  6. #include "cmErr.h"
  7. #include "cmCtx.h"
  8. #include "cmMem.h"
  9. #include "cmMallocDebug.h"
  10. #include "cmLinkedHeap.h"
  11. #include "cmMath.h"
  12. #include "cmFile.h"
  13. #include "cmSymTbl.h"
  14. #include "cmTime.h"
  15. #include "cmMidi.h"
  16. #include "cmAudioFile.h"
  17. #include "cmVectOpsTemplateMain.h"
  18. #include "cmStack.h"
  19. #include "cmProcObj.h"
  20. #include "cmProcTemplateMain.h"
  21. #include "cmVectOps.h"
  22. #include "cmProc.h"
  23. #include "cmProc2.h"
  24. #include "cmRbm.h"
  25. typedef struct
  26. {
  27. double trainErr;
  28. double testErr;
  29. } cmRbmMonitor_t;
  30. cmRbmRC_t cmRbmWriteMonitorFile( cmCtx_t* c, cmStackH_t monH, const cmChar_t* fn )
  31. {
  32. cmRbmRC_t rc = kOkRbmRC;
  33. cmCtx* ctx = cmCtxAlloc(NULL, c->err.rpt, cmLHeapNullHandle, cmSymTblNullHandle );
  34. if( cmBinMtxFileWrite(fn, cmStackCount(monH), sizeof(cmRbmMonitor_t)/sizeof(double), NULL, cmStackFlatten(monH), ctx, c->err.rpt ) != cmOkRC )
  35. {
  36. rc = cmErrMsg(&c->err,kMonitorWrFailRbmRC,"Training monitor file '%s' write failed.",cmStringNullGuard(fn));
  37. goto errLabel;
  38. }
  39. errLabel:
  40. cmCtxFree(&ctx);
  41. return rc;
  42. }
  43. double* cmRbmReadDataFile( cmCtx_t* c, const char* fn, unsigned* dimNPtr, unsigned* pointCntPtr )
  44. {
  45. unsigned rowCnt,colCnt,eleByteCnt;
  46. *dimNPtr = 0;
  47. *pointCntPtr = 0;
  48. if( cmBinMtxFileSize(c, fn, &rowCnt, &colCnt, &eleByteCnt ) != cmOkRC )
  49. return NULL;
  50. double* buf = cmMemAllocZ(double,rowCnt*colCnt);
  51. if( cmBinMtxFileRead(c, fn, rowCnt, colCnt, sizeof(double), buf,NULL) != cmOkRC )
  52. {
  53. cmMemFree(buf);
  54. return NULL;
  55. }
  56. *dimNPtr = rowCnt;
  57. *pointCntPtr = colCnt;
  58. return buf;
  59. }
  60. // Generate a matrix of 'pointsN' random binary valued column vectors of dimension dimN.
  61. // The first i = {0...'dimN'-1} elements of each vector contain ones with prob probV[i]
  62. // (or zeros with prob 1 - probV[i].). probV[i] in [0.0,1.0].
  63. // The last element in each column is set to zero.
  64. // The returned matrix m[ dimN+1, pointsN ] is in column major order and
  65. // must be deleted by the caller (e.g. cmMemFree(m)).
  66. double* cmRbmGenBinaryTestData( cmCtx_t* c, const char* fn, const double* probV, unsigned dimN, unsigned pointsN )
  67. {
  68. if( dimN == 0 || pointsN == 0 )
  69. return NULL;
  70. double* m = cmMemAllocZ( double, dimN*pointsN );
  71. unsigned i,j;
  72. for(i=0; i<pointsN; ++i)
  73. for(j=0; j<dimN; ++j)
  74. m[ i*dimN + j ] = rand() < (probV[j] * RAND_MAX);
  75. if( fn != NULL )
  76. cmBinMtxFileWrite(fn,dimN,pointsN,NULL,m,NULL,c->err.rpt);
  77. return m;
  78. }
  79. typedef struct
  80. {
  81. unsigned vN;
  82. double* vs;
  83. double* vp;
  84. double* vb;
  85. double* vd; // std dev. (var = std_dev^2)
  86. unsigned hN;
  87. double* hs;
  88. double* hp;
  89. double* hb;
  90. double* W; // W[vN,hN]
  91. cmStackH_t monH;
  92. } cmRBM_t;
  93. void _cmRbmPrint( cmRBM_t* r, cmRpt_t* rpt )
  94. {
  95. cmVOD_PrintL("hb", rpt, 1, r->hN, r->hb );
  96. cmVOD_PrintL("hp", rpt, 1, r->hN, r->hp );
  97. cmVOD_PrintL("hs", rpt, 1, r->hN, r->hs );
  98. cmVOD_PrintL("vb", rpt, 1, r->vN, r->vb );
  99. cmVOD_PrintL("vp", rpt, 1, r->vN, r->vp );
  100. cmVOD_PrintL("vs", rpt, 1, r->vN, r->vs );
  101. cmVOD_PrintL("W", rpt, r->vN, r->hN, r->W );
  102. }
  103. void _cmRbmRelease( cmRBM_t* r )
  104. {
  105. cmStackFree(&r->monH);
  106. cmMemFree(r);
  107. }
  108. // Adjust the layer geometry to force all sizes to be a multiple of 16 bytes.
  109. // This assumes that all data will be 8 byte doubles.
  110. void _cmRbmAdjustSizes( unsigned* vNp, unsigned* hNp, unsigned* dNp )
  111. {
  112. *vNp = *vNp + (cmIsOddU(*vNp) ? 1 : 0);
  113. *hNp = *hNp + (cmIsOddU(*hNp) ? 1 : 0);
  114. if( dNp != NULL )
  115. *dNp = *dNp + (cmIsOddU(*dNp) ? 1 : 0);
  116. }
  117. cmRBM_t* _cmRbmAlloc( cmCtx_t* ctx, unsigned vN, unsigned hN )
  118. {
  119. unsigned monInitCnt = 1000;
  120. unsigned monExpandCnt = 1000;
  121. // adjust sizes to force base array addresses to be a multiple of 16 bytes.
  122. unsigned vn = vN;
  123. unsigned hn = hN;
  124. _cmRbmAdjustSizes(&vn,&hn,NULL);
  125. unsigned rn = sizeof(cmRBM_t);
  126. // force record to be a multiple of 16
  127. if( rn % 16 )
  128. rn += 16 - (rn % 16);
  129. unsigned dn = 4*vn + 3*hn + vn*hn;
  130. unsigned bn = rn + dn*sizeof(double);
  131. char* cp = cmMemAllocZ(char,bn);
  132. cmRBM_t* r = (cmRBM_t*)cp;
  133. r->vs = (double*)(cp+rn);
  134. r->vp = r->vs + vn;
  135. r->vb = r->vp + vn;
  136. r->vd = r->vb + vn;
  137. r->hs = r->vd + vn;
  138. r->hp = r->hs + hn;
  139. r->hb = r->hp + hn;
  140. r->W = r->hb + hn;
  141. r->vN = vN;
  142. r->hN = hN;
  143. assert(cp+bn == (char*)(r->W + vn*hn));
  144. if( cmStackAlloc(ctx, &r->monH, monInitCnt, monExpandCnt, sizeof(cmRbmMonitor_t)) != kOkStRC )
  145. {
  146. cmErrMsg(&ctx->err,kStackFailRbmRC,"Stack allocation failed for the training monitor data array.");
  147. goto errLabel;
  148. }
  149. return r;
  150. errLabel:
  151. _cmRbmRelease(r);
  152. return NULL;
  153. }
  154. void cmRbmBinaryTrain(
  155. cmCtx_t* ctx,
  156. cmRBM_t* r,
  157. cmRbmTrainParms_t* p,
  158. unsigned dMN,
  159. const double* dM )
  160. {
  161. cmRpt_t* rpt = ctx->err.rpt;
  162. bool stochFl = true;
  163. unsigned i,j,k,ei,di;
  164. unsigned vN = r->vN;
  165. unsigned hN = r->hN;
  166. // adjust the memory sizes to align all arrays on 16 byte boundaries
  167. unsigned vn = vN;
  168. unsigned hn = hN;
  169. unsigned dn = p->batchCnt;
  170. _cmRbmAdjustSizes(&vn,&hn,&dn);
  171. unsigned mn = (3 * hn * vn) + (1 * vn) + (1 * hn) + (3 * hn * dn) + (2 * vn * dn);
  172. double* m = cmMemAllocZ(double,mn);
  173. double* vh0M = m; // vh0M[ hN, vN ]
  174. double* vh1M = vh0M + hn*vn; // vh1M[ hN, vN ]
  175. double* dwM = vh1M + hn*vn; // dwM[ hN, vN ]
  176. double* vdbV = dwM + hn*vn; // vdbV[ vN ]
  177. double* hdbV = vdbV + vn; // hdbV[ hN ]
  178. double* hp0M = hdbV + hn; // hp0M[ hN, dN ]
  179. double* hs0M = hp0M + dn * hn; // hs0M[ dN, hN ]
  180. double* hp1M = hs0M + dn * hn; // hp1M[ hN, dN ]
  181. double* vp1M = hp1M + dn * hn; // vp1M[ dN, vN ]
  182. double* vs1M = vp1M + dn * vn; // vs1M[ vN, dN ]
  183. assert( vs1M + vn * dn == m + mn );
  184. // initilaize the weights with random values
  185. // W = p->initW * randn(vN,hN,0.0,1.0)
  186. for(i=0; i<vN; ++i)
  187. cmVOD_RandomGauss( r->W + i*hN, hN, 0.0, 1.0 );
  188. cmVOD_MultVS( r->W, hN*vN, p->initW);
  189. if(0)
  190. {
  191. const cmChar_t* fn = "/home/kevin/temp/cmRbmWeight.mtx";
  192. //cmBinMtxFileWrite(fn,hN, vN,NULL,dM,NULL,ctx->err.rpt);
  193. cmBinMtxFileRead( ctx, fn, hN, vN, sizeof(double), r->W,NULL);
  194. }
  195. cmVOD_Zero( dwM, vN*hN );
  196. cmVOD_Zero( vdbV, vN );
  197. cmVOD_Zero( hdbV, hN );
  198. for(ei=0; ei<p->epochCnt; ++ei)
  199. {
  200. unsigned dN = 0;
  201. double err = 0;
  202. for(di=0; di<dMN; di+=dN)
  203. {
  204. dN = cmMin(p->batchCnt,dMN-di);
  205. const double* d = dM + di * vN; // d[ vN, dN ]
  206. //
  207. // Update hidden layer from data
  208. //
  209. // hp0M[hN,dN] = W[hN,vN] * d[vN,dN]
  210. cmVOD_MultMMM(hp0M,hN,dN,r->W,d,vN);
  211. // calc hs0M[dN,hN]
  212. for(k=0; k<dN; ++k)
  213. for(j=0; j<hN; ++j)
  214. {
  215. hp0M[ k*hN + j ] = 1.0/(1.0 + exp(-(hp0M[ k*hN + j] + r->hb[j])));
  216. hs0M[ j*dN + k ] = rand() < hp0M[ k*hN + j ] * RAND_MAX;
  217. if( !stochFl )
  218. hs0M[ j*dN + k ] = hp0M[ k*hN + j ] > 0.5;
  219. }
  220. //
  221. // Reconstruct visible layer from hidden
  222. //
  223. // vp1M[dN,vN] = hs0M[dN,hN] * W[hN,vN]
  224. cmVOD_MultMMM(vp1M,dN,vN,hs0M,r->W,hN);
  225. // calc vs1M[vN,dN]
  226. for(k=0; k<dN; ++k)
  227. for(i=0; i<vN; ++i)
  228. {
  229. vp1M[ i*dN + k ] = 1.0/(1.0 + exp(-( vp1M[ i*dN + k ] + r->vb[i]) ) );
  230. vs1M[ k*vN + i ] = rand() < vp1M[ i*dN + k ] * RAND_MAX;
  231. if( !stochFl )
  232. vs1M[ k*vN + i ] = vp1M[ i*dN + k ] > 0.5;
  233. // calc training error
  234. err += pow(d[ k*vN + i ] - vp1M[ i*dN + k ],2.0);
  235. }
  236. //
  237. // Update hidden layer from reconstruction
  238. //
  239. // hp1M[hN,dN] = W[hN,vN] * vs1[vN,dN]
  240. cmVOD_MultMMM(hp1M,hN,dN,r->W,vs1M,vN);
  241. // calc hp1M[hN,dN]
  242. for(k=0; k<dN; ++k)
  243. for(j=0; j<hN; ++j)
  244. hp1M[ k*hN + j ] = 1.0/(1.0 + exp( -hp1M[ k*hN + j ] - r->hb[j] ));
  245. if(0)
  246. {
  247. cmVOD_PrintL("hp0M",rpt,hN,dN,hp0M);
  248. cmVOD_PrintL("hs0M",rpt,dN,hN,hs0M);
  249. cmVOD_PrintL("vp1M",rpt,dN,vN,vp1M);
  250. cmVOD_PrintL("vs1M",rpt,vN,dN,vs1M);
  251. cmVOD_PrintL("hp1M",rpt,hN,dN,hp1M);
  252. }
  253. //
  254. // Update Wieghts
  255. //
  256. // vh0M[hN,vN] = hp0M[hN,dN] * d[vN,dN]'
  257. cmVOD_MultMMMt(vh0M, hN, vN, hp0M, d, dN );
  258. cmVOD_MultMMMt(vh1M, hN, vN, hp1M, vs1M, dN );
  259. for(i=0; i<hN*vN; ++i)
  260. {
  261. dwM[i] = p->momentum * dwM[i] + p->eta * ( (vh0M[i] - vh1M[i]) / dN );
  262. r->W[i] += dwM[i];
  263. }
  264. //
  265. // Update hidden bias
  266. //
  267. // sum(hp0M - hp1M,2) - sum the difference of rows of hp0M and hp1M
  268. cmVOD_SubVV(hp0M,hN*dN,hp1M); // hp0M -= hp1M
  269. cmVOD_SumMN(hp0M,hN,dN,hp1M); // hp1M[1:hN] = sum(hp0M,2) (note: hp1M is rused as temp space)
  270. for(j=0; j<hN; ++j)
  271. {
  272. hdbV[j] = p->momentum * hdbV[j] + p->eta * (hp1M[j] / dN);
  273. r->hb[j] += hdbV[j];
  274. }
  275. //
  276. // Update visible bias
  277. //
  278. // sum(d - vs1M, 2)
  279. cmVOD_SubVVV(vp1M,vN*dN,d,vs1M); // vp1M = d - vs1M (vp1M is reused as temp space)
  280. cmVOD_SumMN(vp1M,vN,dN,vs1M); // vs1M[1:vn] = sum(vp1M,2) (vs1M is reused as temp space)
  281. for(i=0; i<vN; ++i)
  282. {
  283. vdbV[i] = p->momentum * vdbV[i] + p->eta * (vs1M[i] / dN );
  284. r->vb[i] += vdbV[i];
  285. }
  286. if(0)
  287. {
  288. cmVOD_PrintL("dwM", rpt, vN, hN, dwM );
  289. cmVOD_PrintL("vdbV",rpt, 1, vN, vdbV );
  290. cmVOD_PrintL("hdbV",rpt, 1, hN, hdbV );
  291. cmVOD_PrintL("W", rpt, vN, hN, r->W );
  292. cmVOD_PrintL("vb", rpt, 1, vN, r->vb );
  293. cmVOD_PrintL("hb", rpt, 1, hN, r->hb );
  294. }
  295. } // di
  296. cmRptPrintf(rpt,"err:%f\n",err);
  297. if( cmStackIsValid(r->monH))
  298. {
  299. cmRbmMonitor_t monErr;
  300. monErr.trainErr = err;
  301. cmStackPush(r->monH,&monErr,1);
  302. }
  303. } // ei
  304. cmRptPrintf(rpt,"eta:%f momentum:%f\n",p->eta,p->momentum);
  305. cmVOD_PrintL("dwM", rpt, vN, hN, dwM );
  306. cmVOD_PrintL("vdbV",rpt, 1, vN, vdbV );
  307. cmVOD_PrintL("hdbV",rpt, 1, hN, hdbV );
  308. cmVOD_PrintL("W", rpt, vN, hN, r->W );
  309. cmVOD_PrintL("vb", rpt, 1, vN, r->vb );
  310. cmVOD_PrintL("hb", rpt, 1, hN, r->hb );
  311. cmMemFree(m);
  312. }
  313. void cmRbmRealTrain(
  314. cmCtx_t* ctx,
  315. cmRBM_t* r,
  316. cmRbmTrainParms_t* p,
  317. unsigned dMN,
  318. const double* dM )
  319. {
  320. cmRpt_t* rpt = ctx->err.rpt;
  321. unsigned i,j,k,ei,di;
  322. unsigned vN = r->vN;
  323. unsigned hN = r->hN;
  324. // adjust the memory sizes to align all arrays on 16 byte boundaries
  325. unsigned vn = vN;
  326. unsigned hn = hN;
  327. unsigned dn = p->batchCnt;
  328. _cmRbmAdjustSizes(&vn,&hn,&dn);
  329. unsigned mn = (3 * hn * vn) + (1 * vn) + (1 * hn) + (3 * hn * dn) + (2 * vn * dn);
  330. double* m = cmMemAllocZ(double,mn);
  331. double* vh0M = m; // vh0M[ hN, vN ]
  332. double* vh1M = vh0M + hn*vn; // vh1M[ hN, vN ]
  333. double* dwM = vh1M + hn*vn; // dwM[ hN, vN ]
  334. double* vdbV = dwM + hn*vn; // vdbV[ vN ]
  335. double* hdbV = vdbV + vn; // hdbV[ hN ]
  336. double* hp0M = hdbV + hn; // hp0M[ hN, dN ]
  337. double* hs0M = hp0M + dn * hn; // hs0M[ dN, hN ]
  338. double* hp1M = hs0M + dn * hn; // hp1M[ hN, dN ]
  339. double* vp1M = hp1M + dn * hn; // vp1M[ dN, vN ]
  340. double* vs1M = vp1M + dn * vn; // vs1M[ vN, dN ]
  341. assert( vs1M + vn * dn == m + mn );
  342. //
  343. // Initilaize the weights with small random values
  344. // W = p->initW * randn(vN,hN,0.0,1.0)
  345. for(i=0; i<vN; ++i)
  346. cmVOD_RandomGauss( r->W + i*hN, hN, 0.0, 1.0 );
  347. cmVOD_MultVS( r->W, hN*vN, p->initW);
  348. if(0)
  349. {
  350. const cmChar_t* fn = "/home/kevin/temp/cmRbmWeight.mtx";
  351. //cmBinMtxFileWrite(fn,hN, vN,NULL,dM,NULL,ctx->err.rpt);
  352. cmBinMtxFileRead( ctx, fn, hN, vN, sizeof(double), r->W,NULL);
  353. }
  354. cmVOD_Zero( dwM, vN*hN );
  355. cmVOD_Zero( vdbV, vN );
  356. cmVOD_Zero( hdbV, hN );
  357. for(ei=0; ei<p->epochCnt; ++ei)
  358. {
  359. unsigned dN = 0;
  360. double err = 0;
  361. for(di=0; di<dMN; di+=dN)
  362. {
  363. dN = cmMin(p->batchCnt,dMN-di);
  364. const double* d = dM + di * vN; // d[ vN, dN ]
  365. //
  366. // Update hidden layer from data
  367. //
  368. // hp0M[hN,dN] = W[hN,vN] * d[vN,dN]
  369. cmVOD_MultMMM(hp0M,hN,dN,r->W,d,vN);
  370. // calc hs0M[dN,hN]
  371. for(k=0; k<dN; ++k)
  372. for(j=0; j<hN; ++j)
  373. {
  374. hp0M[ k*hN + j ] = 1.0/(1.0 + exp(-(hp0M[ k*hN + j] + r->hb[j])));
  375. hs0M[ j*dN + k ] = rand() < hp0M[ k*hN + j ] * RAND_MAX;
  376. }
  377. //
  378. // Reconstruct visible layer from hidden
  379. //
  380. // vp1M[dN,vN] = hs0M[dN,hN] * W[hN,vN]
  381. cmVOD_MultMMM(vp1M,dN,vN,hs0M,r->W,hN);
  382. // calc vs1M[vN,dN]
  383. for(k=0; k<dN; ++k)
  384. for(i=0; i<vN; ++i)
  385. {
  386. vp1M[ i*dN + k ] = r->vd[i] * vp1M[ i*dN + k ] + r->vb[i];
  387. cmVOD_GaussPDF(vs1M + k*vN + i, 1, vp1M + i*dN + k, r->vb[i], r->vd[i] );
  388. // calc training error
  389. err += pow(d[ k*vN + i ] - vp1M[ i*dN + k ],2.0);
  390. }
  391. //
  392. // Update hidden layer from reconstruction
  393. //
  394. // hp1M[hN,dN] = W[hN,vN] * vs1[vN,dN]
  395. cmVOD_MultMMM(hp1M,hN,dN,r->W,vs1M,vN);
  396. // calc hp1M[hN,dN]
  397. for(k=0; k<dN; ++k)
  398. for(j=0; j<hN; ++j)
  399. hp1M[ k*hN + j ] = 1.0/(1.0 + exp( -hp1M[ k*hN + j ] - r->hb[j] ));
  400. if(0)
  401. {
  402. cmVOD_PrintL("hp0M",rpt,hN,dN,hp0M);
  403. cmVOD_PrintL("hs0M",rpt,dN,hN,hs0M);
  404. cmVOD_PrintL("vp1M",rpt,dN,vN,vp1M);
  405. cmVOD_PrintL("vs1M",rpt,vN,dN,vs1M);
  406. cmVOD_PrintL("hp1M",rpt,hN,dN,hp1M);
  407. }
  408. //
  409. // Update Wieghts
  410. //
  411. // vh0M[hN,vN] = hp0M[hN,dN] * d[vN,dN]'
  412. cmVOD_MultMMMt(vh0M, hN, vN, hp0M, d, dN );
  413. cmVOD_MultMMMt(vh1M, hN, vN, hp1M, vs1M, dN );
  414. for(i=0,k=0; i<vN; ++i)
  415. for(j=0; j<hN; ++j,++k)
  416. {
  417. dwM[k] = p->momentum * dwM[k] + p->eta * ( (vh0M[k] - vh1M[k]) / (dN * r->vd[i]) );
  418. r->W[k] += dwM[k];
  419. }
  420. //
  421. // Update hidden bias
  422. //
  423. // sum(hp0M - hp1M,2) - sum the difference of rows of hp0M and hp1M
  424. cmVOD_SubVV(hp0M,hN*dN,hp1M); // hp0M -= hp1M
  425. cmVOD_SumMN(hp0M,hN,dN,hp1M); // hp1M[1:hN] = sum(hp0M,2) (note: hp1M is rused as temp space)
  426. for(j=0; j<hN; ++j)
  427. {
  428. hdbV[j] = p->momentum * hdbV[j] + p->eta * (hp1M[j] / (dN * r->vd[i] * r->vd[i]));
  429. r->hb[j] += hdbV[j];
  430. }
  431. //
  432. // Update visible bias
  433. //
  434. // sum(d - vs1M, 2)
  435. cmVOD_SubVVV(vp1M,vN*dN,d,vs1M); // vp1M = d - vs1M (vp1M is reused as temp space)
  436. cmVOD_SumMN(vp1M,vN,dN,vs1M); // vs1M[1:vn] = sum(vp1M,2) (vs1M is reused as temp space)
  437. for(i=0; i<vN; ++i)
  438. {
  439. vdbV[i] = p->momentum * vdbV[i] + p->eta * (vs1M[i] / dN );
  440. r->vb[i] += vdbV[i];
  441. }
  442. for(i=0; i<vN; ++i)
  443. {
  444. for(j=0; j<hN; ++j)
  445. {
  446. double sum_d = 0;
  447. double sum_m = 0;
  448. for(k=0; k<dN; ++k)
  449. {
  450. sum_d += hs0M[ j*dN + k ] * r->W[ i*hN + j ];
  451. sum_m += hp1M[ k*hN + j ] * r->W[ i*hN + j ];
  452. }
  453. }
  454. }
  455. if(0)
  456. {
  457. cmVOD_PrintL("dwM", rpt, vN, hN, dwM );
  458. cmVOD_PrintL("vdbV",rpt, 1, vN, vdbV );
  459. cmVOD_PrintL("hdbV",rpt, 1, hN, hdbV );
  460. cmVOD_PrintL("W", rpt, vN, hN, r->W );
  461. cmVOD_PrintL("vb", rpt, 1, vN, r->vb );
  462. cmVOD_PrintL("hb", rpt, 1, hN, r->hb );
  463. }
  464. } // di
  465. cmRptPrintf(rpt,"err:%f\n",err);
  466. if( cmStackIsValid(r->monH))
  467. {
  468. cmRbmMonitor_t monErr;
  469. monErr.trainErr = err;
  470. cmStackPush(r->monH,&monErr,1);
  471. }
  472. } // ei
  473. cmRptPrintf(rpt,"eta:%f momentum:%f\n",p->eta,p->momentum);
  474. cmVOD_PrintL("dwM", rpt, vN, hN, dwM );
  475. cmVOD_PrintL("vdbV",rpt, 1, vN, vdbV );
  476. cmVOD_PrintL("hdbV",rpt, 1, hN, hdbV );
  477. cmVOD_PrintL("W", rpt, vN, hN, r->W );
  478. cmVOD_PrintL("vb", rpt, 1, vN, r->vb );
  479. cmVOD_PrintL("hb", rpt, 1, hN, r->hb );
  480. cmMemFree(m);
  481. }
  482. void cmRbmBinaryTest( cmCtx_t* ctx )
  483. {
  484. const char* monitorFn = "/home/kevin/temp/cmRbmMonitor0.mtx";
  485. const char* dataFn = "/home/kevin/temp/cmRbmData0.mtx";
  486. unsigned pointsN = 1000;
  487. unsigned dimN = 4;
  488. unsigned vN = dimN;
  489. unsigned hN = 32;
  490. //double probV[] = {0.1,0.2,0.8,0.7};
  491. cmRbmTrainParms_t r;
  492. cmRBM_t* rbm;
  493. r.maxX = 1.0;
  494. r.minX = 0.0;
  495. r.initW = 0.1;
  496. r.eta = 0.01;
  497. r.holdOutFrac = 0.1;
  498. r.epochCnt = 10;
  499. r.momentum = 0.5;
  500. r.batchCnt = 10;
  501. if(0)
  502. {
  503. vN = 4;
  504. hN = 6;
  505. double d[] = {
  506. 0, 1, 1, 1,
  507. 1, 1, 1, 1,
  508. 0, 0, 1, 0,
  509. 0, 0, 1, 0,
  510. 0, 1, 1, 0,
  511. 0, 1, 1, 1,
  512. 1, 1, 1, 1,
  513. 0, 0, 1, 0,
  514. 0, 0, 1, 0,
  515. 0, 1, 1, 0
  516. };
  517. if( (rbm = _cmRbmAlloc(ctx, vN, hN )) == NULL )
  518. return;
  519. pointsN = sizeof(d) / (sizeof(d[0]) * vN);
  520. cmRbmBinaryTrain(ctx,rbm,&r,pointsN,d);
  521. return;
  522. }
  523. if( (rbm = _cmRbmAlloc(ctx, vN, hN )) == NULL )
  524. return;
  525. //double* data0M = cmRbmGenBinaryTestData(ctx,dataFn,probV,dimN,pointsN);
  526. double* data0M = cmRbmReadDataFile(ctx,dataFn,&dimN,&pointsN);
  527. double t[ vN ];
  528. // Sum the columns of sp[srn,scn] into dp[scn].
  529. // dp[] is zeroed prior to computing the sum.
  530. cmVOD_SumMN(data0M, dimN, pointsN, t );
  531. cmVOD_Print( &ctx->rpt, 1, dimN, t );
  532. if(0)
  533. {
  534. //
  535. // Standardize data (subtract mean and divide by standard deviation)
  536. // then set the visible layers initial standard deviation to 1.0.
  537. //
  538. cmVOD_StandardizeRows( data0M, rbm->vN, pointsN, NULL, NULL );
  539. cmVOD_Fill( rbm->vd, rbm->vN, 1.0 );
  540. cmRbmRealTrain(ctx,rbm,&r,pointsN,data0M);
  541. }
  542. cmRbmBinaryTrain(ctx,rbm,&r,pointsN,data0M);
  543. cmRbmWriteMonitorFile(ctx, rbm->monH, monitorFn );
  544. cmMemFree(data0M);
  545. _cmRbmRelease(rbm);
  546. }