cwDataSets.h/cpp : Initial commit. MNIST dataset implementation.

This commit is contained in:
kevin 2020-08-19 20:10:35 -04:00
parent ded6a1ef4a
commit 11bad66e54
2 changed files with 268 additions and 0 deletions

236
cwDataSets.cpp Normal file
View File

@ -0,0 +1,236 @@
#include "cwCommon.h"
#include "cwLog.h"
#include "cwCommonImpl.h"
#include "cwMem.h"
#include "cwFile.h"
#include "cwFileSys.h"
#include "cwMtx.h"
#include "cwDataSets.h"
#include "cwSvg.h"
namespace cw
{
namespace dataset
{
namespace mnist
{
typedef struct mnist_str
{
mtx::fmtx_t* train = nullptr;
mtx::fmtx_t* valid = nullptr;
mtx::fmtx_t* test = nullptr;
} mnist_t;
inline mnist_t* _handleToPtr(handle_t h )
{ return handleToPtr<handle_t,mnist_t>(h); }
rc_t _destroy( mnist_t* p )
{
rc_t rc = kOkRC;
mtx::release(p->train);
mtx::release(p->valid);
mtx::release(p->test);
mem::release(p);
return rc;
}
rc_t _read_file( const char* dir, const char* fn, mtx::fmtx_t*& m )
{
rc_t rc = kOkRC;
file::handle_t fH;
unsigned exampleN = 0;
const unsigned kPixN = 784;
const unsigned kRowN = kPixN+1;
unsigned dimV[] = {kRowN,0};
const unsigned dimN = sizeof(dimV)/sizeof(dimV[0]);
float* v = nullptr;
char* path = filesys::makeFn(dir, fn, ".bin", NULL );
// open the file
if((rc = file::open(fH,path, file::kReadFl | file::kBinaryFl )) != kOkRC )
{
rc = cwLogError(rc,"MNIST file open failed on '%s'.",cwStringNullGuard(path));
goto errLabel;
}
// read the count of examples
if((rc = readUInt(fH,&exampleN)) != kOkRC )
{
rc = cwLogError(rc,"Unable to read MNIST example count.");
goto errLabel;
}
// allocate the data memory
v = mem::alloc<float>( kRowN * exampleN );
// read each example
for(unsigned i=0,j=0; i<exampleN; ++i,j+=kRowN)
{
unsigned digitLabel;
// read the digit image label
if((rc = readUInt(fH,&digitLabel)) != kOkRC )
{
rc = cwLogError(rc,"Unable to read MNIST label on example %i.",i);
goto errLabel;
}
v[j] = digitLabel;
// read the image pixels
if((rc = readFloat(fH,v+j+1,kPixN)) != kOkRC )
{
rc = cwLogError(rc,"Unable to read MNIST data vector on example %i.",i);
goto errLabel;
}
}
dimV[1] = exampleN;
m = mtx::alloc<float>( dimN, dimV, v, mtx::kAliasReleaseFl );
errLabel:
file::close(fH);
mem::release(path);
return rc;
}
}
}
}
cw::rc_t cw::dataset::mnist::create( handle_t& h, const char* dir )
{
rc_t rc;
mnist_t* p = nullptr;
if((rc = destroy(h)) != kOkRC )
return rc;
p = mem::allocZ<mnist_t>(1);
// read the training data
if((rc = _read_file( dir, "mnist_train", p->train )) != kOkRC )
{
rc = cwLogError(rc,"MNIST training set load failed.");
goto errLabel;
}
// read the validation data
if((rc = _read_file( dir, "mnist_valid", p->valid )) != kOkRC )
{
rc = cwLogError(rc,"MNIST validation set load failed.");
goto errLabel;
}
// read the testing data
if((rc = _read_file( dir, "mnist_test", p->test )) != kOkRC )
{
rc = cwLogError(rc,"MNIST test set load failed.");
goto errLabel;
}
h.set(p);
errLabel:
if( rc != kOkRC )
_destroy(p);
return rc;
}
cw::rc_t cw::dataset::mnist::destroy( handle_t& h )
{
rc_t rc = kOkRC;
if( !h.isValid())
return rc;
mnist_t* p = _handleToPtr(h);
if((rc = _destroy(p)) != kOkRC )
return rc;
h.clear();
return rc;
}
const cw::mtx::fmtx_t* cw::dataset::mnist::train( handle_t h )
{
mnist_t* p = _handleToPtr(h);
return p->train;
}
const cw::mtx::fmtx_t* cw::dataset::mnist::validate( handle_t h )
{
mnist_t* p = _handleToPtr(h);
return p->valid;
}
const cw::mtx::fmtx_t* cw::dataset::mnist::test( handle_t h )
{
mnist_t* p = _handleToPtr(h);
return p->test;
}
cw::rc_t cw::dataset::mnist::test( const char* dir, const char* imageFn )
{
rc_t rc = kOkRC;
handle_t h;
if((rc = create(h, dir )) == kOkRC )
{
svg::handle_t svgH;
if((rc = svg::create(svgH)) != kOkRC )
rc = cwLogError(rc,"SVG Test failed on create.");
else
{
const mtx::fmtx_t* m = train(h);
/*
unsigned zn = 0;
unsigned i = 1;
for(; i<m->dimV[1]; ++i)
{
const float* v0 = m->base + (28*28+1) * (i-1) + 1;
const float* v1 = m->base + (28*28+1) * (i-0) + 1;
float d = 0;
for(unsigned j=0; j<28*28; ++j)
d += fabs(v0[j]-v1[j]);
if( d==0 )
++zn;
else
{
printf("%i %i %f\n",i,zn,d);
zn = 0;
}
}
printf("i:%i n:%i zn:%i\n",i,m->dimV[1],zn);
*/
for(unsigned i=0; i<10; ++i)
{
svg::offset(svgH, 0, i*30*5 );
svg::image(svgH, m->base + (28*28+1)*i, 28, 28, 5, svg::kInvGrayScaleColorMapId);
}
svg::write(svgH, imageFn, nullptr, svg::kStandAloneFl | svg::kGenInlineStyleFl, 10,10,10,10);
svg::destroy(svgH);
}
rc = destroy(h);
}
return rc;
}

32
cwDataSets.h Normal file
View File

@ -0,0 +1,32 @@
#ifndef cwDataSets_h
#define cwDataSets_h
namespace cw
{
namespace dataset
{
namespace mnist
{
typedef handle<struct mnist_str> handle_t;
rc_t create( handle_t& h, const char* dir );
rc_t destroy( handle_t& h );
// Each column has one example.
// The top row contains the labels.
const mtx::fmtx_t* train( handle_t h );
const mtx::fmtx_t* validate( handle_t h );
const mtx::fmtx_t* test( handle_t h );
rc_t test(const char* dir, const char* imageFn );
}
}
}
#endif