diff --git a/cwDataSets.cpp b/cwDataSets.cpp new file mode 100644 index 0000000..373de70 --- /dev/null +++ b/cwDataSets.cpp @@ -0,0 +1,236 @@ +#include "cwCommon.h" +#include "cwLog.h" +#include "cwCommonImpl.h" +#include "cwMem.h" +#include "cwFile.h" +#include "cwFileSys.h" +#include "cwMtx.h" +#include "cwDataSets.h" +#include "cwSvg.h" + + +namespace cw +{ + namespace dataset + { + namespace mnist + { + typedef struct mnist_str + { + mtx::fmtx_t* train = nullptr; + mtx::fmtx_t* valid = nullptr; + mtx::fmtx_t* test = nullptr; + + } mnist_t; + + inline mnist_t* _handleToPtr(handle_t h ) + { return handleToPtr(h); } + + rc_t _destroy( mnist_t* p ) + { + rc_t rc = kOkRC; + + mtx::release(p->train); + mtx::release(p->valid); + mtx::release(p->test); + mem::release(p); + return rc; + } + + rc_t _read_file( const char* dir, const char* fn, mtx::fmtx_t*& m ) + { + rc_t rc = kOkRC; + file::handle_t fH; + unsigned exampleN = 0; + const unsigned kPixN = 784; + const unsigned kRowN = kPixN+1; + unsigned dimV[] = {kRowN,0}; + const unsigned dimN = sizeof(dimV)/sizeof(dimV[0]); + float* v = nullptr; + char* path = filesys::makeFn(dir, fn, ".bin", NULL ); + + // open the file + if((rc = file::open(fH,path, file::kReadFl | file::kBinaryFl )) != kOkRC ) + { + rc = cwLogError(rc,"MNIST file open failed on '%s'.",cwStringNullGuard(path)); + goto errLabel; + } + + // read the count of examples + if((rc = readUInt(fH,&exampleN)) != kOkRC ) + { + rc = cwLogError(rc,"Unable to read MNIST example count."); + goto errLabel; + } + + // allocate the data memory + v = mem::alloc( kRowN * exampleN ); + + // read each example + for(unsigned i=0,j=0; i( dimN, dimV, v, mtx::kAliasReleaseFl ); + + errLabel: + file::close(fH); + mem::release(path); + return rc; + } + } + } +} + + +cw::rc_t cw::dataset::mnist::create( handle_t& h, const char* dir ) +{ + rc_t rc; + mnist_t* p = nullptr; + + if((rc = destroy(h)) != kOkRC ) + return rc; + + p = mem::allocZ(1); + + // read the training data + if((rc = _read_file( dir, "mnist_train", p->train )) != kOkRC ) + { + rc = cwLogError(rc,"MNIST training set load failed."); + goto errLabel; + } + + // read the validation data + if((rc = _read_file( dir, "mnist_valid", p->valid )) != kOkRC ) + { + rc = cwLogError(rc,"MNIST validation set load failed."); + goto errLabel; + } + + // read the testing data + if((rc = _read_file( dir, "mnist_test", p->test )) != kOkRC ) + { + rc = cwLogError(rc,"MNIST test set load failed."); + goto errLabel; + } + + h.set(p); + + errLabel: + if( rc != kOkRC ) + _destroy(p); + + return rc; +} + +cw::rc_t cw::dataset::mnist::destroy( handle_t& h ) +{ + rc_t rc = kOkRC; + if( !h.isValid()) + return rc; + + mnist_t* p = _handleToPtr(h); + + if((rc = _destroy(p)) != kOkRC ) + return rc; + + h.clear(); + + return rc; +} + +const cw::mtx::fmtx_t* cw::dataset::mnist::train( handle_t h ) +{ + mnist_t* p = _handleToPtr(h); + return p->train; +} + +const cw::mtx::fmtx_t* cw::dataset::mnist::validate( handle_t h ) +{ + mnist_t* p = _handleToPtr(h); + return p->valid; +} + +const cw::mtx::fmtx_t* cw::dataset::mnist::test( handle_t h ) +{ + mnist_t* p = _handleToPtr(h); + return p->test; +} + + + +cw::rc_t cw::dataset::mnist::test( const char* dir, const char* imageFn ) +{ + rc_t rc = kOkRC; + handle_t h; + if((rc = create(h, dir )) == kOkRC ) + { + svg::handle_t svgH; + + if((rc = svg::create(svgH)) != kOkRC ) + rc = cwLogError(rc,"SVG Test failed on create."); + else + { + const mtx::fmtx_t* m = train(h); + /* + unsigned zn = 0; + unsigned i = 1; + for(; idimV[1]; ++i) + { + const float* v0 = m->base + (28*28+1) * (i-1) + 1; + const float* v1 = m->base + (28*28+1) * (i-0) + 1; + float d = 0; + + for(unsigned j=0; j<28*28; ++j) + d += fabs(v0[j]-v1[j]); + + if( d==0 ) + ++zn; + else + { + printf("%i %i %f\n",i,zn,d); + zn = 0; + } + } + + printf("i:%i n:%i zn:%i\n",i,m->dimV[1],zn); + */ + + for(unsigned i=0; i<10; ++i) + { + svg::offset(svgH, 0, i*30*5 ); + svg::image(svgH, m->base + (28*28+1)*i, 28, 28, 5, svg::kInvGrayScaleColorMapId); + } + + svg::write(svgH, imageFn, nullptr, svg::kStandAloneFl | svg::kGenInlineStyleFl, 10,10,10,10); + + + svg::destroy(svgH); + } + + rc = destroy(h); + } + + return rc; +} + + diff --git a/cwDataSets.h b/cwDataSets.h new file mode 100644 index 0000000..8cf27db --- /dev/null +++ b/cwDataSets.h @@ -0,0 +1,32 @@ +#ifndef cwDataSets_h +#define cwDataSets_h + + +namespace cw +{ + namespace dataset + { + namespace mnist + { + typedef handle handle_t; + + rc_t create( handle_t& h, const char* dir ); + rc_t destroy( handle_t& h ); + + // Each column has one example. + // The top row contains the labels. + const mtx::fmtx_t* train( handle_t h ); + const mtx::fmtx_t* validate( handle_t h ); + const mtx::fmtx_t* test( handle_t h ); + + rc_t test(const char* dir, const char* imageFn ); + + + } + } + + +} + + +#endif