|
|
|
@ -48,20 +48,17 @@ bool CMRProjectionNormLayer::init(const LayerMap& layerMap,
|
|
|
|
|
if (useGpu_) {
|
|
|
|
|
forward_ = FunctionBase::funcRegistrar_.createByType(
|
|
|
|
|
FUNC_NAME(CrossMapNormal, GPU));
|
|
|
|
|
backward_ = FunctionBase::funcRegistrar_.createByType(
|
|
|
|
|
FUNC_NAME(CrossMapNormalGrad, GPU));
|
|
|
|
|
} else {
|
|
|
|
|
forward_ = FunctionBase::funcRegistrar_.createByType(
|
|
|
|
|
FUNC_NAME(CrossMapNormal, CPU));
|
|
|
|
|
backward_ = FunctionBase::funcRegistrar_.createByType(
|
|
|
|
|
FUNC_NAME(CrossMapNormalGrad, CPU));
|
|
|
|
|
}
|
|
|
|
|
forward_->init(
|
|
|
|
|
FuncConfig().set("size", size_).set("scale", scale_).set("pow", pow_));
|
|
|
|
|
|
|
|
|
|
if (useGpu_) {
|
|
|
|
|
backward_ = FunctionBase::funcRegistrar_.createByType(
|
|
|
|
|
FUNC_NAME(CrossMapNormalGrad, GPU));
|
|
|
|
|
} else {
|
|
|
|
|
backward_ = FunctionBase::funcRegistrar_.createByType(
|
|
|
|
|
FUNC_NAME(CrossMapNormalGrad, CPU));
|
|
|
|
|
}
|
|
|
|
|
backward_->init(
|
|
|
|
|
FuncConfig().set("size", size_).set("scale", scale_).set("pow", pow_));
|
|
|
|
|
|
|
|
|
@ -74,7 +71,7 @@ void CMRProjectionNormLayer::forward(PassType passType) {
|
|
|
|
|
/* malloc memory for the output_ if necessary */
|
|
|
|
|
/* note: one sample correspond to one row */
|
|
|
|
|
MatrixPtr input = inputLayers_[0]->getOutputValue();
|
|
|
|
|
int batchSize = input->getHeight();
|
|
|
|
|
size_t batchSize = input->getHeight();
|
|
|
|
|
int size = getSize();
|
|
|
|
|
resetOutput(batchSize, size);
|
|
|
|
|
|
|
|
|
@ -82,10 +79,7 @@ void CMRProjectionNormLayer::forward(PassType passType) {
|
|
|
|
|
|
|
|
|
|
Matrix::resizeOrCreate(denoms_, batchSize, size, /* trans */ false, useGpu_);
|
|
|
|
|
|
|
|
|
|
dims_ = {(size_t)batchSize,
|
|
|
|
|
(size_t)channels_,
|
|
|
|
|
(size_t)imgSizeH_,
|
|
|
|
|
(size_t)imgSizeW_};
|
|
|
|
|
dims_ = {batchSize, channels_, imgSizeH_, imgSizeW_};
|
|
|
|
|
forward_->calc(
|
|
|
|
|
{Tensor(input->getData(), dims_)},
|
|
|
|
|
{Tensor(outV->getData(), dims_), Tensor(denoms_->getData(), dims_)},
|
|
|
|
|