|
|
|
@ -101,49 +101,57 @@ public:
|
|
|
|
|
size_t outputHeight = outputs[0].shape()[2];
|
|
|
|
|
size_t outputWidth = outputs[0].shape()[3];
|
|
|
|
|
|
|
|
|
|
CHECK_EQ(inputChannels / groups_, inputs[1].shape()[1]);
|
|
|
|
|
|
|
|
|
|
real* inputData = inputs[0].data<real>();
|
|
|
|
|
real* filterData = inputs[1].data<real>();
|
|
|
|
|
real* outputData = outputs[0].data<real>();
|
|
|
|
|
|
|
|
|
|
size_t size =
|
|
|
|
|
inputChannels * filterHeight * filterWidth * outputHeight * outputWidth;
|
|
|
|
|
size_t size = inputChannels / groups_ * filterHeight * filterWidth *
|
|
|
|
|
outputHeight * outputWidth;
|
|
|
|
|
resizeBuffer(size);
|
|
|
|
|
real* colData = reinterpret_cast<real*>(memory_->getBuf());
|
|
|
|
|
|
|
|
|
|
Im2ColFunctor<real> im2col;
|
|
|
|
|
size_t inputOffset = (inputChannels / groups_) * inputHeight * inputWidth;
|
|
|
|
|
size_t outputOffset =
|
|
|
|
|
(outputChannels / groups_) * outputHeight * outputWidth;
|
|
|
|
|
size_t filterOffset = inputs[1].shape().getElements() / groups_;
|
|
|
|
|
for (size_t i = 0; i < batchSize; i++) {
|
|
|
|
|
im2col(inputData,
|
|
|
|
|
inputChannels,
|
|
|
|
|
inputHeight,
|
|
|
|
|
inputWidth,
|
|
|
|
|
filterHeight,
|
|
|
|
|
filterWidth,
|
|
|
|
|
strideH(),
|
|
|
|
|
strideW(),
|
|
|
|
|
paddingH(),
|
|
|
|
|
paddingW(),
|
|
|
|
|
outputHeight,
|
|
|
|
|
outputWidth,
|
|
|
|
|
colData);
|
|
|
|
|
|
|
|
|
|
int M = outputChannels;
|
|
|
|
|
int N = outputHeight * outputWidth;
|
|
|
|
|
int K = inputChannels * filterHeight * filterWidth;
|
|
|
|
|
gemm<real>(CblasNoTrans,
|
|
|
|
|
CblasNoTrans,
|
|
|
|
|
M,
|
|
|
|
|
N,
|
|
|
|
|
K,
|
|
|
|
|
1.0f,
|
|
|
|
|
filterData,
|
|
|
|
|
K,
|
|
|
|
|
colData,
|
|
|
|
|
N,
|
|
|
|
|
0.0f,
|
|
|
|
|
outputData,
|
|
|
|
|
N);
|
|
|
|
|
inputData += inputChannels * inputHeight * inputWidth;
|
|
|
|
|
outputData += outputChannels * outputHeight * outputWidth;
|
|
|
|
|
for (int g = 0; g < groups_; g++) {
|
|
|
|
|
im2col(inputData + g * inputOffset,
|
|
|
|
|
inputChannels / groups_,
|
|
|
|
|
inputHeight,
|
|
|
|
|
inputWidth,
|
|
|
|
|
filterHeight,
|
|
|
|
|
filterWidth,
|
|
|
|
|
strideH(),
|
|
|
|
|
strideW(),
|
|
|
|
|
paddingH(),
|
|
|
|
|
paddingW(),
|
|
|
|
|
outputHeight,
|
|
|
|
|
outputWidth,
|
|
|
|
|
colData);
|
|
|
|
|
|
|
|
|
|
int M = outputChannels;
|
|
|
|
|
int N = outputHeight * outputWidth;
|
|
|
|
|
int K = inputChannels * filterHeight * filterWidth;
|
|
|
|
|
gemm<real>(CblasNoTrans,
|
|
|
|
|
CblasNoTrans,
|
|
|
|
|
M,
|
|
|
|
|
N,
|
|
|
|
|
K,
|
|
|
|
|
1.0f,
|
|
|
|
|
filterData + g * filterOffset,
|
|
|
|
|
K,
|
|
|
|
|
colData,
|
|
|
|
|
N,
|
|
|
|
|
0.0f,
|
|
|
|
|
outputData + g * outputOffset,
|
|
|
|
|
N);
|
|
|
|
|
inputData += inputChannels * inputHeight * inputWidth;
|
|
|
|
|
outputData += outputChannels * outputHeight * outputWidth;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|