parent
0f3a3e9894
commit
2377d71947
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,198 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "Pool3DLayer.h"
|
||||
#include "PoolProjectionLayer.h"
|
||||
#include "paddle/utils/Logging.h"
|
||||
|
||||
namespace paddle {
|
||||
|
||||
REGISTER_LAYER(pool3d, Pool3DLayer);
|
||||
|
||||
bool Pool3DLayer::init(const LayerMap& layerMap,
|
||||
const ParameterMap& parameterMap) {
|
||||
Layer::init(layerMap, parameterMap);
|
||||
|
||||
/* the size of inputs for pool-layer is 1 */
|
||||
CHECK_EQ(config_.inputs_size(), 1);
|
||||
|
||||
const PoolConfig& conf = config_.inputs(0).pool_conf();
|
||||
poolType_ = conf.pool_type();
|
||||
channels_ = conf.channels();
|
||||
|
||||
sizeX_ = conf.size_x();
|
||||
sizeY_ = conf.size_y();
|
||||
sizeZ_ = conf.size_z();
|
||||
|
||||
strideW_ = conf.stride();
|
||||
strideH_ = conf.stride_y();
|
||||
strideD_ = conf.stride_z();
|
||||
|
||||
imgSizeW_ = conf.img_size();
|
||||
imgSizeH_ = conf.img_size_y();
|
||||
imgSizeD_ = conf.img_size_z();
|
||||
|
||||
paddingW_ = conf.padding();
|
||||
paddingH_ = conf.padding_y();
|
||||
paddingD_ = conf.padding_z();
|
||||
|
||||
outputW_ = conf.output_x();
|
||||
outputH_ = conf.output_y();
|
||||
outputD_ = conf.output_z();
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
size_t Pool3DLayer::getSize() {
|
||||
CHECK_EQ(inputLayers_.size(), 1UL);
|
||||
|
||||
size_t layerSize = 0;
|
||||
// imgSizeD_ = inputLayers_[0]->getOutput().getFrameDepth();
|
||||
// imgSizeH_ = inputLayers_[0]->getOutput().getFrameHeight();
|
||||
// imgSizeW_ = inputLayers_[0]->getOutput().getFrameWidth();
|
||||
if (imgSizeH_ == 0) {
|
||||
// imgSizeH_ = imgSizeY_;
|
||||
}
|
||||
if (imgSizeW_ == 0) {
|
||||
// imgSizeW_ = imgSize_;
|
||||
}
|
||||
outputD_ = outputSize(imgSizeD_,
|
||||
sizeZ_,
|
||||
paddingD_,
|
||||
strideD_,
|
||||
/* caffeMode */ false);
|
||||
outputH_ = outputSize(imgSizeH_,
|
||||
sizeY_,
|
||||
paddingH_,
|
||||
strideH_,
|
||||
/* caffeMode */ false);
|
||||
outputW_ = outputSize(imgSizeW_,
|
||||
sizeX_,
|
||||
paddingW_,
|
||||
strideW_,
|
||||
/* caffeMode */ false);
|
||||
|
||||
layerSize = outputD_ * outputH_ * outputW_ * channels_;
|
||||
getOutput().setFrameHeight(outputH_);
|
||||
getOutput().setFrameWidth(outputW_);
|
||||
getOutput().setFrameDepth(outputD_);
|
||||
return layerSize;
|
||||
}
|
||||
|
||||
void Pool3DLayer::forward(PassType passType) {
|
||||
Layer::forward(passType);
|
||||
const MatrixPtr& inMat = inputLayers_[0]->getOutputValue();
|
||||
int batchSize = inMat->getHeight();
|
||||
int outWidth = getSize();
|
||||
resetOutput(batchSize, outWidth);
|
||||
const MatrixPtr outMat = getOutputValue();
|
||||
|
||||
if (poolType_ == "avg") {
|
||||
outMat->avgPool3DForward(*inMat,
|
||||
imgSizeD_,
|
||||
imgSizeH_,
|
||||
imgSizeW_,
|
||||
channels_,
|
||||
sizeZ_,
|
||||
sizeY_,
|
||||
sizeX_,
|
||||
strideD_,
|
||||
strideH_,
|
||||
strideW_,
|
||||
outputD_,
|
||||
outputH_,
|
||||
outputW_,
|
||||
paddingD_,
|
||||
paddingH_,
|
||||
paddingW_);
|
||||
} else if (poolType_ == "max") {
|
||||
outMat->maxPool3DForward(*inMat,
|
||||
imgSizeD_,
|
||||
imgSizeH_,
|
||||
imgSizeW_,
|
||||
channels_,
|
||||
sizeZ_,
|
||||
sizeY_,
|
||||
sizeX_,
|
||||
strideD_,
|
||||
strideH_,
|
||||
strideW_,
|
||||
outputD_,
|
||||
outputH_,
|
||||
outputW_,
|
||||
paddingD_,
|
||||
paddingH_,
|
||||
paddingW_);
|
||||
} else {
|
||||
LOG(FATAL) << "Unknown pool type: " << poolType_;
|
||||
}
|
||||
forwardActivation();
|
||||
}
|
||||
|
||||
void Pool3DLayer::backward(const UpdateCallback& callback) {
|
||||
backwardActivation();
|
||||
|
||||
(void)callback;
|
||||
if (NULL == getInputGrad(0)) return;
|
||||
MatrixPtr inMat = inputLayers_[0]->getOutputValue();
|
||||
MatrixPtr inGradMat = inputLayers_[0]->getOutputGrad();
|
||||
MatrixPtr outMat = getOutputValue();
|
||||
MatrixPtr outGradMat = getOutputGrad();
|
||||
|
||||
if (poolType_ == "avg") {
|
||||
inGradMat->avgPool3DBackward(*outGradMat,
|
||||
imgSizeD_,
|
||||
imgSizeH_,
|
||||
imgSizeW_,
|
||||
sizeZ_,
|
||||
sizeY_,
|
||||
sizeZ_,
|
||||
strideD_,
|
||||
strideH_,
|
||||
strideW_,
|
||||
outputD_,
|
||||
outputH_,
|
||||
outputW_,
|
||||
1,
|
||||
1,
|
||||
paddingD_,
|
||||
paddingH_,
|
||||
paddingW_);
|
||||
} else if (poolType_ == "max") {
|
||||
inGradMat->maxPool3DBackward(*inMat,
|
||||
imgSizeD_,
|
||||
imgSizeH_,
|
||||
imgSizeW_,
|
||||
*outGradMat,
|
||||
*outMat,
|
||||
sizeZ_,
|
||||
sizeY_,
|
||||
sizeZ_,
|
||||
strideD_,
|
||||
strideH_,
|
||||
strideW_,
|
||||
outputD_,
|
||||
outputH_,
|
||||
outputW_,
|
||||
1,
|
||||
1,
|
||||
paddingD_,
|
||||
paddingH_,
|
||||
paddingW_);
|
||||
} else {
|
||||
LOG(FATAL) << "Unknown pool type: " << poolType_;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace paddle
|
@ -0,0 +1,48 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include "Layer.h"
|
||||
#include "paddle/math/MathUtils.h"
|
||||
#include "paddle/math/Matrix.h"
|
||||
|
||||
namespace paddle {
|
||||
|
||||
/**
|
||||
* @brief Basic parent layer of pooling
|
||||
* Pools the input within regions
|
||||
*/
|
||||
class Pool3DLayer : public Layer {
|
||||
public:
|
||||
explicit Pool3DLayer(const LayerConfig& config) : Layer(config) {}
|
||||
~Pool3DLayer() {}
|
||||
|
||||
bool init(const LayerMap& layerMap,
|
||||
const ParameterMap& parameterMap) override;
|
||||
void forward(PassType passType) override;
|
||||
void backward(const UpdateCallback& callback) override;
|
||||
size_t getSize();
|
||||
|
||||
protected:
|
||||
int channels_;
|
||||
int sizeX_, sizeY_, sizeZ_;
|
||||
int strideW_, strideH_, strideD_;
|
||||
int paddingW_, paddingH_, paddingD_;
|
||||
int imgSizeW_, imgSizeH_, imgSizeD_;
|
||||
int outputW_, outputH_, outputD_;
|
||||
std::string poolType_;
|
||||
};
|
||||
} // namespace paddle
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Loading…
Reference in new issue