forked from btgraham/SparseConvNet-archived
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmnist.cpp
More file actions
58 lines (51 loc) · 1.93 KB
/
mnist.cpp
File metadata and controls
58 lines (51 loc) · 1.93 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
#include "SparseConvNet.h"
#include "NetworkArchitectures.h"
#include "SpatiallySparseDatasetMnist.h"
int epoch=0;
int cudaDevice=-1; //PCI bus ID, -1 for default GPU
int batchSize=100;
Picture* OpenCVPicture::distort(RNG& rng, batchType type) {
OpenCVPicture* pic=new OpenCVPicture(*this);
return pic; //No data augmentation
// pic->loadData();
// if (type==TRAINBATCH)
// pic->jiggle(rng,2);
//return pic;
}
class CNN : public SparseConvNet {
public:
CNN (int dimension, int nInputFeatures, int nClasses, float p=0.0f, int cudaDevice=-1, int nTop=1);
};
CNN::CNN
(int dimension, int nInputFeatures, int nClasses, float p, int cudaDevice, int nTop)
: SparseConvNet(dimension,nInputFeatures, nClasses, cudaDevice, nTop) {
int l=0;
addLeNetLayerPOFMP(32*(++l),2,1,2,powf(2,0.5),VLEAKYRELU,0);
addLeNetLayerPOFMP(32*(++l),2,1,2,powf(2,0.5),VLEAKYRELU,0);
addLeNetLayerPOFMP(32*(++l),2,1,2,powf(2,0.5),VLEAKYRELU,0);
addLeNetLayerPOFMP(32*(++l),2,1,2,powf(2,0.5),VLEAKYRELU,0.1);
addLeNetLayerPOFMP(32*(++l),2,1,2,powf(2,0.5),VLEAKYRELU,0.2);
addLeNetLayerPOFMP(32*(++l),2,1,2,powf(2,0.5),VLEAKYRELU,0.3);
addLeNetLayerMP (32*(++l),2,1,1,1, VLEAKYRELU,0.4);
addLeNetLayerMP (32*(++l),1,1,1,1, VLEAKYRELU,0.5);
addSoftmaxLayer();
}
int main() {
std::string baseName="weights/mnist";
SpatiallySparseDataset trainSet=MnistTrainSet();
SpatiallySparseDataset testSet=MnistTestSet();
trainSet.summary();
testSet.summary();
CNN cnn(2,trainSet.nFeatures,trainSet.nClasses,0.0f,cudaDevice);
//DeepCNet cnn(2,5,32,VLEAKYRELU,trainSet.nFeatures,trainSet.nClasses,0.0f,cudaDevice);
if (epoch>0)
cnn.loadWeights(baseName,epoch);
for (epoch++;;epoch++) {
std::cout <<"epoch: " << epoch << " " << std::flush;
cnn.processDataset(trainSet, batchSize,0.003*exp(-0.01 * epoch));
if (epoch%10==0) {
cnn.saveWeights(baseName,epoch);
cnn.processDataset(testSet, batchSize);
}
}
}