2020-08-20 00:12:51 +00:00
|
|
|
#ifndef cwNN_H
|
|
|
|
#define cwNN_H
|
|
|
|
|
|
|
|
namespace cw
|
|
|
|
{
|
|
|
|
namespace nn
|
|
|
|
{
|
|
|
|
typedef handle<struct nn_str> handle_t;
|
|
|
|
|
|
|
|
enum
|
|
|
|
{
|
|
|
|
kSigmoidActId,
|
|
|
|
kReluActId
|
|
|
|
};
|
|
|
|
|
|
|
|
enum
|
|
|
|
{
|
2020-10-30 13:40:39 +00:00
|
|
|
kInputLayerTId,
|
|
|
|
kDenseLayerTId,
|
|
|
|
kConv1DConvTId
|
2020-08-20 00:12:51 +00:00
|
|
|
};
|
|
|
|
|
|
|
|
enum
|
|
|
|
{
|
|
|
|
kZeroInitId,
|
|
|
|
kUniformInitId,
|
|
|
|
kNormalInitId
|
|
|
|
};
|
|
|
|
|
2020-10-30 13:40:39 +00:00
|
|
|
|
|
|
|
|
|
|
|
typedef struct train_args_str
|
2020-08-20 00:12:51 +00:00
|
|
|
{
|
2020-10-30 13:40:39 +00:00
|
|
|
unsigned epochN;
|
|
|
|
unsigned batchN;
|
|
|
|
double eta;
|
|
|
|
double lambda;
|
|
|
|
|
|
|
|
} train_args_t;
|
2020-08-20 00:12:51 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
2020-10-30 13:40:39 +00:00
|
|
|
rc_t create( handle_t& h, const object_t& cfg );
|
2020-08-20 00:12:51 +00:00
|
|
|
rc_t destroy( handle_t& h );
|
|
|
|
|
2020-10-30 13:40:39 +00:00
|
|
|
rc_t train( handle_t h, dataset::handle_t dsH, const train_args_t& args );
|
|
|
|
|
|
|
|
rc_t test( handle_t h, dataset::handle_t dsH );
|
|
|
|
|
2020-08-20 00:12:51 +00:00
|
|
|
|
2020-10-30 13:40:39 +00:00
|
|
|
rc_t test( const char* mnistDir );
|
2020-08-20 00:12:51 +00:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
#endif
|
|
|
|
|