forked from btgraham/SparseConvNet-archived
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathSoftmaxClassifier.cu
More file actions
57 lines (49 loc) · 2.04 KB
/
SoftmaxClassifier.cu
File metadata and controls
57 lines (49 loc) · 2.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
#include "SoftmaxClassifier.h"
#include "cudaUtilities.h"
#include <iostream>
#include <vector>
#include <cassert>
#include "utilities.h"
__global__ void dDerivativeOfCostWRTpreSoftmaxTopLevelWeights
(int batchSize, float* topDelta, float* topGrid, int* labels, int N) {
for (int k=0;k<batchSize;k++) {
for(int i=threadIdx.x;i<N;i+=NTHREADS) {
topDelta[k*N+i]=topGrid[k*N+i]-(i==labels[k]);
}
}
}
int outCtr=0;
void SoftmaxClassifier(SpatiallySparseBatchInterface& input, SpatiallySparseBatch& batch, int nTop) {
//Assume no dropout in the output layer! nClasses:=input.nFeatures.
assert(batch.batchSize==input.nSpatialSites);
assert(input.nFeatures==input.featuresPresent.size());
if (batch.type==TRAINBATCH) {//Begin backprop. Top layer: d Cost / d SoftmaxInput
input.sub->dfeatures.resize(input.nSpatialSites*input.featuresPresent.size());
dDerivativeOfCostWRTpreSoftmaxTopLevelWeights<<<1,NTHREADS,0,cnnMemStream->stream>>>
(batch.batchSize, input.sub->dfeatures.dPtr(), input.sub->features.dPtr(),
batch.labels.dPtr(), input.nFeatures);
}
input.sub->features.copyToCPUAsync(*cnnMemStream);
batch.labels.copyToCPUAsync(*cnnMemStream);
float* probs=&input.sub->features.hVector()[0];
for (int i=0;i<batch.batchSize;++i)
batch.probabilities.push_back(std::vector<float> (probs+i*input.nFeatures,probs+(i+1)*input.nFeatures));
for (int i=0;i<batch.batchSize;i++)
batch.predictions.push_back(vectorTopIndices(batch.probabilities[i],nTop));
if (batch.type!=UNLABELEDBATCH) {
batch.mistakes+=batch.batchSize;
for (int i=0;i<batch.batchSize;i++) {
batch.negativeLogLikelihood-=log(max(batch.probabilities[i][batch.labels.hVector()[i]],1.0e-15));
for (int j=0;j<nTop;j++) {
if (batch.predictions[i][j]==batch.labels.hVector()[i]) {
batch.mistakes--;
}
}
}
}
// std::cout <<batch.mistakes << " "<< std::flush;
input.sub->features.copyToGPUAsync(*cnnMemStream);
cudaCheckError();
// input.sub->features.check();
// input.sub->dfeatures.check();
}