|
|
|
@ -36,6 +36,16 @@ MatrixPtr CrossChannelNormLayer::createSpatialMatrix(MatrixPtr data,
|
|
|
|
|
data->getData() + iter * spatialDim, 1, spatialDim, false, useGpu_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool CrossChannelNormLayer::init(const LayerMap& layerMap,
|
|
|
|
|
const ParameterMap& parameterMap) {
|
|
|
|
|
Layer::init(layerMap, parameterMap);
|
|
|
|
|
CHECK(parameters_[0]);
|
|
|
|
|
const NormConfig& conf = config_.inputs(0).norm_conf();
|
|
|
|
|
channels_ = conf.channels();
|
|
|
|
|
scale_.reset(new Weight(channels_, 1, parameters_[0]));
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void CrossChannelNormLayer::forward(PassType passType) {
|
|
|
|
|
Layer::forward(passType);
|
|
|
|
|
MatrixPtr inV = getInputValue(0);
|
|
|
|
@ -63,6 +73,7 @@ void CrossChannelNormLayer::forward(PassType passType) {
|
|
|
|
|
|
|
|
|
|
// compute norm.
|
|
|
|
|
spatialBuffer_->sumCols(*dataTmp, 1, 0);
|
|
|
|
|
spatialBuffer_->add(*normTmp);
|
|
|
|
|
spatialBuffer_->sqrt2(*spatialBuffer_);
|
|
|
|
|
normTmp->copyFrom(*spatialBuffer_);
|
|
|
|
|
outVTmp->copyFrom(*inVTmp);
|
|
|
|
@ -82,6 +93,9 @@ void CrossChannelNormLayer::backward(const UpdateCallback& callback) {
|
|
|
|
|
size_t dataDim = inG->getWidth();
|
|
|
|
|
size_t spatialDim = dataDim / channels_;
|
|
|
|
|
|
|
|
|
|
MatrixPtr inGBuffer;
|
|
|
|
|
Matrix::resizeOrCreate(inGBuffer, channels_, spatialDim, false, useGpu_);
|
|
|
|
|
|
|
|
|
|
dataBuffer_->dotMul(*outG, *outV);
|
|
|
|
|
Matrix::resizeOrCreate(scaleDiff_, channels_, 1, false, useGpu_);
|
|
|
|
|
Matrix::resizeOrCreate(channelBuffer_, channels_, 1, false, useGpu_);
|
|
|
|
@ -100,22 +114,24 @@ void CrossChannelNormLayer::backward(const UpdateCallback& callback) {
|
|
|
|
|
scaleDiff_->add(*channelBuffer_, 1.);
|
|
|
|
|
|
|
|
|
|
sampleBuffer_->dotMul(*inVTmp, *outGTmp);
|
|
|
|
|
spatialBuffer_->sumCols(*sampleBuffer_, 1., 1.);
|
|
|
|
|
spatialBuffer_->sumCols(*sampleBuffer_, 1., 0.);
|
|
|
|
|
// scale the grad
|
|
|
|
|
inGTmp->copyFrom(*inVTmp);
|
|
|
|
|
inGTmp->mulRowVector(*spatialBuffer_);
|
|
|
|
|
inGBuffer->copyFrom(*inVTmp);
|
|
|
|
|
inGBuffer->mulRowVector(*spatialBuffer_);
|
|
|
|
|
// divide by square of norm
|
|
|
|
|
spatialBuffer_->dotMul(*normTmp, *normTmp);
|
|
|
|
|
inGTmp->divRowVector(*spatialBuffer_);
|
|
|
|
|
inGBuffer->divRowVector(*spatialBuffer_);
|
|
|
|
|
// subtract
|
|
|
|
|
inGTmp->add(*outGTmp, -1, 1);
|
|
|
|
|
inGBuffer->add(*outGTmp, -1, 1);
|
|
|
|
|
// divide by norm
|
|
|
|
|
inGTmp->divRowVector(*normTmp);
|
|
|
|
|
inGBuffer->divRowVector(*normTmp);
|
|
|
|
|
// scale the diff
|
|
|
|
|
inGTmp->mulColVector(*scale_->getW());
|
|
|
|
|
inGBuffer->mulColVector(*scale_->getW());
|
|
|
|
|
|
|
|
|
|
inGTmp->add(*inGBuffer);
|
|
|
|
|
}
|
|
|
|
|
// updata scale
|
|
|
|
|
if (scale_->getWGrad()) scale_->getWGrad()->copyFrom(*scaleDiff_);
|
|
|
|
|
if (scale_->getWGrad()) scale_->getWGrad()->add(*scaleDiff_);
|
|
|
|
|
scale_->getParameterPtr()->incUpdate(callback);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|