add getSize method for PoolProjection

avx_docs
qijun 8 years ago
parent bdc9d10ae5
commit 70e04683dd

@ -32,6 +32,8 @@ PoolProjection* PoolProjection::create(const ProjectionConfig& config,
}
void MaxPoolProjection::forward() {
size_t width = getSize();
CHECK_EQ(width, out_->value->getWidth());
MatrixPtr inputV = in_->value;
MatrixPtr outV = out_->value;
outV->maxPoolForward(*inputV, imgSizeY_, imgSize_, channels_, sizeX_, sizeY_,
@ -55,6 +57,8 @@ void MaxPoolProjection::backward(const UpdateCallback& callback) {
}
void AvgPoolProjection::forward() {
size_t width = getSize();
CHECK_EQ(width, out_->value->getWidth());
MatrixPtr inputV = in_->value;
MatrixPtr outV = out_->value;
outV->avgPoolForward(*inputV, imgSizeY_, imgSize_, channels_, sizeX_, sizeY_,

@ -51,6 +51,26 @@ public:
static PoolProjection* create(const ProjectionConfig& config,
ParameterPtr parameter, bool useGpu);
const std::string& getPoolType() const { return poolType_; }
size_t getSize() {
imgSizeY_ = in_->getFrameHeight();
imgSize_ = in_->getFrameWidth();
const PoolConfig& conf = config_.pool_conf();
if (imgSizeY_ == 0) {
imgSizeY_ = conf.has_img_size_y() ? conf.img_size_y() : conf.img_size();
}
if (imgSize_ == 0) {
imgSize_ = conf.img_size();
}
outputY_ = outputSize(imgSizeY_, sizeY_, confPaddingY_, strideY_,
/* caffeMode */ false);
outputX_ = outputSize(imgSize_, sizeX_, confPadding_, stride_,
/* caffeMode */ false);
const_cast<Argument*>(out_)->setFrameHeight(outputY_);
const_cast<Argument*>(out_)->setFrameWidth(outputX_);
return outputY_ * outputX_ * channels_;
}
};
class MaxPoolProjection : public PoolProjection {

@ -38,8 +38,6 @@ size_t PoolProjectionLayer::getSize() {
layerSize = outputH_ * outputW_ * channels_;
getOutput().setFrameHeight(outputH_);
getOutput().setFrameWidth(outputW_);
return layerSize;
}

@ -70,10 +70,6 @@ size_t SpatialPyramidPoolLayer::getSize() {
size_t outputW = (std::pow(4, pyramidHeight_) - 1) / (4 - 1);
layerSize = outputH * outputW * channels_;
getOutput().setFrameHeight(outputH);
getOutput().setFrameWidth(outputW);
return layerSize;
}

Loading…
Cancel
Save