|
|
|
@ -85,7 +85,6 @@ public:
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Im2ColFunctor<kCFO, Device, real> im2col;
|
|
|
|
|
GemmFunctor<Device, real> gemm;
|
|
|
|
|
size_t inputOffset = imShape.getElements();
|
|
|
|
|
size_t outputOffset =
|
|
|
|
|
(outputChannels / groups_) * outputHeight * outputWidth;
|
|
|
|
@ -108,19 +107,19 @@ public:
|
|
|
|
|
int M = outputChannels / groups_;
|
|
|
|
|
int N = outputHeight * outputWidth;
|
|
|
|
|
int K = inputChannels / groups_ * filterHeight * filterWidth;
|
|
|
|
|
gemm(CblasNoTrans,
|
|
|
|
|
CblasNoTrans,
|
|
|
|
|
M,
|
|
|
|
|
N,
|
|
|
|
|
K,
|
|
|
|
|
1.0f,
|
|
|
|
|
filterData + g * filterOffset,
|
|
|
|
|
K,
|
|
|
|
|
colData,
|
|
|
|
|
N,
|
|
|
|
|
beta,
|
|
|
|
|
outputData + g * outputOffset,
|
|
|
|
|
N);
|
|
|
|
|
BlasGemm<Device, real>::compute(false,
|
|
|
|
|
false,
|
|
|
|
|
M,
|
|
|
|
|
N,
|
|
|
|
|
K,
|
|
|
|
|
1.0f,
|
|
|
|
|
filterData + g * filterOffset,
|
|
|
|
|
K,
|
|
|
|
|
colData,
|
|
|
|
|
N,
|
|
|
|
|
beta,
|
|
|
|
|
outputData + g * outputOffset,
|
|
|
|
|
N);
|
|
|
|
|
}
|
|
|
|
|
inputData += inputChannels * inputHeight * inputWidth;
|
|
|
|
|
outputData += outputChannels * outputHeight * outputWidth;
|
|
|
|
@ -188,8 +187,6 @@ public:
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Col2ImFunctor<kCFO, Device, real> col2im;
|
|
|
|
|
GemmFunctor<Device, real> gemm;
|
|
|
|
|
|
|
|
|
|
size_t inputOffset = imShape.getElements();
|
|
|
|
|
size_t outputOffset =
|
|
|
|
|
(outputChannels / groups_) * outputHeight * outputWidth;
|
|
|
|
@ -205,19 +202,19 @@ public:
|
|
|
|
|
colData = inputGrad + g * inputOffset;
|
|
|
|
|
scale = 1.0f;
|
|
|
|
|
}
|
|
|
|
|
gemm(CblasTrans,
|
|
|
|
|
CblasNoTrans,
|
|
|
|
|
M,
|
|
|
|
|
N,
|
|
|
|
|
K,
|
|
|
|
|
1.0f,
|
|
|
|
|
filterData + g * filterOffset,
|
|
|
|
|
M,
|
|
|
|
|
outputGrad + g * outputOffset,
|
|
|
|
|
N,
|
|
|
|
|
scale,
|
|
|
|
|
colData,
|
|
|
|
|
N);
|
|
|
|
|
BlasGemm<Device, real>::compute(true,
|
|
|
|
|
false,
|
|
|
|
|
M,
|
|
|
|
|
N,
|
|
|
|
|
K,
|
|
|
|
|
1.0f,
|
|
|
|
|
filterData + g * filterOffset,
|
|
|
|
|
M,
|
|
|
|
|
outputGrad + g * outputOffset,
|
|
|
|
|
N,
|
|
|
|
|
scale,
|
|
|
|
|
colData,
|
|
|
|
|
N);
|
|
|
|
|
if (needIm2col) {
|
|
|
|
|
col2im(inputGrad + g * inputOffset,
|
|
|
|
|
imShape,
|
|
|
|
@ -299,7 +296,6 @@ public:
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Im2ColFunctor<kCFO, Device, real> im2col;
|
|
|
|
|
GemmFunctor<Device, real> gemm;
|
|
|
|
|
size_t inputOffset = imShape.getElements();
|
|
|
|
|
size_t outputOffset =
|
|
|
|
|
(outputChannels / groups_) * outputHeight * outputWidth;
|
|
|
|
@ -321,19 +317,19 @@ public:
|
|
|
|
|
int M = outputChannels / groups_;
|
|
|
|
|
int K = outputHeight * outputWidth;
|
|
|
|
|
int N = inputChannels / groups_ * filterHeight * filterWidth;
|
|
|
|
|
gemm(CblasNoTrans,
|
|
|
|
|
CblasTrans,
|
|
|
|
|
M,
|
|
|
|
|
N,
|
|
|
|
|
K,
|
|
|
|
|
1.0f,
|
|
|
|
|
outputGrad + g * outputOffset,
|
|
|
|
|
K,
|
|
|
|
|
colData,
|
|
|
|
|
K,
|
|
|
|
|
i == 0 ? beta : 1.0f,
|
|
|
|
|
filterGrad + g * filterOffset,
|
|
|
|
|
N);
|
|
|
|
|
BlasGemm<Device, real>::compute(false,
|
|
|
|
|
true,
|
|
|
|
|
M,
|
|
|
|
|
N,
|
|
|
|
|
K,
|
|
|
|
|
1.0f,
|
|
|
|
|
outputGrad + g * outputOffset,
|
|
|
|
|
K,
|
|
|
|
|
colData,
|
|
|
|
|
K,
|
|
|
|
|
i == 0 ? beta : 1.0f,
|
|
|
|
|
filterGrad + g * filterOffset,
|
|
|
|
|
N);
|
|
|
|
|
}
|
|
|
|
|
inputData += inputChannels * inputHeight * inputWidth;
|
|
|
|
|
outputGrad += outputChannels * outputHeight * outputWidth;
|
|
|
|
|