cwDataSets.h/cpp : Initial commit. MNIST dataset implementation.
This commit is contained in:
parent
ded6a1ef4a
commit
11bad66e54
236
cwDataSets.cpp
Normal file
236
cwDataSets.cpp
Normal file
@ -0,0 +1,236 @@
|
||||
#include "cwCommon.h"
|
||||
#include "cwLog.h"
|
||||
#include "cwCommonImpl.h"
|
||||
#include "cwMem.h"
|
||||
#include "cwFile.h"
|
||||
#include "cwFileSys.h"
|
||||
#include "cwMtx.h"
|
||||
#include "cwDataSets.h"
|
||||
#include "cwSvg.h"
|
||||
|
||||
|
||||
namespace cw
|
||||
{
|
||||
namespace dataset
|
||||
{
|
||||
namespace mnist
|
||||
{
|
||||
typedef struct mnist_str
|
||||
{
|
||||
mtx::fmtx_t* train = nullptr;
|
||||
mtx::fmtx_t* valid = nullptr;
|
||||
mtx::fmtx_t* test = nullptr;
|
||||
|
||||
} mnist_t;
|
||||
|
||||
inline mnist_t* _handleToPtr(handle_t h )
|
||||
{ return handleToPtr<handle_t,mnist_t>(h); }
|
||||
|
||||
rc_t _destroy( mnist_t* p )
|
||||
{
|
||||
rc_t rc = kOkRC;
|
||||
|
||||
mtx::release(p->train);
|
||||
mtx::release(p->valid);
|
||||
mtx::release(p->test);
|
||||
mem::release(p);
|
||||
return rc;
|
||||
}
|
||||
|
||||
rc_t _read_file( const char* dir, const char* fn, mtx::fmtx_t*& m )
|
||||
{
|
||||
rc_t rc = kOkRC;
|
||||
file::handle_t fH;
|
||||
unsigned exampleN = 0;
|
||||
const unsigned kPixN = 784;
|
||||
const unsigned kRowN = kPixN+1;
|
||||
unsigned dimV[] = {kRowN,0};
|
||||
const unsigned dimN = sizeof(dimV)/sizeof(dimV[0]);
|
||||
float* v = nullptr;
|
||||
char* path = filesys::makeFn(dir, fn, ".bin", NULL );
|
||||
|
||||
// open the file
|
||||
if((rc = file::open(fH,path, file::kReadFl | file::kBinaryFl )) != kOkRC )
|
||||
{
|
||||
rc = cwLogError(rc,"MNIST file open failed on '%s'.",cwStringNullGuard(path));
|
||||
goto errLabel;
|
||||
}
|
||||
|
||||
// read the count of examples
|
||||
if((rc = readUInt(fH,&exampleN)) != kOkRC )
|
||||
{
|
||||
rc = cwLogError(rc,"Unable to read MNIST example count.");
|
||||
goto errLabel;
|
||||
}
|
||||
|
||||
// allocate the data memory
|
||||
v = mem::alloc<float>( kRowN * exampleN );
|
||||
|
||||
// read each example
|
||||
for(unsigned i=0,j=0; i<exampleN; ++i,j+=kRowN)
|
||||
{
|
||||
unsigned digitLabel;
|
||||
|
||||
// read the digit image label
|
||||
if((rc = readUInt(fH,&digitLabel)) != kOkRC )
|
||||
{
|
||||
rc = cwLogError(rc,"Unable to read MNIST label on example %i.",i);
|
||||
goto errLabel;
|
||||
}
|
||||
|
||||
v[j] = digitLabel;
|
||||
|
||||
// read the image pixels
|
||||
if((rc = readFloat(fH,v+j+1,kPixN)) != kOkRC )
|
||||
{
|
||||
rc = cwLogError(rc,"Unable to read MNIST data vector on example %i.",i);
|
||||
goto errLabel;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
dimV[1] = exampleN;
|
||||
m = mtx::alloc<float>( dimN, dimV, v, mtx::kAliasReleaseFl );
|
||||
|
||||
errLabel:
|
||||
file::close(fH);
|
||||
mem::release(path);
|
||||
return rc;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
cw::rc_t cw::dataset::mnist::create( handle_t& h, const char* dir )
|
||||
{
|
||||
rc_t rc;
|
||||
mnist_t* p = nullptr;
|
||||
|
||||
if((rc = destroy(h)) != kOkRC )
|
||||
return rc;
|
||||
|
||||
p = mem::allocZ<mnist_t>(1);
|
||||
|
||||
// read the training data
|
||||
if((rc = _read_file( dir, "mnist_train", p->train )) != kOkRC )
|
||||
{
|
||||
rc = cwLogError(rc,"MNIST training set load failed.");
|
||||
goto errLabel;
|
||||
}
|
||||
|
||||
// read the validation data
|
||||
if((rc = _read_file( dir, "mnist_valid", p->valid )) != kOkRC )
|
||||
{
|
||||
rc = cwLogError(rc,"MNIST validation set load failed.");
|
||||
goto errLabel;
|
||||
}
|
||||
|
||||
// read the testing data
|
||||
if((rc = _read_file( dir, "mnist_test", p->test )) != kOkRC )
|
||||
{
|
||||
rc = cwLogError(rc,"MNIST test set load failed.");
|
||||
goto errLabel;
|
||||
}
|
||||
|
||||
h.set(p);
|
||||
|
||||
errLabel:
|
||||
if( rc != kOkRC )
|
||||
_destroy(p);
|
||||
|
||||
return rc;
|
||||
}
|
||||
|
||||
cw::rc_t cw::dataset::mnist::destroy( handle_t& h )
|
||||
{
|
||||
rc_t rc = kOkRC;
|
||||
if( !h.isValid())
|
||||
return rc;
|
||||
|
||||
mnist_t* p = _handleToPtr(h);
|
||||
|
||||
if((rc = _destroy(p)) != kOkRC )
|
||||
return rc;
|
||||
|
||||
h.clear();
|
||||
|
||||
return rc;
|
||||
}
|
||||
|
||||
const cw::mtx::fmtx_t* cw::dataset::mnist::train( handle_t h )
|
||||
{
|
||||
mnist_t* p = _handleToPtr(h);
|
||||
return p->train;
|
||||
}
|
||||
|
||||
const cw::mtx::fmtx_t* cw::dataset::mnist::validate( handle_t h )
|
||||
{
|
||||
mnist_t* p = _handleToPtr(h);
|
||||
return p->valid;
|
||||
}
|
||||
|
||||
const cw::mtx::fmtx_t* cw::dataset::mnist::test( handle_t h )
|
||||
{
|
||||
mnist_t* p = _handleToPtr(h);
|
||||
return p->test;
|
||||
}
|
||||
|
||||
|
||||
|
||||
cw::rc_t cw::dataset::mnist::test( const char* dir, const char* imageFn )
|
||||
{
|
||||
rc_t rc = kOkRC;
|
||||
handle_t h;
|
||||
if((rc = create(h, dir )) == kOkRC )
|
||||
{
|
||||
svg::handle_t svgH;
|
||||
|
||||
if((rc = svg::create(svgH)) != kOkRC )
|
||||
rc = cwLogError(rc,"SVG Test failed on create.");
|
||||
else
|
||||
{
|
||||
const mtx::fmtx_t* m = train(h);
|
||||
/*
|
||||
unsigned zn = 0;
|
||||
unsigned i = 1;
|
||||
for(; i<m->dimV[1]; ++i)
|
||||
{
|
||||
const float* v0 = m->base + (28*28+1) * (i-1) + 1;
|
||||
const float* v1 = m->base + (28*28+1) * (i-0) + 1;
|
||||
float d = 0;
|
||||
|
||||
for(unsigned j=0; j<28*28; ++j)
|
||||
d += fabs(v0[j]-v1[j]);
|
||||
|
||||
if( d==0 )
|
||||
++zn;
|
||||
else
|
||||
{
|
||||
printf("%i %i %f\n",i,zn,d);
|
||||
zn = 0;
|
||||
}
|
||||
}
|
||||
|
||||
printf("i:%i n:%i zn:%i\n",i,m->dimV[1],zn);
|
||||
*/
|
||||
|
||||
for(unsigned i=0; i<10; ++i)
|
||||
{
|
||||
svg::offset(svgH, 0, i*30*5 );
|
||||
svg::image(svgH, m->base + (28*28+1)*i, 28, 28, 5, svg::kInvGrayScaleColorMapId);
|
||||
}
|
||||
|
||||
svg::write(svgH, imageFn, nullptr, svg::kStandAloneFl | svg::kGenInlineStyleFl, 10,10,10,10);
|
||||
|
||||
|
||||
svg::destroy(svgH);
|
||||
}
|
||||
|
||||
rc = destroy(h);
|
||||
}
|
||||
|
||||
return rc;
|
||||
}
|
||||
|
||||
|
32
cwDataSets.h
Normal file
32
cwDataSets.h
Normal file
@ -0,0 +1,32 @@
|
||||
#ifndef cwDataSets_h
|
||||
#define cwDataSets_h
|
||||
|
||||
|
||||
namespace cw
|
||||
{
|
||||
namespace dataset
|
||||
{
|
||||
namespace mnist
|
||||
{
|
||||
typedef handle<struct mnist_str> handle_t;
|
||||
|
||||
rc_t create( handle_t& h, const char* dir );
|
||||
rc_t destroy( handle_t& h );
|
||||
|
||||
// Each column has one example.
|
||||
// The top row contains the labels.
|
||||
const mtx::fmtx_t* train( handle_t h );
|
||||
const mtx::fmtx_t* validate( handle_t h );
|
||||
const mtx::fmtx_t* test( handle_t h );
|
||||
|
||||
rc_t test(const char* dir, const char* imageFn );
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
#endif
|
Loading…
Reference in New Issue
Block a user