forked from btgraham/SparseConvNet-archived
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathNetworkInNetworkLayer.cu
More file actions
269 lines (254 loc) · 11.3 KB
/
NetworkInNetworkLayer.cu
File metadata and controls
269 lines (254 loc) · 11.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
#include "NetworkInNetworkLayer.h"
#include "cudaUtilities.h"
#include "SigmoidLayer.h"
#include <iostream>
#include <cassert>
__global__ void dShrinkMatrixForDropout
(float* m, float* md,
int* inFeaturesPresent, int* outFeaturesPresent,
int nOut, int nOutDropout) {
int i=blockIdx.x*nOutDropout;
int ii=inFeaturesPresent[blockIdx.x]*nOut;
for(int j=threadIdx.x; j<nOutDropout; j+=KERNELBLOCKSIZE) {
int jj=outFeaturesPresent[j];
md[i+j]=m[ii+jj];
}
}
__global__ void dShrinkVectorForDropout(float* m, float* md, int* outFeaturesPresent, int nOut, int nOutDropout) {
for(int i=threadIdx.x; i<nOutDropout; i+=NTHREADS) {
md[i]=m[outFeaturesPresent[i]];
}
}
__global__ void dGradientDescent
(float* d_delta, float* d_momentum, float* d_weights, int nOut, float learningRate,
float momentum) {
int i=blockIdx.x*nOut;
for(int j=i+threadIdx.x; j<i+nOut; j+=KERNELBLOCKSIZE) {
d_weights[j]-=d_momentum[j]*momentum;
d_momentum[j]=momentum*d_momentum[j]-learningRate*(1-momentum)*d_delta[j];
d_weights[j]=d_weights[j]+d_momentum[j]*(1+momentum);
}
}
__global__ void dGradientDescentShrunkMatrix
(float* d_delta, float* d_momentum, float* d_weights,
int nOut, int nOutDropout,
int* inFeaturesPresent, int* outFeaturesPresent,
float learningRate,
float momentum) {
int i=blockIdx.x*nOutDropout;
int ii=inFeaturesPresent[blockIdx.x]*nOut;
for(int j=threadIdx.x; j<nOutDropout; j+=KERNELBLOCKSIZE) {
int jj=outFeaturesPresent[j];
//NAG light
d_weights[ii+jj]-=d_momentum[ii+jj]*momentum;
d_momentum[ii+jj]=momentum*d_momentum[ii+jj]-learningRate*(1-momentum)*d_delta[i+j];
d_weights[ii+jj]=d_weights[ii+jj]+d_momentum[ii+jj]*(1+momentum);
}
}
__global__ void dGradientDescentShrunkVector
(float* d_delta, float* d_momentum, float* d_weights,
int nOut, int nOutDropout,
int* outFeaturesPresent,
float learningRate,
float momentum) {
for(int i=threadIdx.x; i<nOutDropout; i+=NTHREADS) {
int ii=outFeaturesPresent[i];
//NAG light
d_weights[ii]-=d_momentum[ii]*momentum;
d_momentum[ii]=momentum*d_momentum[ii]-learningRate*(1-momentum)*d_delta[i];
d_weights[ii]=d_weights[ii]+d_momentum[ii]*(1+momentum);
}
}
__global__ void dColumnSum
(float* matrix, float* target, int nRows, int nColumns) {
int i=blockIdx.x*KERNELBLOCKSIZE+threadIdx.x;
float t=0;
for (int j=blockIdx.y;j<nRows;j+=KERNELBLOCKSIZE)
t+=matrix[j*nColumns+i];
atomicAdd(&target[i],t);
}
void columnSum(float* matrix, float* target, int nRows, int nColumns) {
if (nColumns/KERNELBLOCKSIZE>0)
dColumnSum<<<dim3(nColumns/KERNELBLOCKSIZE,KERNELBLOCKSIZE),KERNELBLOCKSIZE,0,cnnMemStream->stream>>>(matrix, target, nRows, nColumns);
if (nColumns%KERNELBLOCKSIZE>0) {
int o=nColumns/KERNELBLOCKSIZE*KERNELBLOCKSIZE;
dColumnSum<<<dim3(1,KERNELBLOCKSIZE),nColumns-o,0,cnnMemStream->stream>>>(matrix+o, target+o, nRows, nColumns);
}
cudaCheckError();
}
__global__ void dReplicateArray
(float* src, float* dst, int nColumns) {
int i=blockIdx.x*nColumns;
for (int j=threadIdx.x;j<nColumns;j+=KERNELBLOCKSIZE)
dst[i+j]=src[j];
}
void replicateArray(float* src, float* dst, int nRows, int nColumns) {
int processed=0;
while (processed<nRows) {
int batch=min(1024,nRows-processed); //////////////////////////////////////
dReplicateArray<<<batch,KERNELBLOCKSIZE,0,cnnMemStream->stream>>>
(src, dst+processed*nColumns, nColumns);
processed+=batch;
}
cudaCheckError();
}
NetworkInNetworkLayer::NetworkInNetworkLayer(int nFeaturesIn, int nFeaturesOut,
float dropout,ActivationFunction fn,
float alpha//used to determine intialization weights only
) :
nFeaturesIn(nFeaturesIn), nFeaturesOut(nFeaturesOut),
dropout(dropout), fn(fn),
W(true,nFeaturesIn*nFeaturesOut), MW(true,nFeaturesIn*nFeaturesOut),
B(true,nFeaturesOut), MB(true,nFeaturesOut) {
float scale=pow(6.0f/(nFeaturesIn+nFeaturesOut*alpha),0.5f);
W.setUniform(-scale,scale);
MW.setZero();
B.setZero();
MB.setZero();
std::cout << "Learn " << nFeaturesIn << "->" << nFeaturesOut << " dropout=" << dropout << " " << sigmoidNames[fn] << std::endl;
}
void NetworkInNetworkLayer::preprocess
(SpatiallySparseBatch &batch,
SpatiallySparseBatchInterface &input,
SpatiallySparseBatchInterface &output) {
assert(input.nFeatures==nFeaturesIn);
output.nFeatures=nFeaturesOut;
output.spatialSize=input.spatialSize;
output.nSpatialSites=input.nSpatialSites;
output.grids=input.grids;
int o=nFeaturesOut*(batch.type==TRAINBATCH?(1.0f-dropout):1.0f);
output.featuresPresent.hVector()=rng.NchooseM(nFeaturesOut,o);
output.backpropErrors=true;
}
void NetworkInNetworkLayer::forwards
(SpatiallySparseBatch &batch,
SpatiallySparseBatchInterface &input,
SpatiallySparseBatchInterface &output) {
output.sub->features.resize(output.nSpatialSites*output.featuresPresent.size());
if (batch.type==TRAINBATCH and
nFeaturesIn+nFeaturesOut>input.featuresPresent.size()+output.featuresPresent.size()) {
w.resize(input.featuresPresent.size()*output.featuresPresent.size());
dShrinkMatrixForDropout<<<input.featuresPresent.size(),KERNELBLOCKSIZE,0,cnnMemStream->stream>>>
(W.dPtr(), w.dPtr(),
input.featuresPresent.dPtr(),
output.featuresPresent.dPtr(),
output.nFeatures,
output.featuresPresent.size());
cudaCheckError();
b.resize(output.featuresPresent.size());
dShrinkVectorForDropout<<<1,NTHREADS,0,cnnMemStream->stream>>>(B.dPtr(), b.dPtr(),
output.featuresPresent.dPtr(),
output.nFeatures,
output.featuresPresent.size());
cudaCheckError();
replicateArray(b.dPtr(), output.sub->features.dPtr(), output.nSpatialSites, output.featuresPresent.size());
cudaCheckError();
d_rowMajorSGEMM_alphaAB_betaC(cublasHandle,
input.sub->features.dPtr(), w.dPtr(), output.sub->features.dPtr(),
output.nSpatialSites, input.featuresPresent.size(), output.featuresPresent.size(),
1.0f, 1.0f,__FILE__,__LINE__);
cudaCheckError();
} else {
replicateArray(B.dPtr(), output.sub->features.dPtr(), output.nSpatialSites, output.featuresPresent.size());
d_rowMajorSGEMM_alphaAB_betaC(cublasHandle,
input.sub->features.dPtr(), W.dPtr(), output.sub->features.dPtr(),
output.nSpatialSites, input.nFeatures, output.nFeatures,
1.0f-dropout, 1.0f-dropout,__FILE__,__LINE__);
cudaCheckError();
}
multiplyAddCount+=(__int128_t)output.nSpatialSites*input.featuresPresent.size()*output.featuresPresent.size();
applySigmoid(output, output, fn);
cudaCheckError();
}
void NetworkInNetworkLayer::scaleWeights
(SpatiallySparseBatchInterface &input,
SpatiallySparseBatchInterface &output,
float& scalingUnderneath,
bool topLayer) {
assert(input.sub->features.size()>0);
assert(output.sub->features.size()>0); //call after forwards(...)
float scale=output.sub->features.meanAbs( (fn==VLEAKYRELU) ? 3 : 100 );
std::cout << "featureScale:" << scale << std::endl;
if (topLayer) {
scale=1;
} else {
scale=powf(scale,-0.1); //0.7978846 = sqrt(2/pi) = mean of the half normal distribution
}
W.multiplicativeRescale(scale/scalingUnderneath);
B.multiplicativeRescale(scale);
MW.multiplicativeRescale(scale/scalingUnderneath);
MB.multiplicativeRescale(scale);
scalingUnderneath=scale;
}
void NetworkInNetworkLayer::backwards
(SpatiallySparseBatch &batch,
SpatiallySparseBatchInterface &input,
SpatiallySparseBatchInterface &output,
float learningRate,
float momentum) {
applySigmoidBackProp(output, output, fn);
dw.resize(input.featuresPresent.size()*output.featuresPresent.size());
db.resize(output.featuresPresent.size());
d_rowMajorSGEMM_alphaAtB_betaC(cublasHandle,
input.sub->features.dPtr(), output.sub->dfeatures.dPtr(), dw.dPtr(),
input.featuresPresent.size(), output.nSpatialSites, output.featuresPresent.size(),
1.0, 0.0);
multiplyAddCount+=(__int128_t)output.nSpatialSites*input.featuresPresent.size()*output.featuresPresent.size();
cudaCheckError();
db.setZero(*cnnMemStream);
columnSum(output.sub->dfeatures.dPtr(), db.dPtr(), output.nSpatialSites, output.featuresPresent.size());
if (nFeaturesIn+nFeaturesOut>input.featuresPresent.size()+output.featuresPresent.size()) {
if (input.backpropErrors) {
input.sub->dfeatures.resize(input.nSpatialSites*input.featuresPresent.size());
d_rowMajorSGEMM_alphaABt_betaC(cublasHandle,
output.sub->dfeatures.dPtr(), w.dPtr(), input.sub->dfeatures.dPtr(),
output.nSpatialSites,output.featuresPresent.size(),input.featuresPresent.size(),
1.0, 0.0);
multiplyAddCount+=(__int128_t)output.nSpatialSites*input.featuresPresent.size()*output.featuresPresent.size();
cudaCheckError();
}
dGradientDescentShrunkMatrix<<<input.featuresPresent.size(),KERNELBLOCKSIZE,0,cnnMemStream->stream>>>
(dw.dPtr(), MW.dPtr(), W.dPtr(),
output.nFeatures, output.featuresPresent.size(),
input.featuresPresent.dPtr(), output.featuresPresent.dPtr(),
learningRate,momentum);
cudaCheckError();
dGradientDescentShrunkVector<<<1,NTHREADS,0,cnnMemStream->stream>>>
(db.dPtr(), MB.dPtr(), B.dPtr(),
output.nFeatures, output.featuresPresent.size(),
output.featuresPresent.dPtr(),
learningRate,momentum);
cudaCheckError();
} else {
if (input.backpropErrors) {
input.sub->dfeatures.resize(input.nSpatialSites*input.featuresPresent.size());
d_rowMajorSGEMM_alphaABt_betaC(cublasHandle,
output.sub->dfeatures.dPtr(), W.dPtr(), input.sub->dfeatures.dPtr(),
output.nSpatialSites,nFeaturesOut,nFeaturesIn,
1.0, 0.0);
multiplyAddCount+=(__int128_t)output.nSpatialSites*input.featuresPresent.size()*output.featuresPresent.size();
cudaCheckError();
}
dGradientDescent<<<nFeaturesIn,KERNELBLOCKSIZE,0,cnnMemStream->stream>>>
(dw.dPtr(), MW.dPtr(), W.dPtr(), nFeaturesOut, learningRate,momentum);
cudaCheckError();
dGradientDescent<<<1,KERNELBLOCKSIZE,0,cnnMemStream->stream>>>
(db.dPtr(), MB.dPtr(), B.dPtr(), nFeaturesOut, learningRate,momentum);
cudaCheckError();
}
// std::cout << __LINE__ << " "<<input.sub->dfeatures.meanAbs() << "\n";
// std::cout << __LINE__ << " "<<W.meanAbs() << "\n";
}
void NetworkInNetworkLayer::loadWeightsFromStream(std::ifstream &f) {
f.read((char*)&W.hVector()[0],sizeof(float)*W.size());
f.read((char*)&B.hVector()[0],sizeof(float)*B.size());
MW.setZero();
MB.setZero();
};
void NetworkInNetworkLayer::putWeightsToStream(std::ofstream &f) {
f.write((char*)&W.hVector()[0],sizeof(float)*W.size());
f.write((char*)&B.hVector()[0],sizeof(float)*B.size());
};
int NetworkInNetworkLayer::calculateInputSpatialSize(int outputSpatialSize) {
return outputSpatialSize;
}