libcm is a C development framework with an emphasis on audio signal processing applications.
Du kannst nicht mehr als 25 Themen auswählen Themen müssen mit entweder einem Buchstaben oder einer Ziffer beginnen. Sie können Bindestriche („-“) enthalten und bis zu 35 Zeichen lang sein.

cmRbm.c 18KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701
  1. //| Copyright: (C) 2009-2020 Kevin Larke <contact AT larke DOT org>
  2. //| License: GNU GPL version 3.0 or above. See the accompanying LICENSE file.
  3. #include "cmPrefix.h"
  4. #include "cmGlobal.h"
  5. #include "cmFloatTypes.h"
  6. #include "cmComplexTypes.h"
  7. #include "cmRpt.h"
  8. #include "cmErr.h"
  9. #include "cmCtx.h"
  10. #include "cmMem.h"
  11. #include "cmMallocDebug.h"
  12. #include "cmLinkedHeap.h"
  13. #include "cmMath.h"
  14. #include "cmFile.h"
  15. #include "cmSymTbl.h"
  16. #include "cmTime.h"
  17. #include "cmMidi.h"
  18. #include "cmAudioFile.h"
  19. #include "cmVectOpsTemplateMain.h"
  20. #include "cmStack.h"
  21. #include "cmProcObj.h"
  22. #include "cmProcTemplateMain.h"
  23. #include "cmVectOps.h"
  24. #include "cmProc.h"
  25. #include "cmProc2.h"
  26. #include "cmRbm.h"
  27. typedef struct
  28. {
  29. double trainErr;
  30. double testErr;
  31. } cmRbmMonitor_t;
  32. cmRbmRC_t cmRbmWriteMonitorFile( cmCtx_t* c, cmStackH_t monH, const cmChar_t* fn )
  33. {
  34. cmRbmRC_t rc = kOkRbmRC;
  35. cmCtx* ctx = cmCtxAlloc(NULL, c->err.rpt, cmLHeapNullHandle, cmSymTblNullHandle );
  36. if( cmBinMtxFileWrite(fn, cmStackCount(monH), sizeof(cmRbmMonitor_t)/sizeof(double), NULL, cmStackFlatten(monH), ctx, c->err.rpt ) != cmOkRC )
  37. {
  38. rc = cmErrMsg(&c->err,kMonitorWrFailRbmRC,"Training monitor file '%s' write failed.",cmStringNullGuard(fn));
  39. goto errLabel;
  40. }
  41. errLabel:
  42. cmCtxFree(&ctx);
  43. return rc;
  44. }
  45. double* cmRbmReadDataFile( cmCtx_t* c, const char* fn, unsigned* dimNPtr, unsigned* pointCntPtr )
  46. {
  47. unsigned rowCnt,colCnt,eleByteCnt;
  48. *dimNPtr = 0;
  49. *pointCntPtr = 0;
  50. if( cmBinMtxFileSize(c, fn, &rowCnt, &colCnt, &eleByteCnt ) != cmOkRC )
  51. return NULL;
  52. double* buf = cmMemAllocZ(double,rowCnt*colCnt);
  53. if( cmBinMtxFileRead(c, fn, rowCnt, colCnt, sizeof(double), buf,NULL) != cmOkRC )
  54. {
  55. cmMemFree(buf);
  56. return NULL;
  57. }
  58. *dimNPtr = rowCnt;
  59. *pointCntPtr = colCnt;
  60. return buf;
  61. }
  62. // Generate a matrix of 'pointsN' random binary valued column vectors of dimension dimN.
  63. // The first i = {0...'dimN'-1} elements of each vector contain ones with prob probV[i]
  64. // (or zeros with prob 1 - probV[i].). probV[i] in [0.0,1.0].
  65. // The last element in each column is set to zero.
  66. // The returned matrix m[ dimN+1, pointsN ] is in column major order and
  67. // must be deleted by the caller (e.g. cmMemFree(m)).
  68. double* cmRbmGenBinaryTestData( cmCtx_t* c, const char* fn, const double* probV, unsigned dimN, unsigned pointsN )
  69. {
  70. if( dimN == 0 || pointsN == 0 )
  71. return NULL;
  72. double* m = cmMemAllocZ( double, dimN*pointsN );
  73. unsigned i,j;
  74. for(i=0; i<pointsN; ++i)
  75. for(j=0; j<dimN; ++j)
  76. m[ i*dimN + j ] = rand() < (probV[j] * RAND_MAX);
  77. if( fn != NULL )
  78. cmBinMtxFileWrite(fn,dimN,pointsN,NULL,m,NULL,c->err.rpt);
  79. return m;
  80. }
  81. typedef struct
  82. {
  83. unsigned vN;
  84. double* vs;
  85. double* vp;
  86. double* vb;
  87. double* vd; // std dev. (var = std_dev^2)
  88. unsigned hN;
  89. double* hs;
  90. double* hp;
  91. double* hb;
  92. double* W; // W[vN,hN]
  93. cmStackH_t monH;
  94. } cmRBM_t;
  95. void _cmRbmPrint( cmRBM_t* r, cmRpt_t* rpt )
  96. {
  97. cmVOD_PrintL("hb", rpt, 1, r->hN, r->hb );
  98. cmVOD_PrintL("hp", rpt, 1, r->hN, r->hp );
  99. cmVOD_PrintL("hs", rpt, 1, r->hN, r->hs );
  100. cmVOD_PrintL("vb", rpt, 1, r->vN, r->vb );
  101. cmVOD_PrintL("vp", rpt, 1, r->vN, r->vp );
  102. cmVOD_PrintL("vs", rpt, 1, r->vN, r->vs );
  103. cmVOD_PrintL("W", rpt, r->vN, r->hN, r->W );
  104. }
  105. void _cmRbmRelease( cmRBM_t* r )
  106. {
  107. cmStackFree(&r->monH);
  108. cmMemFree(r);
  109. }
  110. // Adjust the layer geometry to force all sizes to be a multiple of 16 bytes.
  111. // This assumes that all data will be 8 byte doubles.
  112. void _cmRbmAdjustSizes( unsigned* vNp, unsigned* hNp, unsigned* dNp )
  113. {
  114. *vNp = *vNp + (cmIsOddU(*vNp) ? 1 : 0);
  115. *hNp = *hNp + (cmIsOddU(*hNp) ? 1 : 0);
  116. if( dNp != NULL )
  117. *dNp = *dNp + (cmIsOddU(*dNp) ? 1 : 0);
  118. }
  119. cmRBM_t* _cmRbmAlloc( cmCtx_t* ctx, unsigned vN, unsigned hN )
  120. {
  121. unsigned monInitCnt = 1000;
  122. unsigned monExpandCnt = 1000;
  123. // adjust sizes to force base array addresses to be a multiple of 16 bytes.
  124. unsigned vn = vN;
  125. unsigned hn = hN;
  126. _cmRbmAdjustSizes(&vn,&hn,NULL);
  127. unsigned rn = sizeof(cmRBM_t);
  128. // force record to be a multiple of 16
  129. if( rn % 16 )
  130. rn += 16 - (rn % 16);
  131. unsigned dn = 4*vn + 3*hn + vn*hn;
  132. unsigned bn = rn + dn*sizeof(double);
  133. char* cp = cmMemAllocZ(char,bn);
  134. cmRBM_t* r = (cmRBM_t*)cp;
  135. r->vs = (double*)(cp+rn);
  136. r->vp = r->vs + vn;
  137. r->vb = r->vp + vn;
  138. r->vd = r->vb + vn;
  139. r->hs = r->vd + vn;
  140. r->hp = r->hs + hn;
  141. r->hb = r->hp + hn;
  142. r->W = r->hb + hn;
  143. r->vN = vN;
  144. r->hN = hN;
  145. assert(cp+bn == (char*)(r->W + vn*hn));
  146. if( cmStackAlloc(ctx, &r->monH, monInitCnt, monExpandCnt, sizeof(cmRbmMonitor_t)) != kOkStRC )
  147. {
  148. cmErrMsg(&ctx->err,kStackFailRbmRC,"Stack allocation failed for the training monitor data array.");
  149. goto errLabel;
  150. }
  151. return r;
  152. errLabel:
  153. _cmRbmRelease(r);
  154. return NULL;
  155. }
  156. void cmRbmBinaryTrain(
  157. cmCtx_t* ctx,
  158. cmRBM_t* r,
  159. cmRbmTrainParms_t* p,
  160. unsigned dMN,
  161. const double* dM )
  162. {
  163. cmRpt_t* rpt = ctx->err.rpt;
  164. bool stochFl = true;
  165. unsigned i,j,k,ei,di;
  166. unsigned vN = r->vN;
  167. unsigned hN = r->hN;
  168. // adjust the memory sizes to align all arrays on 16 byte boundaries
  169. unsigned vn = vN;
  170. unsigned hn = hN;
  171. unsigned dn = p->batchCnt;
  172. _cmRbmAdjustSizes(&vn,&hn,&dn);
  173. unsigned mn = (3 * hn * vn) + (1 * vn) + (1 * hn) + (3 * hn * dn) + (2 * vn * dn);
  174. double* m = cmMemAllocZ(double,mn);
  175. double* vh0M = m; // vh0M[ hN, vN ]
  176. double* vh1M = vh0M + hn*vn; // vh1M[ hN, vN ]
  177. double* dwM = vh1M + hn*vn; // dwM[ hN, vN ]
  178. double* vdbV = dwM + hn*vn; // vdbV[ vN ]
  179. double* hdbV = vdbV + vn; // hdbV[ hN ]
  180. double* hp0M = hdbV + hn; // hp0M[ hN, dN ]
  181. double* hs0M = hp0M + dn * hn; // hs0M[ dN, hN ]
  182. double* hp1M = hs0M + dn * hn; // hp1M[ hN, dN ]
  183. double* vp1M = hp1M + dn * hn; // vp1M[ dN, vN ]
  184. double* vs1M = vp1M + dn * vn; // vs1M[ vN, dN ]
  185. assert( vs1M + vn * dn == m + mn );
  186. // initilaize the weights with random values
  187. // W = p->initW * randn(vN,hN,0.0,1.0)
  188. for(i=0; i<vN; ++i)
  189. cmVOD_RandomGauss( r->W + i*hN, hN, 0.0, 1.0 );
  190. cmVOD_MultVS( r->W, hN*vN, p->initW);
  191. if(0)
  192. {
  193. const cmChar_t* fn = "/home/kevin/temp/cmRbmWeight.mtx";
  194. //cmBinMtxFileWrite(fn,hN, vN,NULL,dM,NULL,ctx->err.rpt);
  195. cmBinMtxFileRead( ctx, fn, hN, vN, sizeof(double), r->W,NULL);
  196. }
  197. cmVOD_Zero( dwM, vN*hN );
  198. cmVOD_Zero( vdbV, vN );
  199. cmVOD_Zero( hdbV, hN );
  200. for(ei=0; ei<p->epochCnt; ++ei)
  201. {
  202. unsigned dN = 0;
  203. double err = 0;
  204. for(di=0; di<dMN; di+=dN)
  205. {
  206. dN = cmMin(p->batchCnt,dMN-di);
  207. const double* d = dM + di * vN; // d[ vN, dN ]
  208. //
  209. // Update hidden layer from data
  210. //
  211. // hp0M[hN,dN] = W[hN,vN] * d[vN,dN]
  212. cmVOD_MultMMM(hp0M,hN,dN,r->W,d,vN);
  213. // calc hs0M[dN,hN]
  214. for(k=0; k<dN; ++k)
  215. for(j=0; j<hN; ++j)
  216. {
  217. hp0M[ k*hN + j ] = 1.0/(1.0 + exp(-(hp0M[ k*hN + j] + r->hb[j])));
  218. hs0M[ j*dN + k ] = rand() < hp0M[ k*hN + j ] * RAND_MAX;
  219. if( !stochFl )
  220. hs0M[ j*dN + k ] = hp0M[ k*hN + j ] > 0.5;
  221. }
  222. //
  223. // Reconstruct visible layer from hidden
  224. //
  225. // vp1M[dN,vN] = hs0M[dN,hN] * W[hN,vN]
  226. cmVOD_MultMMM(vp1M,dN,vN,hs0M,r->W,hN);
  227. // calc vs1M[vN,dN]
  228. for(k=0; k<dN; ++k)
  229. for(i=0; i<vN; ++i)
  230. {
  231. vp1M[ i*dN + k ] = 1.0/(1.0 + exp(-( vp1M[ i*dN + k ] + r->vb[i]) ) );
  232. vs1M[ k*vN + i ] = rand() < vp1M[ i*dN + k ] * RAND_MAX;
  233. if( !stochFl )
  234. vs1M[ k*vN + i ] = vp1M[ i*dN + k ] > 0.5;
  235. // calc training error
  236. err += pow(d[ k*vN + i ] - vp1M[ i*dN + k ],2.0);
  237. }
  238. //
  239. // Update hidden layer from reconstruction
  240. //
  241. // hp1M[hN,dN] = W[hN,vN] * vs1[vN,dN]
  242. cmVOD_MultMMM(hp1M,hN,dN,r->W,vs1M,vN);
  243. // calc hp1M[hN,dN]
  244. for(k=0; k<dN; ++k)
  245. for(j=0; j<hN; ++j)
  246. hp1M[ k*hN + j ] = 1.0/(1.0 + exp( -hp1M[ k*hN + j ] - r->hb[j] ));
  247. if(0)
  248. {
  249. cmVOD_PrintL("hp0M",rpt,hN,dN,hp0M);
  250. cmVOD_PrintL("hs0M",rpt,dN,hN,hs0M);
  251. cmVOD_PrintL("vp1M",rpt,dN,vN,vp1M);
  252. cmVOD_PrintL("vs1M",rpt,vN,dN,vs1M);
  253. cmVOD_PrintL("hp1M",rpt,hN,dN,hp1M);
  254. }
  255. //
  256. // Update Wieghts
  257. //
  258. // vh0M[hN,vN] = hp0M[hN,dN] * d[vN,dN]'
  259. cmVOD_MultMMMt(vh0M, hN, vN, hp0M, d, dN );
  260. cmVOD_MultMMMt(vh1M, hN, vN, hp1M, vs1M, dN );
  261. for(i=0; i<hN*vN; ++i)
  262. {
  263. dwM[i] = p->momentum * dwM[i] + p->eta * ( (vh0M[i] - vh1M[i]) / dN );
  264. r->W[i] += dwM[i];
  265. }
  266. //
  267. // Update hidden bias
  268. //
  269. // sum(hp0M - hp1M,2) - sum the difference of rows of hp0M and hp1M
  270. cmVOD_SubVV(hp0M,hN*dN,hp1M); // hp0M -= hp1M
  271. cmVOD_SumMN(hp0M,hN,dN,hp1M); // hp1M[1:hN] = sum(hp0M,2) (note: hp1M is rused as temp space)
  272. for(j=0; j<hN; ++j)
  273. {
  274. hdbV[j] = p->momentum * hdbV[j] + p->eta * (hp1M[j] / dN);
  275. r->hb[j] += hdbV[j];
  276. }
  277. //
  278. // Update visible bias
  279. //
  280. // sum(d - vs1M, 2)
  281. cmVOD_SubVVV(vp1M,vN*dN,d,vs1M); // vp1M = d - vs1M (vp1M is reused as temp space)
  282. cmVOD_SumMN(vp1M,vN,dN,vs1M); // vs1M[1:vn] = sum(vp1M,2) (vs1M is reused as temp space)
  283. for(i=0; i<vN; ++i)
  284. {
  285. vdbV[i] = p->momentum * vdbV[i] + p->eta * (vs1M[i] / dN );
  286. r->vb[i] += vdbV[i];
  287. }
  288. if(0)
  289. {
  290. cmVOD_PrintL("dwM", rpt, vN, hN, dwM );
  291. cmVOD_PrintL("vdbV",rpt, 1, vN, vdbV );
  292. cmVOD_PrintL("hdbV",rpt, 1, hN, hdbV );
  293. cmVOD_PrintL("W", rpt, vN, hN, r->W );
  294. cmVOD_PrintL("vb", rpt, 1, vN, r->vb );
  295. cmVOD_PrintL("hb", rpt, 1, hN, r->hb );
  296. }
  297. } // di
  298. cmRptPrintf(rpt,"err:%f\n",err);
  299. if( cmStackIsValid(r->monH))
  300. {
  301. cmRbmMonitor_t monErr;
  302. monErr.trainErr = err;
  303. cmStackPush(r->monH,&monErr,1);
  304. }
  305. } // ei
  306. cmRptPrintf(rpt,"eta:%f momentum:%f\n",p->eta,p->momentum);
  307. cmVOD_PrintL("dwM", rpt, vN, hN, dwM );
  308. cmVOD_PrintL("vdbV",rpt, 1, vN, vdbV );
  309. cmVOD_PrintL("hdbV",rpt, 1, hN, hdbV );
  310. cmVOD_PrintL("W", rpt, vN, hN, r->W );
  311. cmVOD_PrintL("vb", rpt, 1, vN, r->vb );
  312. cmVOD_PrintL("hb", rpt, 1, hN, r->hb );
  313. cmMemFree(m);
  314. }
  315. void cmRbmRealTrain(
  316. cmCtx_t* ctx,
  317. cmRBM_t* r,
  318. cmRbmTrainParms_t* p,
  319. unsigned dMN,
  320. const double* dM )
  321. {
  322. cmRpt_t* rpt = ctx->err.rpt;
  323. unsigned i,j,k,ei,di;
  324. unsigned vN = r->vN;
  325. unsigned hN = r->hN;
  326. // adjust the memory sizes to align all arrays on 16 byte boundaries
  327. unsigned vn = vN;
  328. unsigned hn = hN;
  329. unsigned dn = p->batchCnt;
  330. _cmRbmAdjustSizes(&vn,&hn,&dn);
  331. unsigned mn = (3 * hn * vn) + (1 * vn) + (1 * hn) + (3 * hn * dn) + (2 * vn * dn);
  332. double* m = cmMemAllocZ(double,mn);
  333. double* vh0M = m; // vh0M[ hN, vN ]
  334. double* vh1M = vh0M + hn*vn; // vh1M[ hN, vN ]
  335. double* dwM = vh1M + hn*vn; // dwM[ hN, vN ]
  336. double* vdbV = dwM + hn*vn; // vdbV[ vN ]
  337. double* hdbV = vdbV + vn; // hdbV[ hN ]
  338. double* hp0M = hdbV + hn; // hp0M[ hN, dN ]
  339. double* hs0M = hp0M + dn * hn; // hs0M[ dN, hN ]
  340. double* hp1M = hs0M + dn * hn; // hp1M[ hN, dN ]
  341. double* vp1M = hp1M + dn * hn; // vp1M[ dN, vN ]
  342. double* vs1M = vp1M + dn * vn; // vs1M[ vN, dN ]
  343. assert( vs1M + vn * dn == m + mn );
  344. //
  345. // Initilaize the weights with small random values
  346. // W = p->initW * randn(vN,hN,0.0,1.0)
  347. for(i=0; i<vN; ++i)
  348. cmVOD_RandomGauss( r->W + i*hN, hN, 0.0, 1.0 );
  349. cmVOD_MultVS( r->W, hN*vN, p->initW);
  350. if(0)
  351. {
  352. const cmChar_t* fn = "/home/kevin/temp/cmRbmWeight.mtx";
  353. //cmBinMtxFileWrite(fn,hN, vN,NULL,dM,NULL,ctx->err.rpt);
  354. cmBinMtxFileRead( ctx, fn, hN, vN, sizeof(double), r->W,NULL);
  355. }
  356. cmVOD_Zero( dwM, vN*hN );
  357. cmVOD_Zero( vdbV, vN );
  358. cmVOD_Zero( hdbV, hN );
  359. for(ei=0; ei<p->epochCnt; ++ei)
  360. {
  361. unsigned dN = 0;
  362. double err = 0;
  363. for(di=0; di<dMN; di+=dN)
  364. {
  365. dN = cmMin(p->batchCnt,dMN-di);
  366. const double* d = dM + di * vN; // d[ vN, dN ]
  367. //
  368. // Update hidden layer from data
  369. //
  370. // hp0M[hN,dN] = W[hN,vN] * d[vN,dN]
  371. cmVOD_MultMMM(hp0M,hN,dN,r->W,d,vN);
  372. // calc hs0M[dN,hN]
  373. for(k=0; k<dN; ++k)
  374. for(j=0; j<hN; ++j)
  375. {
  376. hp0M[ k*hN + j ] = 1.0/(1.0 + exp(-(hp0M[ k*hN + j] + r->hb[j])));
  377. hs0M[ j*dN + k ] = rand() < hp0M[ k*hN + j ] * RAND_MAX;
  378. }
  379. //
  380. // Reconstruct visible layer from hidden
  381. //
  382. // vp1M[dN,vN] = hs0M[dN,hN] * W[hN,vN]
  383. cmVOD_MultMMM(vp1M,dN,vN,hs0M,r->W,hN);
  384. // calc vs1M[vN,dN]
  385. for(k=0; k<dN; ++k)
  386. for(i=0; i<vN; ++i)
  387. {
  388. vp1M[ i*dN + k ] = r->vd[i] * vp1M[ i*dN + k ] + r->vb[i];
  389. cmVOD_GaussPDF(vs1M + k*vN + i, 1, vp1M + i*dN + k, r->vb[i], r->vd[i] );
  390. // calc training error
  391. err += pow(d[ k*vN + i ] - vp1M[ i*dN + k ],2.0);
  392. }
  393. //
  394. // Update hidden layer from reconstruction
  395. //
  396. // hp1M[hN,dN] = W[hN,vN] * vs1[vN,dN]
  397. cmVOD_MultMMM(hp1M,hN,dN,r->W,vs1M,vN);
  398. // calc hp1M[hN,dN]
  399. for(k=0; k<dN; ++k)
  400. for(j=0; j<hN; ++j)
  401. hp1M[ k*hN + j ] = 1.0/(1.0 + exp( -hp1M[ k*hN + j ] - r->hb[j] ));
  402. if(0)
  403. {
  404. cmVOD_PrintL("hp0M",rpt,hN,dN,hp0M);
  405. cmVOD_PrintL("hs0M",rpt,dN,hN,hs0M);
  406. cmVOD_PrintL("vp1M",rpt,dN,vN,vp1M);
  407. cmVOD_PrintL("vs1M",rpt,vN,dN,vs1M);
  408. cmVOD_PrintL("hp1M",rpt,hN,dN,hp1M);
  409. }
  410. //
  411. // Update Wieghts
  412. //
  413. // vh0M[hN,vN] = hp0M[hN,dN] * d[vN,dN]'
  414. cmVOD_MultMMMt(vh0M, hN, vN, hp0M, d, dN );
  415. cmVOD_MultMMMt(vh1M, hN, vN, hp1M, vs1M, dN );
  416. for(i=0,k=0; i<vN; ++i)
  417. for(j=0; j<hN; ++j,++k)
  418. {
  419. dwM[k] = p->momentum * dwM[k] + p->eta * ( (vh0M[k] - vh1M[k]) / (dN * r->vd[i]) );
  420. r->W[k] += dwM[k];
  421. }
  422. //
  423. // Update hidden bias
  424. //
  425. // sum(hp0M - hp1M,2) - sum the difference of rows of hp0M and hp1M
  426. cmVOD_SubVV(hp0M,hN*dN,hp1M); // hp0M -= hp1M
  427. cmVOD_SumMN(hp0M,hN,dN,hp1M); // hp1M[1:hN] = sum(hp0M,2) (note: hp1M is rused as temp space)
  428. for(j=0; j<hN; ++j)
  429. {
  430. hdbV[j] = p->momentum * hdbV[j] + p->eta * (hp1M[j] / (dN * r->vd[i] * r->vd[i]));
  431. r->hb[j] += hdbV[j];
  432. }
  433. //
  434. // Update visible bias
  435. //
  436. // sum(d - vs1M, 2)
  437. cmVOD_SubVVV(vp1M,vN*dN,d,vs1M); // vp1M = d - vs1M (vp1M is reused as temp space)
  438. cmVOD_SumMN(vp1M,vN,dN,vs1M); // vs1M[1:vn] = sum(vp1M,2) (vs1M is reused as temp space)
  439. for(i=0; i<vN; ++i)
  440. {
  441. vdbV[i] = p->momentum * vdbV[i] + p->eta * (vs1M[i] / dN );
  442. r->vb[i] += vdbV[i];
  443. }
  444. for(i=0; i<vN; ++i)
  445. {
  446. for(j=0; j<hN; ++j)
  447. {
  448. double sum_d = 0;
  449. double sum_m = 0;
  450. for(k=0; k<dN; ++k)
  451. {
  452. sum_d += hs0M[ j*dN + k ] * r->W[ i*hN + j ];
  453. sum_m += hp1M[ k*hN + j ] * r->W[ i*hN + j ];
  454. }
  455. }
  456. }
  457. if(0)
  458. {
  459. cmVOD_PrintL("dwM", rpt, vN, hN, dwM );
  460. cmVOD_PrintL("vdbV",rpt, 1, vN, vdbV );
  461. cmVOD_PrintL("hdbV",rpt, 1, hN, hdbV );
  462. cmVOD_PrintL("W", rpt, vN, hN, r->W );
  463. cmVOD_PrintL("vb", rpt, 1, vN, r->vb );
  464. cmVOD_PrintL("hb", rpt, 1, hN, r->hb );
  465. }
  466. } // di
  467. cmRptPrintf(rpt,"err:%f\n",err);
  468. if( cmStackIsValid(r->monH))
  469. {
  470. cmRbmMonitor_t monErr;
  471. monErr.trainErr = err;
  472. cmStackPush(r->monH,&monErr,1);
  473. }
  474. } // ei
  475. cmRptPrintf(rpt,"eta:%f momentum:%f\n",p->eta,p->momentum);
  476. cmVOD_PrintL("dwM", rpt, vN, hN, dwM );
  477. cmVOD_PrintL("vdbV",rpt, 1, vN, vdbV );
  478. cmVOD_PrintL("hdbV",rpt, 1, hN, hdbV );
  479. cmVOD_PrintL("W", rpt, vN, hN, r->W );
  480. cmVOD_PrintL("vb", rpt, 1, vN, r->vb );
  481. cmVOD_PrintL("hb", rpt, 1, hN, r->hb );
  482. cmMemFree(m);
  483. }
  484. void cmRbmBinaryTest( cmCtx_t* ctx )
  485. {
  486. const char* monitorFn = "/home/kevin/temp/cmRbmMonitor0.mtx";
  487. const char* dataFn = "/home/kevin/temp/cmRbmData0.mtx";
  488. unsigned pointsN = 1000;
  489. unsigned dimN = 4;
  490. unsigned vN = dimN;
  491. unsigned hN = 32;
  492. //double probV[] = {0.1,0.2,0.8,0.7};
  493. cmRbmTrainParms_t r;
  494. cmRBM_t* rbm;
  495. r.maxX = 1.0;
  496. r.minX = 0.0;
  497. r.initW = 0.1;
  498. r.eta = 0.01;
  499. r.holdOutFrac = 0.1;
  500. r.epochCnt = 10;
  501. r.momentum = 0.5;
  502. r.batchCnt = 10;
  503. if(0)
  504. {
  505. vN = 4;
  506. hN = 6;
  507. double d[] = {
  508. 0, 1, 1, 1,
  509. 1, 1, 1, 1,
  510. 0, 0, 1, 0,
  511. 0, 0, 1, 0,
  512. 0, 1, 1, 0,
  513. 0, 1, 1, 1,
  514. 1, 1, 1, 1,
  515. 0, 0, 1, 0,
  516. 0, 0, 1, 0,
  517. 0, 1, 1, 0
  518. };
  519. if( (rbm = _cmRbmAlloc(ctx, vN, hN )) == NULL )
  520. return;
  521. pointsN = sizeof(d) / (sizeof(d[0]) * vN);
  522. cmRbmBinaryTrain(ctx,rbm,&r,pointsN,d);
  523. return;
  524. }
  525. if( (rbm = _cmRbmAlloc(ctx, vN, hN )) == NULL )
  526. return;
  527. //double* data0M = cmRbmGenBinaryTestData(ctx,dataFn,probV,dimN,pointsN);
  528. double* data0M = cmRbmReadDataFile(ctx,dataFn,&dimN,&pointsN);
  529. double t[ vN ];
  530. // Sum the columns of sp[srn,scn] into dp[scn].
  531. // dp[] is zeroed prior to computing the sum.
  532. cmVOD_SumMN(data0M, dimN, pointsN, t );
  533. cmVOD_Print( &ctx->rpt, 1, dimN, t );
  534. if(0)
  535. {
  536. //
  537. // Standardize data (subtract mean and divide by standard deviation)
  538. // then set the visible layers initial standard deviation to 1.0.
  539. //
  540. cmVOD_StandardizeRows( data0M, rbm->vN, pointsN, NULL, NULL );
  541. cmVOD_Fill( rbm->vd, rbm->vN, 1.0 );
  542. cmRbmRealTrain(ctx,rbm,&r,pointsN,data0M);
  543. }
  544. cmRbmBinaryTrain(ctx,rbm,&r,pointsN,data0M);
  545. cmRbmWriteMonitorFile(ctx, rbm->monH, monitorFn );
  546. cmMemFree(data0M);
  547. _cmRbmRelease(rbm);
  548. }