|
|
|
@ -85,7 +85,6 @@ public:
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Im2ColFunctor<kCFO, Device, real> im2col;
|
|
|
|
|
GemmFunctor<Device, real> gemm;
|
|
|
|
|
size_t inputOffset = imShape.getElements();
|
|
|
|
|
size_t outputOffset =
|
|
|
|
|
(outputChannels / groups_) * outputHeight * outputWidth;
|
|
|
|
@ -108,8 +107,8 @@ public:
|
|
|
|
|
int M = outputChannels / groups_;
|
|
|
|
|
int N = outputHeight * outputWidth;
|
|
|
|
|
int K = inputChannels / groups_ * filterHeight * filterWidth;
|
|
|
|
|
gemm(CblasNoTrans,
|
|
|
|
|
CblasNoTrans,
|
|
|
|
|
BlasGemm<Device, real>::compute(false,
|
|
|
|
|
false,
|
|
|
|
|
M,
|
|
|
|
|
N,
|
|
|
|
|
K,
|
|
|
|
@ -188,8 +187,6 @@ public:
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Col2ImFunctor<kCFO, Device, real> col2im;
|
|
|
|
|
GemmFunctor<Device, real> gemm;
|
|
|
|
|
|
|
|
|
|
size_t inputOffset = imShape.getElements();
|
|
|
|
|
size_t outputOffset =
|
|
|
|
|
(outputChannels / groups_) * outputHeight * outputWidth;
|
|
|
|
@ -205,8 +202,8 @@ public:
|
|
|
|
|
colData = inputGrad + g * inputOffset;
|
|
|
|
|
scale = 1.0f;
|
|
|
|
|
}
|
|
|
|
|
gemm(CblasTrans,
|
|
|
|
|
CblasNoTrans,
|
|
|
|
|
BlasGemm<Device, real>::compute(true,
|
|
|
|
|
false,
|
|
|
|
|
M,
|
|
|
|
|
N,
|
|
|
|
|
K,
|
|
|
|
@ -299,7 +296,6 @@ public:
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Im2ColFunctor<kCFO, Device, real> im2col;
|
|
|
|
|
GemmFunctor<Device, real> gemm;
|
|
|
|
|
size_t inputOffset = imShape.getElements();
|
|
|
|
|
size_t outputOffset =
|
|
|
|
|
(outputChannels / groups_) * outputHeight * outputWidth;
|
|
|
|
@ -321,8 +317,8 @@ public:
|
|
|
|
|
int M = outputChannels / groups_;
|
|
|
|
|
int K = outputHeight * outputWidth;
|
|
|
|
|
int N = inputChannels / groups_ * filterHeight * filterWidth;
|
|
|
|
|
gemm(CblasNoTrans,
|
|
|
|
|
CblasTrans,
|
|
|
|
|
BlasGemm<Device, real>::compute(false,
|
|
|
|
|
true,
|
|
|
|
|
M,
|
|
|
|
|
N,
|
|
|
|
|
K,
|
|
|
|
|