//| Copyright: (C) 2020-2024 Kevin Larke //| License: GNU GPL version 3.0 or above. See the accompanying LICENSE file. #include "cwCommon.h" #include "cwLog.h" #include "cwCommonImpl.h" #include "cwMem.h" #include "cwTest.h" #include "cwObject.h" #include "cwVectOps.h" #include "cwMtx.h" namespace cw { namespace mtx { bool _mtx_object_is_list( const object_t* cfg, unsigned& dimN ) { if( cfg->is_list() ) { dimN += 1; return true; } return false; } unsigned _mtx_object_get_degree( const object_t* cfg ) { unsigned dimN = 0; const object_t* o = cfg; while( _mtx_object_is_list(o,dimN) ) o = o->child_ele(0); return dimN; } rc_t _mtx_object_get_shape( const object_t* cfg, unsigned i, unsigned* dimV, unsigned dimN, unsigned& eleN ) { rc_t rc = kOkRC; if( !cfg->is_list() ) return kOkRC; dimV[i] = cfg->child_count(); eleN = eleN == 0 ? dimV[i] : eleN * dimV[i]; if((rc = _mtx_object_get_shape(cfg->child_ele(0), i+1, dimV, dimN, eleN )) != kOkRC ) return rc; if( cfg->child_ele(0)->is_list() ) { unsigned ch0 = cfg->child_ele(0)->child_count(); for(unsigned j=1; jchild_ele(j)->child_count()) return cwLogError(kSyntaxErrorRC,"A matrix contains an inconsistent dimension length on dimension index %i",i+1); } return rc; } unsigned _offsetMulV( const unsigned* mulV, unsigned dimN, unsigned* idxV ) { unsigned n = 0; for(unsigned i=0; i=0; --j) m *= dimV[j]; n += m; } return n; } } } cw::rc_t cw::mtx::test( const test::test_args_t& args ) { rc_t rc = kOkRC; const object_t* cfg = args.test_args; d_t* mtx0 = nullptr; d_t* mtx1 = nullptr; d_t* mtx2 = nullptr; d_t* mtx3 = nullptr; d_t* mtx4 = nullptr; d_t* mtx_y0 = nullptr; d_t* mtx_y1 = nullptr; d_t* mtx_y2 = nullptr; d_t* mtx_y3 = nullptr; d_t* mtx_y4 = nullptr; d_t* mtx_y5 = nullptr; d_t y; const object_t* m0 = cfg->find("m0"); if( m0 != nullptr ) mtx0 = allocCfg(m0); const object_t* m1 = cfg->find("m1"); if( m1 != nullptr ) mtx1 = allocCfg(m1); const object_t* m2 = cfg->find("m2"); if( m2 != nullptr ) mtx2 = allocCfg(m2); const object_t* m3 = cfg->find("m3"); if( m3 != nullptr ) mtx3 = allocCfg(m3); const object_t* m4 = cfg->find("m4"); if( m4 != nullptr ) mtx4 = allocCfg(m4); const object_t* y0 = cfg->find("y0"); if( y0 != nullptr ) mtx_y0 = allocCfg(y0); const object_t* y1 = cfg->find("y1"); if( y1 != nullptr ) mtx_y1 = allocCfg(y1); unsigned n = offset(*mtx1,1,1); cwLogPrint("offset: %i\n",n); report(*mtx0,"m0"); report(*mtx1,"m1"); report(*mtx2,"m2"); report(*mtx3,"m3"); report(*mtx4,"m4"); report(*mtx_y0,"y0"); report(*mtx_y1,"y1"); if( mtx_mul(y,*mtx1,*mtx0) == kOkRC ) { report(y,"y0"); if( !is_equal(*mtx_y0,y) ) rc = cwLogError(kTestFailRC,"Test 0 fail."); } transpose(*mtx0); transpose(*mtx1); if( mtx_mul(y,*mtx1,*mtx0) == kOkRC ) { report(y,"y1"); if( !is_equal(*mtx_y1,y) ) rc = cwLogError(kTestFailRC,"Test 1 fail."); } transpose(*mtx0); report(*mtx0,"m0"); mtx_y2 = join(0,*mtx0,*mtx4); if( mtx_y2 != nullptr ) report(*mtx_y2,"y2"); report(*mtx0,"m0"); report(*mtx4,"m4"); mtx_y3 = join(1,*mtx0,*mtx4); if( mtx_y3 != nullptr ) report(*mtx_y3,"y3"); mtx_y4 = slice_alias(*mtx_y3,0,0,1); if( mtx_y4 != nullptr ) { report(*mtx_y4,"y4 - slice"); ele(*mtx_y4,2) = 1; ele(*mtx_y4,3) = 2; report(*mtx_y4,"y4 - mod"); mtx_y5 = alloc_one_hot(*mtx_y4); if( mtx_y5 != nullptr ) report(*mtx_y5,"y5 -(one_hot(y4))"); } release(mtx0); release(mtx1); release(mtx2); release(mtx3); release(mtx4); release(mtx_y0); release(mtx_y1); release(mtx_y2); release(mtx_y3); release(mtx_y4); release(mtx_y5); release(y); return rc; }