cwDataSets.h/cpp : Add use of cache_t file cache to rdr to improve file read performance and add shuffling option.

This commit is contained in:
kevin 2020-12-29 11:22:29 -05:00
parent 46a7633e00
commit cbf4870410
2 changed files with 388 additions and 185 deletions

View File

@ -12,7 +12,7 @@
#include "cwSvg.h"
#include "cwTime.h"
#include "cwText.h"
#include "cwMath.h"
//----------------------------------------------------------------------------------------------------------------------------
//----------------------------------------------------------------------------------------------------------------------------
@ -26,7 +26,7 @@ namespace cw
{
typedef struct col_str
{
rdr::col_t col; // Public fields.
rdr::col_t col; // Public fields - See rdr::col_t.
unsigned char* cur; // Cache of the current column data contents.
unsigned curByteN; // Count of bytes in cur[].
unsigned* curDimV; // Cache of the current column dimensions.
@ -159,17 +159,17 @@ namespace cw
++col_count;
if((rc = file::write( p->fH, p->record_count )) != kOkRC ) goto errLabel;
if((rc = file::write( p->fH, col_count )) != kOkRC ) goto errLabel;
if((rc = file::write( p->fH, col_count )) != kOkRC ) goto errLabel;
for(c=p->colL; c!=nullptr; c=c->link)
{
if((rc = file::writeStr( p->fH, c->col.label )) != kOkRC ) goto errLabel;
if((rc = file::writeStr( p->fH, c->col.label )) != kOkRC ) goto errLabel;
if((rc = file::write( p->fH, c->col.id )) != kOkRC ) goto errLabel;
if((rc = file::write( p->fH, c->col.varDimN )) != kOkRC ) goto errLabel;
if((rc = file::write( p->fH, c->col.rankN )) != kOkRC ) goto errLabel;
if((rc = file::write( p->fH, c->col.maxEleN )) != kOkRC ) goto errLabel;
if((rc = variant::write( p->fH, c->col.max)) != kOkRC ) goto errLabel;
if((rc = variant::write( p->fH, c->col.min )) != kOkRC ) goto errLabel;
if((rc = variant::write( p->fH, c->col.max)) != kOkRC ) goto errLabel;
if((rc = variant::write( p->fH, c->col.min )) != kOkRC ) goto errLabel;
for(unsigned i=0; i<c->col.rankN; ++i)
{
@ -640,6 +640,265 @@ namespace cw
kSizeofRecordHeader = sizeof(unsigned)
};
typedef struct cache_str
{
file::handle_t fH;
unsigned totalRecdN; // Total count of records in the file
std::uint8_t * buf; // File buffer memory
unsigned bufMaxByteN; // Allocated size of buf[]
unsigned bufByteN; // Bytes in buf[]
unsigned baseFileOffs; // Offset of the first record in the file
unsigned* tocV; // tocV[tocN] Cached record byte offsets
unsigned tocN; // Count of records in the cache
unsigned tocBaseIdx; // Record index of the first record in the cache
unsigned tocIdx; // Record index of next record to return
unsigned state; // See rdr::k???State
bool shuffleFl; // shuffle the file buffer each time it is filled
} cache_t;
// Backup the file position to the beginning of the last (partial) record in the cache.
// Note that the last record overlaps the end of the cache is is therefore incomplete.
rc_t _cache_backup( cache_t* p, unsigned actByteN, unsigned cacheByteN )
{
rc_t rc = kOkRC;
if( p->state == kEofRC )
return kEofRC;
assert( actByteN >= cacheByteN);
if((rc = file::seek(p->fH,file::kCurFl, -(int)(actByteN-cacheByteN) )) != kOkRC )
return cwLogError(rc,"Dataset rdr cache align failed.");
return rc;
}
// Count the records in the case and re-align the current file position to the last (partial) record in the cache
rc_t _cache_count_and_align( cache_t* p, unsigned actByteN )
{
p->bufByteN = 0;
for(p->tocN=0; p->bufByteN < actByteN; ++p->tocN )
{
if( p->bufByteN + kSizeofRecordHeader >= actByteN )
{
_cache_backup( p, actByteN, p->bufByteN);
break;
}
unsigned recdByteN = *reinterpret_cast<unsigned*>(p->buf + p->bufByteN);
// TODO: handle case where the whole buffer has less than one record
if( p->tocN==0 && actByteN < kSizeofRecordHeader + recdByteN )
{
assert(0);
}
if( p->bufByteN + recdByteN > actByteN )
{
_cache_backup( p, actByteN, p->bufByteN);
break;
}
p->bufByteN += kSizeofRecordHeader + recdByteN;
}
return kOkRC;
}
void _cache_shuffle_toc( cache_t* p )
{
// for each record address in tocV[]
for(unsigned i=0; i<p->tocN; ++i)
{
// generate a random index into tocV[]
unsigned idx = math::randUInt(0,p->tocN-1);
// swap location i with a random location
unsigned tmp = p->tocV[i];
p->tocV[i] = p->tocV[idx];
p->tocV[idx] = tmp;
}
}
void _cache_fill_toc( cache_t* p )
{
unsigned cacheByteOffs = 0;
for(unsigned i=0; i<p->tocN; ++i)
{
p->tocV[i] = cacheByteOffs;
unsigned recdByteN = *reinterpret_cast<unsigned*>(p->buf + cacheByteOffs);
cacheByteOffs += kSizeofRecordHeader + recdByteN;
}
}
rc_t _cache_fill( cache_t* p )
{
rc_t rc = kOkRC;
unsigned actByteN = 0;
// Note that his function is always called when the file is pointing to the
// record length at the start of a record
// Fill the cache with as much data as possible from the file
if((rc = file::read(p->fH, p->buf, p->bufMaxByteN, &actByteN)) != kOkRC )
{
if(rc == kEofRC)
p->state = kEofState;
else
return cwLogError(rc,"dataset rdr cache fill failed.");
}
// Get a count of the records in the cache (p->tocN) and adjust the file position such that
// it is left pointing to the beginning of the first record after the cache.
if((rc = _cache_count_and_align(p,actByteN)) != kOkRC )
return rc;
// alllocate the TOC
p->tocV = mem::resize<unsigned>(p->tocV,p->tocN);
// Fill the p->tocV[]
_cache_fill_toc(p);
if( p->shuffleFl)
_cache_shuffle_toc(p);
return rc;
}
rc_t _cache_rewind( cache_t* p )
{
rc_t rc;
// rewind the file to the beginning of the
if(( rc = file::seek(p->fH,file::kBeginFl,p->baseFileOffs)) != kOkRC )
{
rc = cwLogError(rc,"rdr cache file seek failed.");
goto errLabel;
}
if((rc = _cache_fill(p)) != kOkRC )
goto errLabel;
p->tocBaseIdx = 0;
p->tocIdx = 0;
errLabel:
return rc;
}
rc_t _cache_advance( cache_t* p )
{
rc_t rc = kOkRC;
unsigned n = p->tocN;
if((rc = _cache_fill(p)) != kOkRC )
goto errLabel;
p->tocBaseIdx += n;
errLabel:
return rc;
}
rc_t cache_setup( cache_t* p, file::handle_t fH, unsigned bufMaxByteN, unsigned baseFileOffs, unsigned totalRecordN, bool shuffleFl )
{
rc_t rc = kOkRC;
p->fH = fH;
p->buf = mem::alloc<uint8_t>( bufMaxByteN );
p->bufMaxByteN = bufMaxByteN;
p->bufByteN = 0;
p->baseFileOffs = baseFileOffs;
p->state = kOkState;
p->totalRecdN = totalRecordN;
p->shuffleFl = shuffleFl;
rc = _cache_rewind(p);
return rc;
}
rc_t cache_close( cache_t* p )
{
mem::release(p->tocV);
mem::release(p->buf);
return kOkRC;
}
rc_t cache_read( cache_t* p, const std::uint8_t*& recdRef, unsigned& recdByteN )
{
rc_t rc = kOkRC;
unsigned tocIdx;
if( p->tocIdx == p->totalRecdN )
{
rc = kEofRC;
p->state = kEofState;
goto errLabel;
}
if( p->tocIdx == p->tocBaseIdx + p->tocN )
if((rc = _cache_advance(p)) != kOkRC )
goto errLabel;
tocIdx = p->tocIdx - p->tocBaseIdx;
recdByteN = *reinterpret_cast<unsigned*>(p->buf + p->tocV[ tocIdx ]);
recdRef = p->buf + (kSizeofRecordHeader + p->tocV[ tocIdx ]);
p->tocIdx += 1;
errLabel:
return rc;
}
rc_t cache_seek( cache_t* p, unsigned recordIdx )
{
rc_t rc = kOkRC;
if( recordIdx >= p->totalRecdN )
return cwLogError(kSeekFailRC,"rdr cache seek index %i greater than last index: %i.",recordIdx,p->totalRecdN-1);
// if the requested record index is inside the cache
if( p->tocBaseIdx <= recordIdx && recordIdx < p->tocBaseIdx + p->tocN )
p->tocIdx = recordIdx;
else
{
// if the requested record index is prior to the cache
if( recordIdx < p->tocBaseIdx )
if((rc = _cache_rewind(p)) != kOkRC )
goto errLabel;
// recordIdx now must be past the beginning of the cache
assert( recordIdx >= p->tocBaseIdx );
// advance the cache until recordIdx is inside of it
while( recordIdx >= p->tocBaseIdx + p->tocN )
{
if((rc = _cache_advance(p)) != kOkRC )
goto errLabel;
}
assert( p->tocBaseIdx <= recordIdx && recordIdx < p->tocBaseIdx + p->tocN );
p->tocIdx = recordIdx;
}
errLabel:
return rc;
}
typedef struct
{
col_t col; // Public record
@ -653,15 +912,18 @@ namespace cw
unsigned column_count; // Count of elements in colA[].
unsigned record_count; // Count of total examples.
file::handle_t fH; // Backing data file handle.
std::uint8_t* buf; // buf[ bufMaxByteN ] File read buffer
const std::uint8_t* buf; // buf[ bufMaxByteN ] File read buffer
unsigned bufMaxByteN; // Allocated size of buf[] in bytes. (also sizeof fixed size records)
unsigned bufCurByteN; // Current count of bytes used in buf[].
bool isFixedSizeFl; // True if all fields are fixed size
unsigned flags; // kShuffleFl
unsigned curRecordIdx; // Index of record in buf[].
unsigned nextRecordIdx; // Index of the next record to read.
long baseFileByteOffs; // File byte offset of the first data record
cache_t* cache;
unsigned state; // See k???State enum
@ -752,15 +1014,17 @@ namespace cw
mem::free( const_cast<char*>(p->colA[i].col.label) );
}
cache_close(p->cache);
mem::release(p->cache);
file::close(p->fH);
mem::release(p->colA);
mem::release(p->buf);
//mem::free(const_cast<std::uint8_t*>(p->buf));
mem::release(p);
return kOkRC;
}
rc_t _readHdr( rdr_t* p )
rc_t _readHdr( rdr_t* p, unsigned cacheByteN, unsigned flags )
{
rc_t rc = kOkRC;
unsigned bufOffsByteN = 0;
@ -783,8 +1047,8 @@ namespace cw
if((rc = read(p->fH,c->col.varDimN)) != kOkRC ) goto errLabel;
if((rc = read(p->fH,c->col.rankN )) != kOkRC ) goto errLabel;
if((rc = read(p->fH,c->col.maxEleN )) != kOkRC ) goto errLabel;
if((rc = variant::read( p->fH, c->col.max)) != kOkRC ) goto errLabel;
if((rc = variant::read( p->fH, c->col.min )) != kOkRC ) goto errLabel;
if((rc = variant::read( p->fH, c->col.max)) != kOkRC ) goto errLabel;
if((rc = variant::read( p->fH, c->col.min )) != kOkRC ) goto errLabel;
c->col.dimV = mem::allocZ<unsigned>( c->col.rankN );
@ -827,10 +1091,18 @@ namespace cw
bufOffsByteN = p->bufMaxByteN;
}
p->buf = mem::alloc<std::uint8_t>(p->bufMaxByteN);
p->buf = nullptr; //mem::alloc<std::uint8_t>(p->bufMaxByteN);
p->cache = mem::allocZ<cache_t>(1);
// store the file offset to the first data record
rc = tell(p->fH,&p->baseFileByteOffs);
if((rc = tell(p->fH,&p->baseFileByteOffs)) != kOkRC )
{
rc = cwLogError(rc,"rdr dataset tell file position failed.");
goto errLabel;
}
rc = cache_setup( p->cache, p->fH, cacheByteN, p->baseFileByteOffs, p->record_count, cwIsFlag(flags,kShuffleFl) );
errLabel:
if( rc != kOkRC )
@ -842,75 +1114,14 @@ namespace cw
return rc;
}
rc_t _rewind( rdr_t* p )
{
rc_t rc;
if((rc = file::seek( p->fH, file::kBeginFl, p->baseFileByteOffs)) != kOkRC )
p->state = kErrorState;
else
{
p->curRecordIdx = kInvalidIdx;
p->nextRecordIdx = 0;
}
return rc;
}
rc_t _var_seek( rdr_t* p, unsigned recdIdx )
{
rc_t rc = kOkRC;
if( recdIdx < p->nextRecordIdx )
if((rc = _rewind(p)) != kOkRC )
goto errLabel;
for(; recdIdx < p->nextRecordIdx; ++recdIdx )
{
unsigned recdByteN;
if((rc = file::read(p->fH,recdByteN)) != kOkRC )
{
p->state = kErrorState;
goto errLabel;
}
if((rc = file::seek(p->fH, file::kCurFl, recdByteN )) != kOkRC )
{
p->state = kErrorState;
goto errLabel;
}
}
errLabel:
return rc;
}
// Seek to the a record, but don't actually read it.
rc_t _seek( rdr_t* p, unsigned recdIdx )
{
rc_t rc = kOkRC;
if( p->nextRecordIdx == recdIdx )
return rc;
if( recdIdx >= p->record_count )
{
rc = cwLogError(kInvalidArgRC,"The seek index %i is invalid. Record Count=%i", recdIdx, p->record_count);
goto errLabel;
}
if( p->isFixedSizeFl )
rc = _var_seek(p,recdIdx);
else
{
// fixed size recds offset = baseOffset + (recd_index * (sizeof(recd_byte_cnt) + sizeof(data_record)))
rc = file::seek( p->fH, file::kBeginFl, p->baseFileByteOffs + recdIdx * (kSizeofRecordHeader + p->bufMaxByteN));
}
if( rc == kOkRC )
p->nextRecordIdx = recdIdx;
errLabel:
rc_t rc;
if((rc = cache_seek(p->cache,recdIdx)) != kOkRC )
p->state = p->cache->state;
return rc;
}
rc_t _parse_var_record( rdr_t* p )
@ -926,7 +1137,7 @@ namespace cw
// if this is a variabled sized column
if( c->col.varDimN != 0 )
{
unsigned* varDimV = reinterpret_cast<unsigned*>(p->buf + p->bufCurByteN );
const unsigned* varDimV = reinterpret_cast<const unsigned*>(p->buf + p->bufCurByteN );
unsigned eleN = c->col.rankN==0 ? 0 : 1;
// for each dim. of this column
@ -963,24 +1174,12 @@ namespace cw
unsigned recordByteN;
// Read the byte length of this record
if((rc = file::read(p->fH, recordByteN )) != kOkRC )
if((rc = cache_read( p->cache, p->buf, recordByteN )) != kOkRC )
{
if( file::eof(p->fH) )
{
p->state = kEofState;
return kEofRC;
}
p->state = p->cache->state;
goto errLabel;
}
assert( recordByteN <= p->bufMaxByteN );
// read the record data into p->buf[]
if((rc = file::read( p->fH, p->buf, recordByteN )) != kOkRC )
goto errLabel;
// if all columns in the record do not have a fixed size then update
// the column pointers into the data record
if( !p->isFixedSizeFl )
@ -993,7 +1192,7 @@ namespace cw
return rc;
}
rc_t _get( rdr_t* p, unsigned columnId, void*& vpRef, unsigned& nRef, const unsigned*& dimVRef, unsigned reqTypeId )
rc_t _get( rdr_t* p, unsigned columnId, const void*& vpRef, unsigned& nRef, const unsigned*& dimVRef, unsigned reqTypeId )
{
const c_t* c;;
@ -1013,7 +1212,7 @@ namespace cw
}
}
cw::rc_t cw::dataset::rdr::create( handle_t& h, const char* fn )
cw::rc_t cw::dataset::rdr::create( handle_t& h, const char* fn, unsigned cacheBufByteN, unsigned flags )
{
rc_t rc;
if((rc = destroy(h)) != kOkRC )
@ -1022,11 +1221,12 @@ cw::rc_t cw::dataset::rdr::create( handle_t& h, const char* fn )
auto p = mem::allocZ<rdr_t>(1);
if((rc = file::open(p->fH, fn,file::kReadFl)) == kOkRC )
if((rc = _readHdr(p)) != kOkRC )
if((rc = _readHdr(p,cacheBufByteN,flags)) != kOkRC )
goto errLabel;
p->state = kOkState;
p->curRecordIdx = kInvalidIdx;
p->flags = flags;
h.set(p);
errLabel:
@ -1126,7 +1326,7 @@ cw::rc_t cw::dataset::rdr::read( handle_t h, unsigned record_index )
cw::rc_t cw::dataset::rdr::get( handle_t h, unsigned columnId, const int*& vRef, unsigned& nRef, const unsigned*& dimVRef )
{
rdr_t* p = _handleToPtr(h);
void* vp = nullptr;
const void* vp = nullptr;
rc_t rc = _get(p, columnId, vp, nRef, dimVRef, kIntRdrFl );
vRef = rc!=kOkRC ? nullptr : static_cast<const int*>(vp);
@ -1137,7 +1337,7 @@ cw::rc_t cw::dataset::rdr::get( handle_t h, unsigned columnId, const int*& vR
cw::rc_t cw::dataset::rdr::get( handle_t h, unsigned columnId, const float*& vRef, unsigned& nRef, const unsigned*& dimVRef )
{
rdr_t* p = _handleToPtr(h);
void* vp = nullptr;
const void* vp = nullptr;
rc_t rc = _get(p, columnId, vp, nRef, dimVRef, kFloatRdrFl );
vRef = rc!=kOkRC ? nullptr : static_cast<const float*>(vp);
@ -1148,7 +1348,7 @@ cw::rc_t cw::dataset::rdr::get( handle_t h, unsigned columnId, const float*& vR
cw::rc_t cw::dataset::rdr::get( handle_t h, unsigned columnId, const double*& vRef, unsigned& nRef, const unsigned*& dimVRef )
{
rdr_t* p = _handleToPtr(h);
void* vp = nullptr;
const void* vp = nullptr;
rc_t rc = _get(p, columnId, vp, nRef, dimVRef, kDoubleRdrFl );
vRef = rc!=kOkRC ? nullptr : static_cast<const double*>(vp);
@ -1189,14 +1389,15 @@ cw::rc_t cw::dataset::rdr::test( const object_t* cfg )
{
rc_t rc = kOkRC;
char* inFn = nullptr;
unsigned cacheByteN = 128;
handle_t h;
if((rc = cfg->getv("inFn",inFn)) != kOkRC )
if((rc = cfg->getv("inFn",inFn,"cacheByteN",cacheByteN)) != kOkRC )
return cwLogError(rc,"rdr test failed. Argument parse failed.");
inFn = filesys::expandPath(inFn);
if((rc = create(h,inFn)) != kOkRC )
if((rc = create(h,inFn,cacheByteN,kShuffleFl)) != kOkRC )
{
rc = cwLogError(rc,"rdr create failed.");
}
@ -1222,6 +1423,7 @@ cw::rc_t cw::dataset::rdr::test( const object_t* cfg )
destroy(h);
}
mem::release(inFn);
return rc;
}
@ -1650,7 +1852,7 @@ namespace cw {
}
}
cw::rc_t cw::dataset::adapter::create( handle_t& hRef, const char* fn, unsigned maxBatchN )
cw::rc_t cw::dataset::adapter::create( handle_t& hRef, const char* fn, unsigned maxBatchN, unsigned cacheByteN, unsigned flags )
{
rc_t rc = kOkRC;
@ -1659,7 +1861,7 @@ cw::rc_t cw::dataset::adapter::create( handle_t& hRef, const char* fn, unsigned
adapter_t* p = mem::allocZ<adapter_t>(1);
if((rc = rdr::create(p->rdrH,fn)) != kOkRC )
if((rc = rdr::create(p->rdrH,fn,cacheByteN,flags)) != kOkRC )
goto errLabel;
p->maxBatchN = maxBatchN;
@ -1882,9 +2084,11 @@ cw::rc_t cw::dataset::adapter::print_field( handle_t h, unsigned fieldId, const
cw::rc_t cw::dataset::adapter::test( const object_t* cfg )
{
rc_t rc = kOkRC;
char* inFn = nullptr;
unsigned batchN = 0;
rc_t rc = kOkRC;
char* inFn = nullptr;
unsigned batchN = 0;
unsigned cacheByteN = 128;
unsigned shuffleFl = rdr::kShuffleFl;
handle_t h;
enum {
@ -1893,13 +2097,13 @@ cw::rc_t cw::dataset::adapter::test( const object_t* cfg )
};
// read the cfg args
if((rc = cfg->getv("inFn",inFn,"batchN",batchN)) != kOkRC )
if((rc = cfg->getv("inFn",inFn,"batchN",batchN,"cacheByteN",cacheByteN)) != kOkRC )
return cwLogError(rc,"adapter test failed. Argument parse failed.");
inFn = filesys::expandPath(inFn);
// create the adapter
if((rc = create(h, inFn, batchN)) != kOkRC )
if((rc = create(h, inFn, batchN, cacheByteN, shuffleFl)) != kOkRC )
{
rc = cwLogError(rc,"Unable to create dataset adapter for '%s'.",inFn);
goto errLabel;
@ -2065,7 +2269,6 @@ namespace cw
// read each example
for(unsigned i=0; i<exampleN; ++i)
{
// read the digit image label
if((rc = read(fH, labelV[i])) != kOkRC )
{
@ -2187,6 +2390,7 @@ unsigned cw::dataset::mnist::record_count( handle_t h )
return p->exampleN;
}
cw::rc_t cw::dataset::mnist::seek( handle_t h, unsigned exampleIdx )
{
rc_t rc = kOkRC;
@ -2239,7 +2443,7 @@ cw::rc_t cw::dataset::mnist::write( handle_t h, const char* fn )
enum { kImagId, kNumbId };
unsigned numbDimV[] = {1};
unsigned imagDimV[] = {28,28};
unsigned imagDimV[] = {kPixelRowN,kPixelColN};
unsigned imagEleN = imagDimV[0]*imagDimV[1];
if((rc = define_columns( wtrH, "numb", kNumbId, cwCountOf(numbDimV), numbDimV )) != kOkRC )
@ -2312,35 +2516,9 @@ cw::rc_t cw::dataset::mnist::test( const object_t* cfg )
rc = cwLogError(rc,"SVG Test failed on create.");
else
{
//const mtx::f_t* m = train(h);
/*
unsigned zn = 0;
unsigned i = 1;
for(; i<m->dimV[1]; ++i)
{
const float* v0 = m->base + (28*28+1) * (i-1) + 1;
const float* v1 = m->base + (28*28+1) * (i-0) + 1;
float d = 0;
for(unsigned j=0; j<28*28; ++j)
d += fabs(v0[j]-v1[j]);
if( d==0 )
++zn;
else
{
printf("%i %i %f\n",i,zn,d);
zn = 0;
}
}
printf("i:%i n:%i zn:%i\n",i,m->dimV[1],zn);
*/
const float* dataM = nullptr;
const unsigned* labelV = nullptr;
unsigned exampleN = 10;
unsigned exampleN = 100;
unsigned actualExampleN = 0;
//mnist::seek( h, 10 );
@ -2349,13 +2527,13 @@ cw::rc_t cw::dataset::mnist::test( const object_t* cfg )
for(unsigned i=0; i<actualExampleN; ++i)
{
printf("label: %i\n", labelV[i] );
printf("label: %i ", labelV[i] );
svg::offset(svgH, 0, i*30*5 );
svg::image(svgH, dataM + (28*28)*i, 28, 28, 5, svg::kInvGrayScaleColorMapId);
svg::image(svgH, dataM + (kPixelRowN*kPixelColN)*i, kPixelRowN, kPixelColN, 5, svg::kInvGrayScaleColorMapId);
}
svg::write(svgH, outHtmlFn, nullptr, svg::kStandAloneFl | svg::kGenInlineStyleFl, 10,10,10,10);
svg::write(svgH, outHtmlFn, nullptr, svg::kStandAloneFl, 10,10,10,10);
svg::destroy(svgH);
@ -2381,12 +2559,14 @@ cw::rc_t cw::dataset::test( const object_t* cfg )
char* dsFn = nullptr;
char* outHtmlFn = nullptr;
mnist::handle_t mniH;
adapter::handle_t rdrH;
adapter::handle_t adpH;
svg::handle_t svgH;
unsigned batchN = 10;
unsigned batchN = 100;
unsigned cacheByteN = 4096 * 10;
unsigned shuffleFl = rdr::kShuffleFl;
if((rc = cfg->getv("inDir",inDir,"dsFn",dsFn,"outHtmlFn",outHtmlFn,"batchN",batchN)) != kOkRC )
if((rc = cfg->getv("inDir",inDir,"dsFn",dsFn,"outHtmlFn",outHtmlFn,"batchN",batchN,"cacheByteN",cacheByteN)) != kOkRC )
return cwLogError(rc,"MNIST test failed. Argument parse failed.");
inDir = filesys::expandPath(inDir);
@ -2413,7 +2593,7 @@ cw::rc_t cw::dataset::test( const object_t* cfg )
}
// open a dataset adapter
if((rc = adapter::create(rdrH,dsFn,batchN)) != kOkRC )
if((rc = adapter::create(adpH,dsFn,batchN,cacheByteN,shuffleFl)) != kOkRC )
{
cwLogError(rc,"Dataset reader create failed.");
goto errLabel;
@ -2428,49 +2608,64 @@ cw::rc_t cw::dataset::test( const object_t* cfg )
enum { kImagId, kNumbId };
// create dataset fields
if((rc = create_field( rdrH, kImagId, adapter::kFloatFl, "imag" )) != kOkRC )
// create a field for the image data
if((rc = create_field( adpH, kImagId, adapter::kFloatFl, "imag" )) != kOkRC )
{
cwLogError(rc,"Dataset rdr column define failed.");
goto errLabel;
}
if((rc = create_field( rdrH, kNumbId, adapter::kIntFl, "numb" )) != kOkRC )
// create a field for the image lable
if((rc = create_field( adpH, kNumbId, adapter::kIntFl, "numb" )) != kOkRC )
{
cwLogError(rc,"Dataset rdr column define failed.");
goto errLabel;
}
// read a batch of data
if((rc = adapter::read( rdrH, batchN)) != kOkRC )
for(unsigned j=0,imageN=0; true; ++j )
{
cwLogError(rc,"Batch read failed.");
goto errLabel;
}
else
{
const int* numbV = nullptr;
const unsigned* numbNV = nullptr;
const float* imagV = nullptr;
const unsigned* imagNV = nullptr;
adapter::get(rdrH, kNumbId, numbV, numbNV ); // get the labels
adapter::get(rdrH, kImagId, imagV, imagNV ); // get the image data
for(unsigned i=0; i<batchN; ++i)
// read a batch of data
if((rc = adapter::read( adpH, batchN)) != kOkRC )
{
printf("label: %i\n", numbV[i] );
svg::offset(svgH, 0, i*30*5 );
svg::image(svgH, imagV + (28*28)*i, 28, 28, 5, svg::kInvGrayScaleColorMapId);
if( rc == kEofRC )
cwLogInfo("Done!.");
else
cwLogError(rc,"Batch read failed.");
goto errLabel;
}
else
{
const int* numbV = nullptr;
const unsigned* numbNV = nullptr;
const float* imagV = nullptr;
const unsigned* imagNV = nullptr;
const unsigned kPixelSize = 5;
svg::write(svgH, outHtmlFn, nullptr, svg::kStandAloneFl | svg::kGenInlineStyleFl, 10,10,10,10);
adapter::get(adpH, kNumbId, numbV, numbNV ); // get the labels
adapter::get(adpH, kImagId, imagV, imagNV ); // get the image data
printf("%3i : ",j);
// print the first 5 images from each batch to an SVG file
for(unsigned i=0; i<0; ++i,++imageN)
{
printf("%i ", numbV[i] );
// offset the image vertically
svg::offset(svgH, 0, imageN*30*kPixelSize );
svg::image(svgH, imagV + (mnist::kPixelRowN*mnist::kPixelColN)*i, mnist::kPixelRowN, mnist::kPixelColN, kPixelSize, svg::kInvGrayScaleColorMapId);
}
printf("\n");
}
}
svg::write(svgH, outHtmlFn, nullptr, svg::kStandAloneFl, 10,10,10,10);
}
}
errLabel:
adapter::destroy(rdrH);
adapter::destroy(adpH);
svg::destroy(svgH);
mem::release(inDir);
mem::release(dsFn);

View File

@ -153,7 +153,12 @@ namespace cw
unsigned byteN; // Size of this field in bytes.
} col_t;
rc_t create( handle_t& h, const char* fn );
enum {
kShuffleFl = 0x01
};
rc_t create( handle_t& h, const char* fn, unsigned cacheBufByteN, unsigned flags=kShuffleFl );
rc_t destroy( handle_t& h );
unsigned column_count( handle_t h );
@ -178,7 +183,7 @@ namespace cw
// Read the next record.
rc_t read( handle_t h, unsigned recordIdx=kInvalidIdx );
// Read a column value.
// Get a column value from the last record returned by 'read()'.
//
// vRef = Pointer to the value vector.
// nRef = Count of elements in value vector.
@ -222,8 +227,7 @@ namespace cw
unsigned rankN; // dimV[ rankN ] Rank of this column
} colMap_t;
rc_t create( handle_t& hRef, const char* fn, unsigned maxBatchN );
rc_t create( handle_t& hRef, const char* fn, unsigned maxBatchN, unsigned cacheByteN, unsigned flags=rdr::kShuffleFl );
rc_t destroy( handle_t& hRef );
// Create a field and assign it a column.
@ -249,8 +253,7 @@ namespace cw
rc_t get( handle_t h, unsigned fieldId, const float*& fV_Ref, const unsigned*& fNV_Ref );
rc_t get( handle_t h, unsigned fieldId, const double*& fV_Ref, const unsigned*& fNV_Ref );
// Returns col position and geometry data from each record returned by the last
// call to read().
// Returns col position and geometry data from each record returned by the last call to read().
// Returns colMapV_Ref[batchN][columnN].
rc_t column_map( handle_t h, unsigned fieldId, colMap_t const * const *& colMapV_Ref );
@ -269,6 +272,11 @@ namespace cw
{
typedef handle<struct mnist_str> handle_t;
enum {
kPixelRowN = 28,
kPixelColN = 28
};
rc_t create( handle_t& h, const char* inDir );
rc_t destroy( handle_t& h );