|
|
@ -206,8 +206,7 @@ public:
|
|
|
|
colData = reinterpret_cast<real*>(memory_->getBuf());
|
|
|
|
colData = reinterpret_cast<real*>(memory_->getBuf());
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
Im2ColFunctor<kCFO, Device, real> im2col;
|
|
|
|
Im2ColMobileFunctor<real> im2col;
|
|
|
|
GemmFunctor<Device, real> gemm;
|
|
|
|
|
|
|
|
size_t inputOffset = imShape.getElements();
|
|
|
|
size_t inputOffset = imShape.getElements();
|
|
|
|
size_t outputOffset =
|
|
|
|
size_t outputOffset =
|
|
|
|
(outputChannels / groups_) * outputHeight * outputWidth;
|
|
|
|
(outputChannels / groups_) * outputHeight * outputWidth;
|
|
|
@ -241,19 +240,20 @@ public:
|
|
|
|
|
|
|
|
|
|
|
|
// gemm
|
|
|
|
// gemm
|
|
|
|
int M = outputChannels / groups_;
|
|
|
|
int M = outputChannels / groups_;
|
|
|
|
gemm(CblasNoTrans,
|
|
|
|
BlasGemm<Device, real>::compute(
|
|
|
|
CblasNoTrans,
|
|
|
|
false,
|
|
|
|
M,
|
|
|
|
false,
|
|
|
|
N,
|
|
|
|
M,
|
|
|
|
K,
|
|
|
|
N,
|
|
|
|
1.0f,
|
|
|
|
K,
|
|
|
|
filterData + g * filterOffset + colHeightStart,
|
|
|
|
1.0f,
|
|
|
|
kStride,
|
|
|
|
filterData + g * filterOffset + colHeightStart,
|
|
|
|
colData,
|
|
|
|
kStride,
|
|
|
|
N,
|
|
|
|
colData,
|
|
|
|
beta_,
|
|
|
|
N,
|
|
|
|
outputData + g * outputOffset + colWidthStart,
|
|
|
|
beta_,
|
|
|
|
nStride);
|
|
|
|
outputData + g * outputOffset + colWidthStart,
|
|
|
|
|
|
|
|
nStride);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
beta_ = 1.0;
|
|
|
|
beta_ = 1.0;
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -261,19 +261,19 @@ public:
|
|
|
|
int M = outputChannels / groups_;
|
|
|
|
int M = outputChannels / groups_;
|
|
|
|
int N = outputHeight * outputWidth;
|
|
|
|
int N = outputHeight * outputWidth;
|
|
|
|
int K = inputChannels / groups_ * filterHeight * filterWidth;
|
|
|
|
int K = inputChannels / groups_ * filterHeight * filterWidth;
|
|
|
|
gemm(CblasNoTrans,
|
|
|
|
BlasGemm<Device, real>::compute(false,
|
|
|
|
CblasNoTrans,
|
|
|
|
false,
|
|
|
|
M,
|
|
|
|
M,
|
|
|
|
N,
|
|
|
|
N,
|
|
|
|
K,
|
|
|
|
K,
|
|
|
|
1.0f,
|
|
|
|
1.0f,
|
|
|
|
filterData + g * filterOffset,
|
|
|
|
filterData + g * filterOffset,
|
|
|
|
K,
|
|
|
|
K,
|
|
|
|
inputData + g * inputOffset,
|
|
|
|
inputData + g * inputOffset,
|
|
|
|
N,
|
|
|
|
N,
|
|
|
|
beta,
|
|
|
|
beta,
|
|
|
|
outputData + g * outputOffset,
|
|
|
|
outputData + g * outputOffset,
|
|
|
|
N);
|
|
|
|
N);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
inputData += inputChannels * inputHeight * inputWidth;
|
|
|
|
inputData += inputChannels * inputHeight * inputWidth;
|
|
|
|