cwDataSets.h/cpp : Add use of cache_t file cache to rdr to improve file read performance and add shuffling option.
This commit is contained in:
parent
46a7633e00
commit
cbf4870410
553
cwDataSets.cpp
553
cwDataSets.cpp
@ -12,7 +12,7 @@
|
||||
#include "cwSvg.h"
|
||||
#include "cwTime.h"
|
||||
#include "cwText.h"
|
||||
|
||||
#include "cwMath.h"
|
||||
|
||||
//----------------------------------------------------------------------------------------------------------------------------
|
||||
//----------------------------------------------------------------------------------------------------------------------------
|
||||
@ -26,7 +26,7 @@ namespace cw
|
||||
{
|
||||
typedef struct col_str
|
||||
{
|
||||
rdr::col_t col; // Public fields.
|
||||
rdr::col_t col; // Public fields - See rdr::col_t.
|
||||
unsigned char* cur; // Cache of the current column data contents.
|
||||
unsigned curByteN; // Count of bytes in cur[].
|
||||
unsigned* curDimV; // Cache of the current column dimensions.
|
||||
@ -159,17 +159,17 @@ namespace cw
|
||||
++col_count;
|
||||
|
||||
if((rc = file::write( p->fH, p->record_count )) != kOkRC ) goto errLabel;
|
||||
if((rc = file::write( p->fH, col_count )) != kOkRC ) goto errLabel;
|
||||
if((rc = file::write( p->fH, col_count )) != kOkRC ) goto errLabel;
|
||||
|
||||
for(c=p->colL; c!=nullptr; c=c->link)
|
||||
{
|
||||
if((rc = file::writeStr( p->fH, c->col.label )) != kOkRC ) goto errLabel;
|
||||
if((rc = file::writeStr( p->fH, c->col.label )) != kOkRC ) goto errLabel;
|
||||
if((rc = file::write( p->fH, c->col.id )) != kOkRC ) goto errLabel;
|
||||
if((rc = file::write( p->fH, c->col.varDimN )) != kOkRC ) goto errLabel;
|
||||
if((rc = file::write( p->fH, c->col.rankN )) != kOkRC ) goto errLabel;
|
||||
if((rc = file::write( p->fH, c->col.maxEleN )) != kOkRC ) goto errLabel;
|
||||
if((rc = variant::write( p->fH, c->col.max)) != kOkRC ) goto errLabel;
|
||||
if((rc = variant::write( p->fH, c->col.min )) != kOkRC ) goto errLabel;
|
||||
if((rc = variant::write( p->fH, c->col.max)) != kOkRC ) goto errLabel;
|
||||
if((rc = variant::write( p->fH, c->col.min )) != kOkRC ) goto errLabel;
|
||||
|
||||
for(unsigned i=0; i<c->col.rankN; ++i)
|
||||
{
|
||||
@ -639,6 +639,265 @@ namespace cw
|
||||
{
|
||||
kSizeofRecordHeader = sizeof(unsigned)
|
||||
};
|
||||
|
||||
typedef struct cache_str
|
||||
{
|
||||
file::handle_t fH;
|
||||
unsigned totalRecdN; // Total count of records in the file
|
||||
std::uint8_t * buf; // File buffer memory
|
||||
unsigned bufMaxByteN; // Allocated size of buf[]
|
||||
unsigned bufByteN; // Bytes in buf[]
|
||||
|
||||
unsigned baseFileOffs; // Offset of the first record in the file
|
||||
|
||||
unsigned* tocV; // tocV[tocN] Cached record byte offsets
|
||||
unsigned tocN; // Count of records in the cache
|
||||
unsigned tocBaseIdx; // Record index of the first record in the cache
|
||||
unsigned tocIdx; // Record index of next record to return
|
||||
unsigned state; // See rdr::k???State
|
||||
bool shuffleFl; // shuffle the file buffer each time it is filled
|
||||
|
||||
} cache_t;
|
||||
|
||||
// Backup the file position to the beginning of the last (partial) record in the cache.
|
||||
// Note that the last record overlaps the end of the cache is is therefore incomplete.
|
||||
rc_t _cache_backup( cache_t* p, unsigned actByteN, unsigned cacheByteN )
|
||||
{
|
||||
rc_t rc = kOkRC;
|
||||
|
||||
if( p->state == kEofRC )
|
||||
return kEofRC;
|
||||
|
||||
assert( actByteN >= cacheByteN);
|
||||
|
||||
if((rc = file::seek(p->fH,file::kCurFl, -(int)(actByteN-cacheByteN) )) != kOkRC )
|
||||
return cwLogError(rc,"Dataset rdr cache align failed.");
|
||||
|
||||
return rc;
|
||||
}
|
||||
|
||||
// Count the records in the case and re-align the current file position to the last (partial) record in the cache
|
||||
rc_t _cache_count_and_align( cache_t* p, unsigned actByteN )
|
||||
{
|
||||
p->bufByteN = 0;
|
||||
for(p->tocN=0; p->bufByteN < actByteN; ++p->tocN )
|
||||
{
|
||||
if( p->bufByteN + kSizeofRecordHeader >= actByteN )
|
||||
{
|
||||
_cache_backup( p, actByteN, p->bufByteN);
|
||||
break;
|
||||
}
|
||||
|
||||
unsigned recdByteN = *reinterpret_cast<unsigned*>(p->buf + p->bufByteN);
|
||||
|
||||
// TODO: handle case where the whole buffer has less than one record
|
||||
if( p->tocN==0 && actByteN < kSizeofRecordHeader + recdByteN )
|
||||
{
|
||||
assert(0);
|
||||
}
|
||||
|
||||
if( p->bufByteN + recdByteN > actByteN )
|
||||
{
|
||||
_cache_backup( p, actByteN, p->bufByteN);
|
||||
break;
|
||||
}
|
||||
|
||||
p->bufByteN += kSizeofRecordHeader + recdByteN;
|
||||
}
|
||||
|
||||
return kOkRC;
|
||||
}
|
||||
|
||||
void _cache_shuffle_toc( cache_t* p )
|
||||
{
|
||||
// for each record address in tocV[]
|
||||
for(unsigned i=0; i<p->tocN; ++i)
|
||||
{
|
||||
// generate a random index into tocV[]
|
||||
unsigned idx = math::randUInt(0,p->tocN-1);
|
||||
|
||||
// swap location i with a random location
|
||||
unsigned tmp = p->tocV[i];
|
||||
p->tocV[i] = p->tocV[idx];
|
||||
p->tocV[idx] = tmp;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void _cache_fill_toc( cache_t* p )
|
||||
{
|
||||
unsigned cacheByteOffs = 0;
|
||||
|
||||
for(unsigned i=0; i<p->tocN; ++i)
|
||||
{
|
||||
p->tocV[i] = cacheByteOffs;
|
||||
|
||||
unsigned recdByteN = *reinterpret_cast<unsigned*>(p->buf + cacheByteOffs);
|
||||
|
||||
cacheByteOffs += kSizeofRecordHeader + recdByteN;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
rc_t _cache_fill( cache_t* p )
|
||||
{
|
||||
rc_t rc = kOkRC;
|
||||
unsigned actByteN = 0;
|
||||
|
||||
// Note that his function is always called when the file is pointing to the
|
||||
// record length at the start of a record
|
||||
|
||||
// Fill the cache with as much data as possible from the file
|
||||
if((rc = file::read(p->fH, p->buf, p->bufMaxByteN, &actByteN)) != kOkRC )
|
||||
{
|
||||
if(rc == kEofRC)
|
||||
p->state = kEofState;
|
||||
else
|
||||
return cwLogError(rc,"dataset rdr cache fill failed.");
|
||||
}
|
||||
|
||||
// Get a count of the records in the cache (p->tocN) and adjust the file position such that
|
||||
// it is left pointing to the beginning of the first record after the cache.
|
||||
if((rc = _cache_count_and_align(p,actByteN)) != kOkRC )
|
||||
return rc;
|
||||
|
||||
// alllocate the TOC
|
||||
p->tocV = mem::resize<unsigned>(p->tocV,p->tocN);
|
||||
|
||||
// Fill the p->tocV[]
|
||||
_cache_fill_toc(p);
|
||||
|
||||
if( p->shuffleFl)
|
||||
_cache_shuffle_toc(p);
|
||||
|
||||
return rc;
|
||||
}
|
||||
|
||||
rc_t _cache_rewind( cache_t* p )
|
||||
{
|
||||
rc_t rc;
|
||||
|
||||
// rewind the file to the beginning of the
|
||||
if(( rc = file::seek(p->fH,file::kBeginFl,p->baseFileOffs)) != kOkRC )
|
||||
{
|
||||
rc = cwLogError(rc,"rdr cache file seek failed.");
|
||||
goto errLabel;
|
||||
}
|
||||
|
||||
if((rc = _cache_fill(p)) != kOkRC )
|
||||
goto errLabel;
|
||||
|
||||
p->tocBaseIdx = 0;
|
||||
p->tocIdx = 0;
|
||||
|
||||
errLabel:
|
||||
return rc;
|
||||
}
|
||||
|
||||
rc_t _cache_advance( cache_t* p )
|
||||
{
|
||||
rc_t rc = kOkRC;
|
||||
unsigned n = p->tocN;
|
||||
if((rc = _cache_fill(p)) != kOkRC )
|
||||
goto errLabel;
|
||||
|
||||
p->tocBaseIdx += n;
|
||||
|
||||
errLabel:
|
||||
return rc;
|
||||
}
|
||||
|
||||
rc_t cache_setup( cache_t* p, file::handle_t fH, unsigned bufMaxByteN, unsigned baseFileOffs, unsigned totalRecordN, bool shuffleFl )
|
||||
{
|
||||
rc_t rc = kOkRC;
|
||||
|
||||
p->fH = fH;
|
||||
p->buf = mem::alloc<uint8_t>( bufMaxByteN );
|
||||
p->bufMaxByteN = bufMaxByteN;
|
||||
p->bufByteN = 0;
|
||||
p->baseFileOffs = baseFileOffs;
|
||||
p->state = kOkState;
|
||||
p->totalRecdN = totalRecordN;
|
||||
p->shuffleFl = shuffleFl;
|
||||
rc = _cache_rewind(p);
|
||||
|
||||
return rc;
|
||||
}
|
||||
|
||||
rc_t cache_close( cache_t* p )
|
||||
{
|
||||
mem::release(p->tocV);
|
||||
mem::release(p->buf);
|
||||
return kOkRC;
|
||||
}
|
||||
|
||||
rc_t cache_read( cache_t* p, const std::uint8_t*& recdRef, unsigned& recdByteN )
|
||||
{
|
||||
rc_t rc = kOkRC;
|
||||
unsigned tocIdx;
|
||||
|
||||
if( p->tocIdx == p->totalRecdN )
|
||||
{
|
||||
rc = kEofRC;
|
||||
p->state = kEofState;
|
||||
goto errLabel;
|
||||
}
|
||||
|
||||
if( p->tocIdx == p->tocBaseIdx + p->tocN )
|
||||
if((rc = _cache_advance(p)) != kOkRC )
|
||||
goto errLabel;
|
||||
|
||||
|
||||
tocIdx = p->tocIdx - p->tocBaseIdx;
|
||||
|
||||
recdByteN = *reinterpret_cast<unsigned*>(p->buf + p->tocV[ tocIdx ]);
|
||||
|
||||
recdRef = p->buf + (kSizeofRecordHeader + p->tocV[ tocIdx ]);
|
||||
|
||||
p->tocIdx += 1;
|
||||
|
||||
errLabel:
|
||||
return rc;
|
||||
|
||||
}
|
||||
|
||||
rc_t cache_seek( cache_t* p, unsigned recordIdx )
|
||||
{
|
||||
rc_t rc = kOkRC;
|
||||
|
||||
if( recordIdx >= p->totalRecdN )
|
||||
return cwLogError(kSeekFailRC,"rdr cache seek index %i greater than last index: %i.",recordIdx,p->totalRecdN-1);
|
||||
|
||||
// if the requested record index is inside the cache
|
||||
if( p->tocBaseIdx <= recordIdx && recordIdx < p->tocBaseIdx + p->tocN )
|
||||
p->tocIdx = recordIdx;
|
||||
else
|
||||
{
|
||||
// if the requested record index is prior to the cache
|
||||
if( recordIdx < p->tocBaseIdx )
|
||||
if((rc = _cache_rewind(p)) != kOkRC )
|
||||
goto errLabel;
|
||||
|
||||
// recordIdx now must be past the beginning of the cache
|
||||
assert( recordIdx >= p->tocBaseIdx );
|
||||
|
||||
// advance the cache until recordIdx is inside of it
|
||||
while( recordIdx >= p->tocBaseIdx + p->tocN )
|
||||
{
|
||||
if((rc = _cache_advance(p)) != kOkRC )
|
||||
goto errLabel;
|
||||
}
|
||||
|
||||
assert( p->tocBaseIdx <= recordIdx && recordIdx < p->tocBaseIdx + p->tocN );
|
||||
|
||||
p->tocIdx = recordIdx;
|
||||
|
||||
}
|
||||
|
||||
errLabel:
|
||||
return rc;
|
||||
|
||||
}
|
||||
|
||||
typedef struct
|
||||
{
|
||||
@ -653,14 +912,17 @@ namespace cw
|
||||
unsigned column_count; // Count of elements in colA[].
|
||||
unsigned record_count; // Count of total examples.
|
||||
file::handle_t fH; // Backing data file handle.
|
||||
std::uint8_t* buf; // buf[ bufMaxByteN ] File read buffer
|
||||
const std::uint8_t* buf; // buf[ bufMaxByteN ] File read buffer
|
||||
unsigned bufMaxByteN; // Allocated size of buf[] in bytes. (also sizeof fixed size records)
|
||||
unsigned bufCurByteN; // Current count of bytes used in buf[].
|
||||
bool isFixedSizeFl; // True if all fields are fixed size
|
||||
|
||||
unsigned flags; // kShuffleFl
|
||||
unsigned curRecordIdx; // Index of record in buf[].
|
||||
unsigned nextRecordIdx; // Index of the next record to read.
|
||||
long baseFileByteOffs; // File byte offset of the first data record
|
||||
|
||||
cache_t* cache;
|
||||
|
||||
unsigned state; // See k???State enum
|
||||
|
||||
@ -751,16 +1013,18 @@ namespace cw
|
||||
mem::release( p->colA[i].varDimIdxV);
|
||||
mem::free( const_cast<char*>(p->colA[i].col.label) );
|
||||
}
|
||||
|
||||
|
||||
cache_close(p->cache);
|
||||
mem::release(p->cache);
|
||||
file::close(p->fH);
|
||||
mem::release(p->colA);
|
||||
mem::release(p->buf);
|
||||
//mem::free(const_cast<std::uint8_t*>(p->buf));
|
||||
mem::release(p);
|
||||
|
||||
return kOkRC;
|
||||
}
|
||||
|
||||
rc_t _readHdr( rdr_t* p )
|
||||
rc_t _readHdr( rdr_t* p, unsigned cacheByteN, unsigned flags )
|
||||
{
|
||||
rc_t rc = kOkRC;
|
||||
unsigned bufOffsByteN = 0;
|
||||
@ -783,8 +1047,8 @@ namespace cw
|
||||
if((rc = read(p->fH,c->col.varDimN)) != kOkRC ) goto errLabel;
|
||||
if((rc = read(p->fH,c->col.rankN )) != kOkRC ) goto errLabel;
|
||||
if((rc = read(p->fH,c->col.maxEleN )) != kOkRC ) goto errLabel;
|
||||
if((rc = variant::read( p->fH, c->col.max)) != kOkRC ) goto errLabel;
|
||||
if((rc = variant::read( p->fH, c->col.min )) != kOkRC ) goto errLabel;
|
||||
if((rc = variant::read( p->fH, c->col.max)) != kOkRC ) goto errLabel;
|
||||
if((rc = variant::read( p->fH, c->col.min )) != kOkRC ) goto errLabel;
|
||||
|
||||
|
||||
c->col.dimV = mem::allocZ<unsigned>( c->col.rankN );
|
||||
@ -827,10 +1091,18 @@ namespace cw
|
||||
bufOffsByteN = p->bufMaxByteN;
|
||||
}
|
||||
|
||||
p->buf = mem::alloc<std::uint8_t>(p->bufMaxByteN);
|
||||
p->buf = nullptr; //mem::alloc<std::uint8_t>(p->bufMaxByteN);
|
||||
p->cache = mem::allocZ<cache_t>(1);
|
||||
|
||||
// store the file offset to the first data record
|
||||
rc = tell(p->fH,&p->baseFileByteOffs);
|
||||
if((rc = tell(p->fH,&p->baseFileByteOffs)) != kOkRC )
|
||||
{
|
||||
rc = cwLogError(rc,"rdr dataset tell file position failed.");
|
||||
goto errLabel;
|
||||
}
|
||||
|
||||
rc = cache_setup( p->cache, p->fH, cacheByteN, p->baseFileByteOffs, p->record_count, cwIsFlag(flags,kShuffleFl) );
|
||||
|
||||
|
||||
errLabel:
|
||||
if( rc != kOkRC )
|
||||
@ -842,75 +1114,14 @@ namespace cw
|
||||
return rc;
|
||||
}
|
||||
|
||||
rc_t _rewind( rdr_t* p )
|
||||
{
|
||||
rc_t rc;
|
||||
if((rc = file::seek( p->fH, file::kBeginFl, p->baseFileByteOffs)) != kOkRC )
|
||||
p->state = kErrorState;
|
||||
else
|
||||
{
|
||||
p->curRecordIdx = kInvalidIdx;
|
||||
p->nextRecordIdx = 0;
|
||||
}
|
||||
return rc;
|
||||
}
|
||||
|
||||
rc_t _var_seek( rdr_t* p, unsigned recdIdx )
|
||||
{
|
||||
rc_t rc = kOkRC;
|
||||
|
||||
if( recdIdx < p->nextRecordIdx )
|
||||
if((rc = _rewind(p)) != kOkRC )
|
||||
goto errLabel;
|
||||
|
||||
for(; recdIdx < p->nextRecordIdx; ++recdIdx )
|
||||
{
|
||||
unsigned recdByteN;
|
||||
if((rc = file::read(p->fH,recdByteN)) != kOkRC )
|
||||
{
|
||||
p->state = kErrorState;
|
||||
goto errLabel;
|
||||
}
|
||||
|
||||
if((rc = file::seek(p->fH, file::kCurFl, recdByteN )) != kOkRC )
|
||||
{
|
||||
p->state = kErrorState;
|
||||
goto errLabel;
|
||||
}
|
||||
}
|
||||
|
||||
errLabel:
|
||||
return rc;
|
||||
}
|
||||
|
||||
// Seek to the a record, but don't actually read it.
|
||||
rc_t _seek( rdr_t* p, unsigned recdIdx )
|
||||
{
|
||||
rc_t rc = kOkRC;
|
||||
|
||||
if( p->nextRecordIdx == recdIdx )
|
||||
return rc;
|
||||
|
||||
if( recdIdx >= p->record_count )
|
||||
{
|
||||
rc = cwLogError(kInvalidArgRC,"The seek index %i is invalid. Record Count=%i", recdIdx, p->record_count);
|
||||
goto errLabel;
|
||||
}
|
||||
|
||||
if( p->isFixedSizeFl )
|
||||
rc = _var_seek(p,recdIdx);
|
||||
else
|
||||
{
|
||||
// fixed size recds offset = baseOffset + (recd_index * (sizeof(recd_byte_cnt) + sizeof(data_record)))
|
||||
rc = file::seek( p->fH, file::kBeginFl, p->baseFileByteOffs + recdIdx * (kSizeofRecordHeader + p->bufMaxByteN));
|
||||
}
|
||||
|
||||
if( rc == kOkRC )
|
||||
p->nextRecordIdx = recdIdx;
|
||||
|
||||
errLabel:
|
||||
rc_t rc;
|
||||
if((rc = cache_seek(p->cache,recdIdx)) != kOkRC )
|
||||
p->state = p->cache->state;
|
||||
return rc;
|
||||
|
||||
}
|
||||
|
||||
rc_t _parse_var_record( rdr_t* p )
|
||||
@ -926,7 +1137,7 @@ namespace cw
|
||||
// if this is a variabled sized column
|
||||
if( c->col.varDimN != 0 )
|
||||
{
|
||||
unsigned* varDimV = reinterpret_cast<unsigned*>(p->buf + p->bufCurByteN );
|
||||
const unsigned* varDimV = reinterpret_cast<const unsigned*>(p->buf + p->bufCurByteN );
|
||||
unsigned eleN = c->col.rankN==0 ? 0 : 1;
|
||||
|
||||
// for each dim. of this column
|
||||
@ -960,27 +1171,15 @@ namespace cw
|
||||
rc_t _read_record( rdr_t* p )
|
||||
{
|
||||
rc_t rc = kOkRC;
|
||||
|
||||
|
||||
unsigned recordByteN;
|
||||
|
||||
// Read the byte length of this record
|
||||
if((rc = file::read(p->fH, recordByteN )) != kOkRC )
|
||||
if((rc = cache_read( p->cache, p->buf, recordByteN )) != kOkRC )
|
||||
{
|
||||
if( file::eof(p->fH) )
|
||||
{
|
||||
p->state = kEofState;
|
||||
return kEofRC;
|
||||
}
|
||||
p->state = p->cache->state;
|
||||
goto errLabel;
|
||||
}
|
||||
|
||||
assert( recordByteN <= p->bufMaxByteN );
|
||||
|
||||
// read the record data into p->buf[]
|
||||
if((rc = file::read( p->fH, p->buf, recordByteN )) != kOkRC )
|
||||
goto errLabel;
|
||||
|
||||
|
||||
// if all columns in the record do not have a fixed size then update
|
||||
// the column pointers into the data record
|
||||
if( !p->isFixedSizeFl )
|
||||
@ -993,7 +1192,7 @@ namespace cw
|
||||
return rc;
|
||||
}
|
||||
|
||||
rc_t _get( rdr_t* p, unsigned columnId, void*& vpRef, unsigned& nRef, const unsigned*& dimVRef, unsigned reqTypeId )
|
||||
rc_t _get( rdr_t* p, unsigned columnId, const void*& vpRef, unsigned& nRef, const unsigned*& dimVRef, unsigned reqTypeId )
|
||||
{
|
||||
const c_t* c;;
|
||||
|
||||
@ -1013,7 +1212,7 @@ namespace cw
|
||||
}
|
||||
}
|
||||
|
||||
cw::rc_t cw::dataset::rdr::create( handle_t& h, const char* fn )
|
||||
cw::rc_t cw::dataset::rdr::create( handle_t& h, const char* fn, unsigned cacheBufByteN, unsigned flags )
|
||||
{
|
||||
rc_t rc;
|
||||
if((rc = destroy(h)) != kOkRC )
|
||||
@ -1022,11 +1221,12 @@ cw::rc_t cw::dataset::rdr::create( handle_t& h, const char* fn )
|
||||
auto p = mem::allocZ<rdr_t>(1);
|
||||
|
||||
if((rc = file::open(p->fH, fn,file::kReadFl)) == kOkRC )
|
||||
if((rc = _readHdr(p)) != kOkRC )
|
||||
if((rc = _readHdr(p,cacheBufByteN,flags)) != kOkRC )
|
||||
goto errLabel;
|
||||
|
||||
p->state = kOkState;
|
||||
p->curRecordIdx = kInvalidIdx;
|
||||
p->flags = flags;
|
||||
h.set(p);
|
||||
|
||||
errLabel:
|
||||
@ -1126,7 +1326,7 @@ cw::rc_t cw::dataset::rdr::read( handle_t h, unsigned record_index )
|
||||
cw::rc_t cw::dataset::rdr::get( handle_t h, unsigned columnId, const int*& vRef, unsigned& nRef, const unsigned*& dimVRef )
|
||||
{
|
||||
rdr_t* p = _handleToPtr(h);
|
||||
void* vp = nullptr;
|
||||
const void* vp = nullptr;
|
||||
rc_t rc = _get(p, columnId, vp, nRef, dimVRef, kIntRdrFl );
|
||||
|
||||
vRef = rc!=kOkRC ? nullptr : static_cast<const int*>(vp);
|
||||
@ -1137,7 +1337,7 @@ cw::rc_t cw::dataset::rdr::get( handle_t h, unsigned columnId, const int*& vR
|
||||
cw::rc_t cw::dataset::rdr::get( handle_t h, unsigned columnId, const float*& vRef, unsigned& nRef, const unsigned*& dimVRef )
|
||||
{
|
||||
rdr_t* p = _handleToPtr(h);
|
||||
void* vp = nullptr;
|
||||
const void* vp = nullptr;
|
||||
rc_t rc = _get(p, columnId, vp, nRef, dimVRef, kFloatRdrFl );
|
||||
|
||||
vRef = rc!=kOkRC ? nullptr : static_cast<const float*>(vp);
|
||||
@ -1148,7 +1348,7 @@ cw::rc_t cw::dataset::rdr::get( handle_t h, unsigned columnId, const float*& vR
|
||||
cw::rc_t cw::dataset::rdr::get( handle_t h, unsigned columnId, const double*& vRef, unsigned& nRef, const unsigned*& dimVRef )
|
||||
{
|
||||
rdr_t* p = _handleToPtr(h);
|
||||
void* vp = nullptr;
|
||||
const void* vp = nullptr;
|
||||
rc_t rc = _get(p, columnId, vp, nRef, dimVRef, kDoubleRdrFl );
|
||||
|
||||
vRef = rc!=kOkRC ? nullptr : static_cast<const double*>(vp);
|
||||
@ -1189,14 +1389,15 @@ cw::rc_t cw::dataset::rdr::test( const object_t* cfg )
|
||||
{
|
||||
rc_t rc = kOkRC;
|
||||
char* inFn = nullptr;
|
||||
unsigned cacheByteN = 128;
|
||||
handle_t h;
|
||||
|
||||
if((rc = cfg->getv("inFn",inFn)) != kOkRC )
|
||||
if((rc = cfg->getv("inFn",inFn,"cacheByteN",cacheByteN)) != kOkRC )
|
||||
return cwLogError(rc,"rdr test failed. Argument parse failed.");
|
||||
|
||||
inFn = filesys::expandPath(inFn);
|
||||
|
||||
if((rc = create(h,inFn)) != kOkRC )
|
||||
if((rc = create(h,inFn,cacheByteN,kShuffleFl)) != kOkRC )
|
||||
{
|
||||
rc = cwLogError(rc,"rdr create failed.");
|
||||
}
|
||||
@ -1222,6 +1423,7 @@ cw::rc_t cw::dataset::rdr::test( const object_t* cfg )
|
||||
destroy(h);
|
||||
}
|
||||
|
||||
mem::release(inFn);
|
||||
return rc;
|
||||
}
|
||||
|
||||
@ -1650,7 +1852,7 @@ namespace cw {
|
||||
}
|
||||
}
|
||||
|
||||
cw::rc_t cw::dataset::adapter::create( handle_t& hRef, const char* fn, unsigned maxBatchN )
|
||||
cw::rc_t cw::dataset::adapter::create( handle_t& hRef, const char* fn, unsigned maxBatchN, unsigned cacheByteN, unsigned flags )
|
||||
{
|
||||
rc_t rc = kOkRC;
|
||||
|
||||
@ -1659,7 +1861,7 @@ cw::rc_t cw::dataset::adapter::create( handle_t& hRef, const char* fn, unsigned
|
||||
|
||||
adapter_t* p = mem::allocZ<adapter_t>(1);
|
||||
|
||||
if((rc = rdr::create(p->rdrH,fn)) != kOkRC )
|
||||
if((rc = rdr::create(p->rdrH,fn,cacheByteN,flags)) != kOkRC )
|
||||
goto errLabel;
|
||||
|
||||
p->maxBatchN = maxBatchN;
|
||||
@ -1882,9 +2084,11 @@ cw::rc_t cw::dataset::adapter::print_field( handle_t h, unsigned fieldId, const
|
||||
|
||||
cw::rc_t cw::dataset::adapter::test( const object_t* cfg )
|
||||
{
|
||||
rc_t rc = kOkRC;
|
||||
char* inFn = nullptr;
|
||||
unsigned batchN = 0;
|
||||
rc_t rc = kOkRC;
|
||||
char* inFn = nullptr;
|
||||
unsigned batchN = 0;
|
||||
unsigned cacheByteN = 128;
|
||||
unsigned shuffleFl = rdr::kShuffleFl;
|
||||
handle_t h;
|
||||
|
||||
enum {
|
||||
@ -1893,13 +2097,13 @@ cw::rc_t cw::dataset::adapter::test( const object_t* cfg )
|
||||
};
|
||||
|
||||
// read the cfg args
|
||||
if((rc = cfg->getv("inFn",inFn,"batchN",batchN)) != kOkRC )
|
||||
if((rc = cfg->getv("inFn",inFn,"batchN",batchN,"cacheByteN",cacheByteN)) != kOkRC )
|
||||
return cwLogError(rc,"adapter test failed. Argument parse failed.");
|
||||
|
||||
inFn = filesys::expandPath(inFn);
|
||||
|
||||
// create the adapter
|
||||
if((rc = create(h, inFn, batchN)) != kOkRC )
|
||||
if((rc = create(h, inFn, batchN, cacheByteN, shuffleFl)) != kOkRC )
|
||||
{
|
||||
rc = cwLogError(rc,"Unable to create dataset adapter for '%s'.",inFn);
|
||||
goto errLabel;
|
||||
@ -2065,7 +2269,6 @@ namespace cw
|
||||
// read each example
|
||||
for(unsigned i=0; i<exampleN; ++i)
|
||||
{
|
||||
|
||||
// read the digit image label
|
||||
if((rc = read(fH, labelV[i])) != kOkRC )
|
||||
{
|
||||
@ -2187,6 +2390,7 @@ unsigned cw::dataset::mnist::record_count( handle_t h )
|
||||
return p->exampleN;
|
||||
}
|
||||
|
||||
|
||||
cw::rc_t cw::dataset::mnist::seek( handle_t h, unsigned exampleIdx )
|
||||
{
|
||||
rc_t rc = kOkRC;
|
||||
@ -2239,7 +2443,7 @@ cw::rc_t cw::dataset::mnist::write( handle_t h, const char* fn )
|
||||
|
||||
enum { kImagId, kNumbId };
|
||||
unsigned numbDimV[] = {1};
|
||||
unsigned imagDimV[] = {28,28};
|
||||
unsigned imagDimV[] = {kPixelRowN,kPixelColN};
|
||||
unsigned imagEleN = imagDimV[0]*imagDimV[1];
|
||||
|
||||
if((rc = define_columns( wtrH, "numb", kNumbId, cwCountOf(numbDimV), numbDimV )) != kOkRC )
|
||||
@ -2312,35 +2516,9 @@ cw::rc_t cw::dataset::mnist::test( const object_t* cfg )
|
||||
rc = cwLogError(rc,"SVG Test failed on create.");
|
||||
else
|
||||
{
|
||||
//const mtx::f_t* m = train(h);
|
||||
|
||||
/*
|
||||
unsigned zn = 0;
|
||||
unsigned i = 1;
|
||||
for(; i<m->dimV[1]; ++i)
|
||||
{
|
||||
const float* v0 = m->base + (28*28+1) * (i-1) + 1;
|
||||
const float* v1 = m->base + (28*28+1) * (i-0) + 1;
|
||||
float d = 0;
|
||||
|
||||
for(unsigned j=0; j<28*28; ++j)
|
||||
d += fabs(v0[j]-v1[j]);
|
||||
|
||||
if( d==0 )
|
||||
++zn;
|
||||
else
|
||||
{
|
||||
printf("%i %i %f\n",i,zn,d);
|
||||
zn = 0;
|
||||
}
|
||||
}
|
||||
|
||||
printf("i:%i n:%i zn:%i\n",i,m->dimV[1],zn);
|
||||
*/
|
||||
|
||||
const float* dataM = nullptr;
|
||||
const unsigned* labelV = nullptr;
|
||||
unsigned exampleN = 10;
|
||||
unsigned exampleN = 100;
|
||||
unsigned actualExampleN = 0;
|
||||
|
||||
//mnist::seek( h, 10 );
|
||||
@ -2349,13 +2527,13 @@ cw::rc_t cw::dataset::mnist::test( const object_t* cfg )
|
||||
|
||||
for(unsigned i=0; i<actualExampleN; ++i)
|
||||
{
|
||||
printf("label: %i\n", labelV[i] );
|
||||
printf("label: %i ", labelV[i] );
|
||||
|
||||
svg::offset(svgH, 0, i*30*5 );
|
||||
svg::image(svgH, dataM + (28*28)*i, 28, 28, 5, svg::kInvGrayScaleColorMapId);
|
||||
svg::image(svgH, dataM + (kPixelRowN*kPixelColN)*i, kPixelRowN, kPixelColN, 5, svg::kInvGrayScaleColorMapId);
|
||||
}
|
||||
|
||||
svg::write(svgH, outHtmlFn, nullptr, svg::kStandAloneFl | svg::kGenInlineStyleFl, 10,10,10,10);
|
||||
svg::write(svgH, outHtmlFn, nullptr, svg::kStandAloneFl, 10,10,10,10);
|
||||
|
||||
|
||||
svg::destroy(svgH);
|
||||
@ -2381,12 +2559,14 @@ cw::rc_t cw::dataset::test( const object_t* cfg )
|
||||
char* dsFn = nullptr;
|
||||
char* outHtmlFn = nullptr;
|
||||
mnist::handle_t mniH;
|
||||
adapter::handle_t rdrH;
|
||||
adapter::handle_t adpH;
|
||||
svg::handle_t svgH;
|
||||
unsigned batchN = 10;
|
||||
unsigned batchN = 100;
|
||||
unsigned cacheByteN = 4096 * 10;
|
||||
unsigned shuffleFl = rdr::kShuffleFl;
|
||||
|
||||
|
||||
if((rc = cfg->getv("inDir",inDir,"dsFn",dsFn,"outHtmlFn",outHtmlFn,"batchN",batchN)) != kOkRC )
|
||||
if((rc = cfg->getv("inDir",inDir,"dsFn",dsFn,"outHtmlFn",outHtmlFn,"batchN",batchN,"cacheByteN",cacheByteN)) != kOkRC )
|
||||
return cwLogError(rc,"MNIST test failed. Argument parse failed.");
|
||||
|
||||
inDir = filesys::expandPath(inDir);
|
||||
@ -2413,7 +2593,7 @@ cw::rc_t cw::dataset::test( const object_t* cfg )
|
||||
}
|
||||
|
||||
// open a dataset adapter
|
||||
if((rc = adapter::create(rdrH,dsFn,batchN)) != kOkRC )
|
||||
if((rc = adapter::create(adpH,dsFn,batchN,cacheByteN,shuffleFl)) != kOkRC )
|
||||
{
|
||||
cwLogError(rc,"Dataset reader create failed.");
|
||||
goto errLabel;
|
||||
@ -2428,49 +2608,64 @@ cw::rc_t cw::dataset::test( const object_t* cfg )
|
||||
|
||||
enum { kImagId, kNumbId };
|
||||
|
||||
// create dataset fields
|
||||
if((rc = create_field( rdrH, kImagId, adapter::kFloatFl, "imag" )) != kOkRC )
|
||||
// create a field for the image data
|
||||
if((rc = create_field( adpH, kImagId, adapter::kFloatFl, "imag" )) != kOkRC )
|
||||
{
|
||||
cwLogError(rc,"Dataset rdr column define failed.");
|
||||
goto errLabel;
|
||||
}
|
||||
|
||||
if((rc = create_field( rdrH, kNumbId, adapter::kIntFl, "numb" )) != kOkRC )
|
||||
{
|
||||
cwLogError(rc,"Dataset rdr column define failed.");
|
||||
goto errLabel;
|
||||
}
|
||||
|
||||
// read a batch of data
|
||||
if((rc = adapter::read( rdrH, batchN)) != kOkRC )
|
||||
{
|
||||
cwLogError(rc,"Batch read failed.");
|
||||
goto errLabel;
|
||||
}
|
||||
else
|
||||
{
|
||||
const int* numbV = nullptr;
|
||||
const unsigned* numbNV = nullptr;
|
||||
const float* imagV = nullptr;
|
||||
const unsigned* imagNV = nullptr;
|
||||
|
||||
adapter::get(rdrH, kNumbId, numbV, numbNV ); // get the labels
|
||||
adapter::get(rdrH, kImagId, imagV, imagNV ); // get the image data
|
||||
|
||||
for(unsigned i=0; i<batchN; ++i)
|
||||
{
|
||||
printf("label: %i\n", numbV[i] );
|
||||
|
||||
svg::offset(svgH, 0, i*30*5 );
|
||||
svg::image(svgH, imagV + (28*28)*i, 28, 28, 5, svg::kInvGrayScaleColorMapId);
|
||||
}
|
||||
|
||||
svg::write(svgH, outHtmlFn, nullptr, svg::kStandAloneFl | svg::kGenInlineStyleFl, 10,10,10,10);
|
||||
// create a field for the image lable
|
||||
if((rc = create_field( adpH, kNumbId, adapter::kIntFl, "numb" )) != kOkRC )
|
||||
{
|
||||
cwLogError(rc,"Dataset rdr column define failed.");
|
||||
goto errLabel;
|
||||
}
|
||||
|
||||
for(unsigned j=0,imageN=0; true; ++j )
|
||||
{
|
||||
// read a batch of data
|
||||
if((rc = adapter::read( adpH, batchN)) != kOkRC )
|
||||
{
|
||||
if( rc == kEofRC )
|
||||
cwLogInfo("Done!.");
|
||||
else
|
||||
cwLogError(rc,"Batch read failed.");
|
||||
goto errLabel;
|
||||
}
|
||||
else
|
||||
{
|
||||
const int* numbV = nullptr;
|
||||
const unsigned* numbNV = nullptr;
|
||||
const float* imagV = nullptr;
|
||||
const unsigned* imagNV = nullptr;
|
||||
const unsigned kPixelSize = 5;
|
||||
|
||||
adapter::get(adpH, kNumbId, numbV, numbNV ); // get the labels
|
||||
adapter::get(adpH, kImagId, imagV, imagNV ); // get the image data
|
||||
|
||||
printf("%3i : ",j);
|
||||
|
||||
// print the first 5 images from each batch to an SVG file
|
||||
for(unsigned i=0; i<0; ++i,++imageN)
|
||||
{
|
||||
printf("%i ", numbV[i] );
|
||||
|
||||
// offset the image vertically
|
||||
svg::offset(svgH, 0, imageN*30*kPixelSize );
|
||||
svg::image(svgH, imagV + (mnist::kPixelRowN*mnist::kPixelColN)*i, mnist::kPixelRowN, mnist::kPixelColN, kPixelSize, svg::kInvGrayScaleColorMapId);
|
||||
}
|
||||
printf("\n");
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
svg::write(svgH, outHtmlFn, nullptr, svg::kStandAloneFl, 10,10,10,10);
|
||||
|
||||
}
|
||||
}
|
||||
errLabel:
|
||||
adapter::destroy(rdrH);
|
||||
adapter::destroy(adpH);
|
||||
svg::destroy(svgH);
|
||||
mem::release(inDir);
|
||||
mem::release(dsFn);
|
||||
|
20
cwDataSets.h
20
cwDataSets.h
@ -152,8 +152,13 @@ namespace cw
|
||||
unsigned byteOffset; // Byte offset of the value of this field in the current record buffer.
|
||||
unsigned byteN; // Size of this field in bytes.
|
||||
} col_t;
|
||||
|
||||
|
||||
enum {
|
||||
kShuffleFl = 0x01
|
||||
};
|
||||
|
||||
rc_t create( handle_t& h, const char* fn );
|
||||
rc_t create( handle_t& h, const char* fn, unsigned cacheBufByteN, unsigned flags=kShuffleFl );
|
||||
rc_t destroy( handle_t& h );
|
||||
|
||||
unsigned column_count( handle_t h );
|
||||
@ -178,7 +183,7 @@ namespace cw
|
||||
// Read the next record.
|
||||
rc_t read( handle_t h, unsigned recordIdx=kInvalidIdx );
|
||||
|
||||
// Read a column value.
|
||||
// Get a column value from the last record returned by 'read()'.
|
||||
//
|
||||
// vRef = Pointer to the value vector.
|
||||
// nRef = Count of elements in value vector.
|
||||
@ -222,8 +227,7 @@ namespace cw
|
||||
unsigned rankN; // dimV[ rankN ] Rank of this column
|
||||
} colMap_t;
|
||||
|
||||
|
||||
rc_t create( handle_t& hRef, const char* fn, unsigned maxBatchN );
|
||||
rc_t create( handle_t& hRef, const char* fn, unsigned maxBatchN, unsigned cacheByteN, unsigned flags=rdr::kShuffleFl );
|
||||
rc_t destroy( handle_t& hRef );
|
||||
|
||||
// Create a field and assign it a column.
|
||||
@ -249,8 +253,7 @@ namespace cw
|
||||
rc_t get( handle_t h, unsigned fieldId, const float*& fV_Ref, const unsigned*& fNV_Ref );
|
||||
rc_t get( handle_t h, unsigned fieldId, const double*& fV_Ref, const unsigned*& fNV_Ref );
|
||||
|
||||
// Returns col position and geometry data from each record returned by the last
|
||||
// call to read().
|
||||
// Returns col position and geometry data from each record returned by the last call to read().
|
||||
// Returns colMapV_Ref[batchN][columnN].
|
||||
rc_t column_map( handle_t h, unsigned fieldId, colMap_t const * const *& colMapV_Ref );
|
||||
|
||||
@ -269,6 +272,11 @@ namespace cw
|
||||
{
|
||||
typedef handle<struct mnist_str> handle_t;
|
||||
|
||||
enum {
|
||||
kPixelRowN = 28,
|
||||
kPixelColN = 28
|
||||
};
|
||||
|
||||
rc_t create( handle_t& h, const char* inDir );
|
||||
rc_t destroy( handle_t& h );
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user