|
|
@ -66,16 +66,23 @@ public:
|
|
|
|
real* inputData = inputs[0].data<real>();
|
|
|
|
real* inputData = inputs[0].data<real>();
|
|
|
|
real* filterData = inputs[1].data<real>();
|
|
|
|
real* filterData = inputs[1].data<real>();
|
|
|
|
real* outputData = outputs[0].data<real>();
|
|
|
|
real* outputData = outputs[0].data<real>();
|
|
|
|
|
|
|
|
bool needIm2col = isNeedIm2col(filter);
|
|
|
|
|
|
|
|
|
|
|
|
TensorShape imShape =
|
|
|
|
TensorShape imShape =
|
|
|
|
TensorShape({inputChannels / groups_, inputHeight, inputWidth});
|
|
|
|
TensorShape({inputChannels / groups_, inputHeight, inputWidth});
|
|
|
|
TensorShape colShape = TensorShape({inputChannels / groups_,
|
|
|
|
|
|
|
|
filterHeight,
|
|
|
|
|
|
|
|
filterWidth,
|
|
|
|
|
|
|
|
outputHeight,
|
|
|
|
|
|
|
|
outputWidth});
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
resizeBuffer<Device>(colShape.getElements());
|
|
|
|
TensorShape colShape;
|
|
|
|
real* colData = reinterpret_cast<real*>(memory_->getBuf());
|
|
|
|
real* colData = NULL;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (needIm2col) {
|
|
|
|
|
|
|
|
colShape = TensorShape({inputChannels / groups_,
|
|
|
|
|
|
|
|
filterHeight,
|
|
|
|
|
|
|
|
filterWidth,
|
|
|
|
|
|
|
|
outputHeight,
|
|
|
|
|
|
|
|
outputWidth});
|
|
|
|
|
|
|
|
resizeBuffer<Device>(colShape.getElements());
|
|
|
|
|
|
|
|
colData = reinterpret_cast<real*>(memory_->getBuf());
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
Im2ColFunctor<kCFO, Device, real> im2col;
|
|
|
|
Im2ColFunctor<kCFO, Device, real> im2col;
|
|
|
|
GemmFunctor<Device, real> gemm;
|
|
|
|
GemmFunctor<Device, real> gemm;
|
|
|
@ -86,15 +93,18 @@ public:
|
|
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < batchSize; i++) {
|
|
|
|
for (size_t i = 0; i < batchSize; i++) {
|
|
|
|
for (size_t g = 0; g < groups_; g++) {
|
|
|
|
for (size_t g = 0; g < groups_; g++) {
|
|
|
|
im2col(inputData + g * inputOffset,
|
|
|
|
if (needIm2col) {
|
|
|
|
imShape,
|
|
|
|
im2col(inputData + g * inputOffset,
|
|
|
|
colData,
|
|
|
|
imShape,
|
|
|
|
colShape,
|
|
|
|
colData,
|
|
|
|
strideH(),
|
|
|
|
colShape,
|
|
|
|
strideW(),
|
|
|
|
strideH(),
|
|
|
|
paddingH(),
|
|
|
|
strideW(),
|
|
|
|
paddingW());
|
|
|
|
paddingH(),
|
|
|
|
|
|
|
|
paddingW());
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
colData = inputData + g * inputOffset;
|
|
|
|
|
|
|
|
}
|
|
|
|
int M = outputChannels / groups_;
|
|
|
|
int M = outputChannels / groups_;
|
|
|
|
int N = outputHeight * outputWidth;
|
|
|
|
int N = outputHeight * outputWidth;
|
|
|
|
int K = inputChannels / groups_ * filterHeight * filterWidth;
|
|
|
|
int K = inputChannels / groups_ * filterHeight * filterWidth;
|
|
|
@ -159,19 +169,27 @@ public:
|
|
|
|
real* outputGrad = inputs[0].data<real>();
|
|
|
|
real* outputGrad = inputs[0].data<real>();
|
|
|
|
real* filterData = inputs[1].data<real>();
|
|
|
|
real* filterData = inputs[1].data<real>();
|
|
|
|
real* inputGrad = outputs[0].data<real>();
|
|
|
|
real* inputGrad = outputs[0].data<real>();
|
|
|
|
|
|
|
|
bool needIm2col = isNeedIm2col(filter);
|
|
|
|
|
|
|
|
|
|
|
|
TensorShape imShape =
|
|
|
|
TensorShape imShape =
|
|
|
|
TensorShape({inputChannels / groups_, inputHeight, inputWidth});
|
|
|
|
TensorShape({inputChannels / groups_, inputHeight, inputWidth});
|
|
|
|
TensorShape colShape = TensorShape({inputChannels / groups_,
|
|
|
|
|
|
|
|
filterHeight,
|
|
|
|
|
|
|
|
filterWidth,
|
|
|
|
|
|
|
|
outputHeight,
|
|
|
|
|
|
|
|
outputWidth});
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
resizeBuffer<Device>(colShape.getElements());
|
|
|
|
TensorShape colShape;
|
|
|
|
real* colData = reinterpret_cast<real*>(memory_->getBuf());
|
|
|
|
real* colData = NULL;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (needIm2col) {
|
|
|
|
|
|
|
|
colShape = TensorShape({inputChannels / groups_,
|
|
|
|
|
|
|
|
filterHeight,
|
|
|
|
|
|
|
|
filterWidth,
|
|
|
|
|
|
|
|
outputHeight,
|
|
|
|
|
|
|
|
outputWidth});
|
|
|
|
|
|
|
|
resizeBuffer<Device>(colShape.getElements());
|
|
|
|
|
|
|
|
colData = reinterpret_cast<real*>(memory_->getBuf());
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
Col2ImFunctor<kCFO, Device, real> col2im;
|
|
|
|
Col2ImFunctor<kCFO, Device, real> col2im;
|
|
|
|
GemmFunctor<Device, real> gemm;
|
|
|
|
GemmFunctor<Device, real> gemm;
|
|
|
|
|
|
|
|
|
|
|
|
size_t inputOffset = imShape.getElements();
|
|
|
|
size_t inputOffset = imShape.getElements();
|
|
|
|
size_t outputOffset =
|
|
|
|
size_t outputOffset =
|
|
|
|
(outputChannels / groups_) * outputHeight * outputWidth;
|
|
|
|
(outputChannels / groups_) * outputHeight * outputWidth;
|
|
|
@ -182,6 +200,11 @@ public:
|
|
|
|
int K = outputChannels / groups_;
|
|
|
|
int K = outputChannels / groups_;
|
|
|
|
int N = outputHeight * outputWidth;
|
|
|
|
int N = outputHeight * outputWidth;
|
|
|
|
int M = inputChannels / groups_ * filterHeight * filterWidth;
|
|
|
|
int M = inputChannels / groups_ * filterHeight * filterWidth;
|
|
|
|
|
|
|
|
real scale = 0.0f;
|
|
|
|
|
|
|
|
if (!needIm2col) {
|
|
|
|
|
|
|
|
colData = inputGrad + g * inputOffset;
|
|
|
|
|
|
|
|
scale = 1.0f;
|
|
|
|
|
|
|
|
}
|
|
|
|
gemm(CblasTrans,
|
|
|
|
gemm(CblasTrans,
|
|
|
|
CblasNoTrans,
|
|
|
|
CblasNoTrans,
|
|
|
|
M,
|
|
|
|
M,
|
|
|
@ -192,17 +215,19 @@ public:
|
|
|
|
M,
|
|
|
|
M,
|
|
|
|
outputGrad + g * outputOffset,
|
|
|
|
outputGrad + g * outputOffset,
|
|
|
|
N,
|
|
|
|
N,
|
|
|
|
0.0f,
|
|
|
|
scale,
|
|
|
|
colData,
|
|
|
|
colData,
|
|
|
|
N);
|
|
|
|
N);
|
|
|
|
col2im(inputGrad + g * inputOffset,
|
|
|
|
if (needIm2col) {
|
|
|
|
imShape,
|
|
|
|
col2im(inputGrad + g * inputOffset,
|
|
|
|
colData,
|
|
|
|
imShape,
|
|
|
|
colShape,
|
|
|
|
colData,
|
|
|
|
strideH(),
|
|
|
|
colShape,
|
|
|
|
strideW(),
|
|
|
|
strideH(),
|
|
|
|
paddingH(),
|
|
|
|
strideW(),
|
|
|
|
paddingW());
|
|
|
|
paddingH(),
|
|
|
|
|
|
|
|
paddingW());
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
inputGrad += inputChannels * inputHeight * inputWidth;
|
|
|
|
inputGrad += inputChannels * inputHeight * inputWidth;
|
|
|
|
outputGrad += outputChannels * outputHeight * outputWidth;
|
|
|
|
outputGrad += outputChannels * outputHeight * outputWidth;
|
|
|
@ -255,16 +280,23 @@ public:
|
|
|
|
real* outputGrad = inputs[0].data<real>();
|
|
|
|
real* outputGrad = inputs[0].data<real>();
|
|
|
|
real* inputData = inputs[1].data<real>();
|
|
|
|
real* inputData = inputs[1].data<real>();
|
|
|
|
real* filterGrad = outputs[0].data<real>();
|
|
|
|
real* filterGrad = outputs[0].data<real>();
|
|
|
|
|
|
|
|
bool needIm2col = isNeedIm2col(filter);
|
|
|
|
|
|
|
|
|
|
|
|
TensorShape imShape =
|
|
|
|
TensorShape imShape =
|
|
|
|
TensorShape({inputChannels / groups_, inputHeight, inputWidth});
|
|
|
|
TensorShape({inputChannels / groups_, inputHeight, inputWidth});
|
|
|
|
TensorShape colShape = TensorShape({inputChannels / groups_,
|
|
|
|
|
|
|
|
filterHeight,
|
|
|
|
|
|
|
|
filterWidth,
|
|
|
|
|
|
|
|
outputHeight,
|
|
|
|
|
|
|
|
outputWidth});
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
resizeBuffer<Device>(colShape.getElements());
|
|
|
|
TensorShape colShape;
|
|
|
|
real* colData = reinterpret_cast<real*>(memory_->getBuf());
|
|
|
|
real* colData = NULL;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (needIm2col) {
|
|
|
|
|
|
|
|
colShape = TensorShape({inputChannels / groups_,
|
|
|
|
|
|
|
|
filterHeight,
|
|
|
|
|
|
|
|
filterWidth,
|
|
|
|
|
|
|
|
outputHeight,
|
|
|
|
|
|
|
|
outputWidth});
|
|
|
|
|
|
|
|
resizeBuffer<Device>(colShape.getElements());
|
|
|
|
|
|
|
|
colData = reinterpret_cast<real*>(memory_->getBuf());
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
Im2ColFunctor<kCFO, Device, real> im2col;
|
|
|
|
Im2ColFunctor<kCFO, Device, real> im2col;
|
|
|
|
GemmFunctor<Device, real> gemm;
|
|
|
|
GemmFunctor<Device, real> gemm;
|
|
|
@ -274,15 +306,18 @@ public:
|
|
|
|
size_t filterOffset = filter.getElements() / groups_;
|
|
|
|
size_t filterOffset = filter.getElements() / groups_;
|
|
|
|
for (size_t i = 0; i < batchSize; i++) {
|
|
|
|
for (size_t i = 0; i < batchSize; i++) {
|
|
|
|
for (size_t g = 0; g < groups_; g++) {
|
|
|
|
for (size_t g = 0; g < groups_; g++) {
|
|
|
|
im2col(inputData + g * inputOffset,
|
|
|
|
if (needIm2col) {
|
|
|
|
imShape,
|
|
|
|
im2col(inputData + g * inputOffset,
|
|
|
|
colData,
|
|
|
|
imShape,
|
|
|
|
colShape,
|
|
|
|
colData,
|
|
|
|
strideH(),
|
|
|
|
colShape,
|
|
|
|
strideW(),
|
|
|
|
strideH(),
|
|
|
|
paddingH(),
|
|
|
|
strideW(),
|
|
|
|
paddingW());
|
|
|
|
paddingH(),
|
|
|
|
|
|
|
|
paddingW());
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
colData = inputData + g * inputOffset;
|
|
|
|
|
|
|
|
}
|
|
|
|
int M = outputChannels / groups_;
|
|
|
|
int M = outputChannels / groups_;
|
|
|
|
int K = outputHeight * outputWidth;
|
|
|
|
int K = outputHeight * outputWidth;
|
|
|
|
int N = inputChannels / groups_ * filterHeight * filterWidth;
|
|
|
|
int N = inputChannels / groups_ * filterHeight * filterWidth;
|
|
|
|