|
|
|
@ -101,8 +101,6 @@ public:
|
|
|
|
|
size_t outputHeight = outputs[0].shape()[2];
|
|
|
|
|
size_t outputWidth = outputs[0].shape()[3];
|
|
|
|
|
|
|
|
|
|
CHECK_EQ(inputChannels / groups_, inputs[1].shape()[1]);
|
|
|
|
|
|
|
|
|
|
real* inputData = inputs[0].data<real>();
|
|
|
|
|
real* filterData = inputs[1].data<real>();
|
|
|
|
|
real* outputData = outputs[0].data<real>();
|
|
|
|
@ -134,9 +132,9 @@ public:
|
|
|
|
|
outputWidth,
|
|
|
|
|
colData);
|
|
|
|
|
|
|
|
|
|
int M = outputChannels;
|
|
|
|
|
int M = outputChannels / groups_;
|
|
|
|
|
int N = outputHeight * outputWidth;
|
|
|
|
|
int K = inputChannels * filterHeight * filterWidth;
|
|
|
|
|
int K = inputChannels / groups_ * filterHeight * filterWidth;
|
|
|
|
|
gemm(M,
|
|
|
|
|
N,
|
|
|
|
|
K,
|
|
|
|
|