From 7dc584f5c437f76383a9b83b6208c25bfaf4a82e Mon Sep 17 00:00:00 2001 From: xzl Date: Mon, 20 Nov 2017 11:54:44 +0800 Subject: [PATCH 1/3] add upsample layer --- paddle/cuda/include/hl_cnn.h | 42 ++++++ paddle/cuda/include/stub/hl_cnn_stub.h | 18 +++ paddle/cuda/src/hl_cuda_cnn.cu | 76 +++++++++++ paddle/gserver/layers/UpsampleLayer.cpp | 107 +++++++++++++++ paddle/gserver/layers/UpsampleLayer.h | 54 ++++++++ paddle/math/Matrix.cpp | 126 ++++++++++++++++++ paddle/math/Matrix.h | 52 ++++++++ proto/ModelConfig.proto | 11 ++ python/paddle/trainer/config_parser.py | 44 ++++++ .../paddle/trainer_config_helpers/layers.py | 77 +++++++++++ 10 files changed, 607 insertions(+) create mode 100644 paddle/gserver/layers/UpsampleLayer.cpp create mode 100644 paddle/gserver/layers/UpsampleLayer.h diff --git a/paddle/cuda/include/hl_cnn.h b/paddle/cuda/include/hl_cnn.h index 89c1f48eda..c8dd3eb91e 100644 --- a/paddle/cuda/include/hl_cnn.h +++ b/paddle/cuda/include/hl_cnn.h @@ -366,4 +366,46 @@ extern void hl_maxout_backward(real* inGrad, size_t featLen, size_t groups); +/** + * @brief Upsample forward. + * @param[in] inputData input data. + * @param[out] maskData the mask data from MaxPoolWithMaskLayer. + * @param[out] batchSize the batch size of the input. + * @param[in] imgSizeH image height. + * @param[in] imgSizeW image width. + * @param[in] channels the input channels. + * @param[in] outputH the output height. + * @param[in] outputW the output widht. + * @param[out] outputData output data. + */ +extern void hl_upsample_forward(real *inputData, real *maskData, + size_t batchSize, + size_t imgSizeH, + size_t imgSizeW, + size_t channels, + size_t outputH, + size_t outputW, + real *outputData); + +/** + * @brief Upsample backward. + * @param[in] outputGradData the output grad data. + * @param[out] maskData the mask data from MaxPoolWithMaskLayer. + * @param[out] batchSize the batch size of the input. + * @param[in] imgSizeH image height. + * @param[in] imgSizeW image width. + * @param[in] channels the input channels. + * @param[in] outputH the output height. + * @param[in] outputW the output widht. + * @param[out] inputGradData the input grad data. + */ +extern void hl_upsample_backward(real *outputGradData, real *maskData, + size_t batchSize, + size_t imgSizeH, + size_t imgSizeW, + size_t channels, + size_t outputH, + size_t outputW, + real *inputGradData); + #endif // HL_CNN_H_ diff --git a/paddle/cuda/include/stub/hl_cnn_stub.h b/paddle/cuda/include/stub/hl_cnn_stub.h index 968ed4840f..ef1f67980e 100644 --- a/paddle/cuda/include/stub/hl_cnn_stub.h +++ b/paddle/cuda/include/stub/hl_cnn_stub.h @@ -222,4 +222,22 @@ inline void hl_maxout_backward(real* inGrad, size_t featLen, size_t group) {} +inline void hl_upsample_forward(real *inputData, real *maskData, + size_t batchSize, + size_t imgSizeH, + size_t imgSizeW, + size_t channels, + size_t outputH, + size_t outputW, + real *outputData) {} + +inline void hl_upsample_backward(real *outputGradData, real *maskData, + size_t batchSize, + size_t imgSizeH, + size_t imgSizeW, + size_t channels, + size_t outputH, + size_t outputW, + real *inputGradData) {} + #endif // HL_CNN_STUB_H_ diff --git a/paddle/cuda/src/hl_cuda_cnn.cu b/paddle/cuda/src/hl_cuda_cnn.cu index 3699b1e8ae..966c406a86 100644 --- a/paddle/cuda/src/hl_cuda_cnn.cu +++ b/paddle/cuda/src/hl_cuda_cnn.cu @@ -1020,3 +1020,79 @@ void hl_maxout_backward(real* inGrad, num_kernels, inGrad, outGrad, idData, size, featLen, groups); CHECK_SYNC("hl_maxout_backward failed"); } + +__global__ void upsampleForwardCompute(real* input_data, + real* mask_data, + size_t nthreads, + size_t in_h, + size_t in_w, + size_t out_h, + size_t out_w, + real* output_data) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < nthreads) { + int offset = index / (in_w * in_h) * out_h * out_w; + int upsample_idx = static_cast(mask_data[index]); + output_data[offset + upsample_idx] = input_data[index]; + } +} + +__global__ void upsampleBackwardCompute(real* out_grad, + real* mask_data, + size_t nthreads, + size_t in_h, + size_t in_w, + size_t out_h, + size_t out_w, + real* input_grad) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < nthreads) { + int offset = index / (in_w * in_h) * out_h * out_w; + int upsample_idx = static_cast(mask_data[index]); + input_grad[index] = out_grad[offset + upsample_idx]; + } +} + +void hl_upsample_forward(real* inputData, + real* maskData, + size_t batchSize, + size_t imgSizeH, + size_t imgSizeW, + size_t channels, + size_t outputH, + size_t outputW, + real* outputData) { + int num_kernels = batchSize * imgSizeH * imgSizeW * channels; + int blocks = (num_kernels + 1024 - 1) / 1024; + upsampleForwardCompute<<>>(inputData, + maskData, + num_kernels, + imgSizeH, + imgSizeW, + outputH, + outputW, + outputData); + CHECK_SYNC("hl_upsample_forward failed"); +} + +void hl_upsample_backward(real* outputGradData, + real* maskData, + size_t batchSize, + size_t imgSizeH, + size_t imgSizeW, + size_t channels, + size_t outputH, + size_t outputW, + real* inputGradData) { + int num_kernels = batchSize * imgSizeH * imgSizeW * channels; + int blocks = (num_kernels + 1024 - 1) / 1024; + upsampleBackwardCompute<<>>(outputGradData, + maskData, + num_kernels, + imgSizeH, + imgSizeW, + outputH, + outputW, + inputGradData); + CHECK_SYNC("hl_upsample_backward failed"); +} diff --git a/paddle/gserver/layers/UpsampleLayer.cpp b/paddle/gserver/layers/UpsampleLayer.cpp new file mode 100644 index 0000000000..300bb82d68 --- /dev/null +++ b/paddle/gserver/layers/UpsampleLayer.cpp @@ -0,0 +1,107 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. */ + +#include "UpsampleLayer.h" +#include "iostream" + +namespace paddle { + +REGISTER_LAYER(upsample, UpsampleLayer); + +size_t UpsampleLayer::getOutputSize() { + if (upsampleSize_ == 0) { + upsampleSize_ = imgSize_ * scale_ - static_cast(padOutX_); + upsampleSizeY_ = imgSizeY_ * scaleY_ - static_cast(padOutY_); + } + return upsampleSize_ * upsampleSizeY_ * channels_; +} + +bool UpsampleLayer::init(const LayerMap& layerMap, + const ParameterMap& parameterMap) { + Layer::init(layerMap, parameterMap); + CHECK_EQ(inputLayers_.size(), 2U); + CHECK_EQ(config_.inputs_size(), 2); + const auto& conf = config_.inputs(0).upsample_conf(); + const auto& img_conf = conf.image_conf(); + + imgSizeY_ = + img_conf.has_img_size_y() ? img_conf.img_size_y() : img_conf.img_size(); + imgSize_ = img_conf.img_size(); + channels_ = img_conf.channels(); + + CHECK((conf.has_upsample_size()) || (conf.has_scale())) + << "scale or upsample_size is required."; + + if (conf.has_upsample_size()) { + upsampleSize_ = conf.upsample_size(); + upsampleSizeY_ = upsampleSize_; + if (conf.has_upsample_size_y()) { + upsampleSizeY_ = conf.upsample_size_y(); + } + } else { + if (!conf.has_scale_y()) { + scale_ = scaleY_ = conf.scale_y(); + CHECK_GT(static_cast(scale_), 1); + } else { + scale_ = conf.scale(); + scaleY_ = conf.scale_y(); + } + padOutX_ = conf.pad_out_x(); + padOutY_ = conf.pad_out_y(); + CHECK(!padOutX_ || scale_ == 2) + << "Output height padding compensation requires scale_ == 2"; + CHECK(!padOutY_ || scaleY_ == 2) + << "Output width padding compensation requires scaleY_ == 2"; + upsampleSize_ = upsampleSizeY_ = 0; + } + return true; +} + +void UpsampleLayer::forward(PassType passType) { + Layer::forward(passType); + + MatrixPtr input = getInputValue(0); + MatrixPtr mask = inputLayers_[1]->getOutput("mask").value; + + size_t batchSize = input->getHeight(); + size_t outSize = getOutputSize(); + + CHECK_EQ(input->getWidth(), mask->getWidth()); + CHECK_EQ(mask->getHeight(), batchSize); + resetOutput(batchSize, outSize); + + MatrixPtr output = getOutputValue(); + output->upsampleForward(*input, + *mask, + imgSize_, + imgSizeY_, + channels_, + upsampleSize_, + upsampleSizeY_); +} + +void UpsampleLayer::backward(const UpdateCallback& callback) { + MatrixPtr mask = inputLayers_[1]->getOutput("mask").value; + MatrixPtr inputGrad = getInputGrad(0); + MatrixPtr outputGrad = getOutputGrad(); + inputGrad->upsampleBackward(*outputGrad, + *mask, + imgSize_, + imgSizeY_, + channels_, + upsampleSize_, + upsampleSizeY_); +} + +} // namespace paddle diff --git a/paddle/gserver/layers/UpsampleLayer.h b/paddle/gserver/layers/UpsampleLayer.h new file mode 100644 index 0000000000..2ae9363439 --- /dev/null +++ b/paddle/gserver/layers/UpsampleLayer.h @@ -0,0 +1,54 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include "Layer.h" +#include "paddle/math/Matrix.h" +#include "paddle/utils/Logging.h" +#include "paddle/utils/Stat.h" + +namespace paddle { + +/** + * This layer transpose the pooling process. + * It takes two input, the first input is the input data, and + * the second is the mask data from the max-pool-with-mask layer. + * + */ + +class UpsampleLayer : public Layer { +public: + explicit UpsampleLayer(const LayerConfig& config) : Layer(config) {} + + ~UpsampleLayer() {} + + bool init(const LayerMap& layerMap, + const ParameterMap& parameterMap) override; + + void forward(PassType passType) override; + void backward(const UpdateCallback& callback) override; + + size_t getOutputSize(); + +protected: + size_t scale_, scaleY_; + size_t upsampleSize_, upsampleSizeY_; + size_t padOutX_, padOutY_; + size_t imgSize_, imgSizeY_; + size_t channels_; +}; + +} // namespace paddle diff --git a/paddle/math/Matrix.cpp b/paddle/math/Matrix.cpp index 88e9180690..1f6458a288 100644 --- a/paddle/math/Matrix.cpp +++ b/paddle/math/Matrix.cpp @@ -1023,6 +1023,64 @@ void GpuMatrix::check(std::ostream& os, Matrix& refMat, bool printDiff) { LOG(INFO) << "the diffCnt is " << diffCnt; } +void GpuMatrix::upsampleForward(Matrix& input, + Matrix& mask, + size_t imgSizeH, + size_t imgSizeW, + size_t channels, + size_t outputH, + size_t outputW) { + CHECK(input.useGpu_ == true) << "Matrix type are not equal"; + CHECK(mask.useGpu_ == true) << "Matrix type are not equal"; + + real *inputData = input.getData(); + real *maskData = mask.getData(); + real *outData = data_; + + size_t batch = input.getHeight(); + + CHECK(imgSizeH * imgSizeW * channels == input.getWidth()); + CHECK(imgSizeH * imgSizeW * channels == mask.getWidth()); + CHECK_EQ(batch, this->getHeight()); + CHECK(width_ == outputH * outputW * channels); + hl_upsample_forward(inputData, maskData, + batch, + imgSizeH, + imgSizeW, + channels, + outputH, + outputW, + outData); +} + +void GpuMatrix::upsampleBackward(Matrix& outputGrad, + Matrix& mask, + size_t imgSizeH, + size_t imgSizeW, + size_t channels, + size_t outputH, + size_t outputW) { + CHECK(outputGrad.useGpu_ == true) << "Matrix type are not equal"; + CHECK(mask.useGpu_ == true) << "Matrix type are not equal"; + + real *outputGradData = outputGrad.getData(); + real *maskData = mask.getData(); + real *inputGradData = data_; + size_t batch = outputGrad.getHeight(); + + CHECK(imgSizeH * imgSizeW == this->getWidth()/channels); + CHECK_EQ(batch, this->getHeight()); + CHECK_EQ(channels * outputH * outputW, outputGrad.getWidth()); + hl_upsample_backward(outputGradData, maskData, + batch, + imgSizeH, + imgSizeW, + channels, + outputH, + outputW, + inputGradData); +} + void GpuMatrix::maxPoolForward(Matrix& inputMat, size_t imgSizeH, size_t imgSizeW, @@ -1981,6 +2039,74 @@ void CpuMatrix::inverse(MatrixPtr& matInv, bool memAlloc) { CHECK_EQ(info, 0); } +void CpuMatrix::upsampleForward(Matrix& input, + Matrix& mask, + size_t imgSizeH, + size_t imgSizeW, + size_t channels, + size_t outputH, + size_t outputW) { + real *inputData = input.getData(); + real *maskData = mask.getData(); + real *outData = data_; + size_t inLength = imgSizeH * imgSizeW; + size_t outLength = outputH * outputW; + size_t batch = input.getHeight(); + CHECK(inLength == input.getWidth() / channels); + CHECK_EQ(batch, this->getHeight()); + CHECK_EQ(channels * outLength, this->getWidth()); + + for (size_t k = 0; k < batch; k++) { + for (size_t c = 0; c < channels; c++) { + for (size_t i = 0; i < inLength; i++) { + size_t out_index = static_cast(maskData[i]); + if (out_index >= outLength) { + LOG(FATAL) << "upsample index " << out_index + << " out of range."; + } + outData[out_index] = inputData[i]; + } + inputData += inLength; + maskData += inLength; + outData += outLength; + } + } +} + +void CpuMatrix::upsampleBackward(Matrix& outputGrad, + Matrix& mask, + size_t imgSizeH, + size_t imgSizeW, + size_t channels, + size_t outputH, + size_t outputW) { + real *outputGradData = outputGrad.getData(); + real *maskData = mask.getData(); + real *inputGradData = data_; + size_t inLength = imgSizeH * imgSizeW; + size_t outLength = outputH * outputW; + size_t batch = outputGrad.getHeight(); + CHECK(inLength == this->getWidth()/channels); + CHECK_EQ(batch, this->getHeight()); + CHECK_EQ(channels * outLength, outputGrad.getWidth()); + + for (size_t k = 0; k < batch; k++) { + for (size_t c = 0; c < channels; c++) { + for (size_t i = 0; i < inLength; i++) { + size_t out_index = static_cast(maskData[i]); + if (out_index >= outLength) { + LOG(FATAL) << "upsample index " << out_index + << " out of range."; + } + inputGradData[i] = outputGradData[out_index]; + } + inputGradData += inLength; + maskData += inLength; + outputGradData += outLength; + } + } +} + void CpuMatrix::maxPoolForward(Matrix& inputMat, size_t imgSizeH, size_t imgSizeW, diff --git a/paddle/math/Matrix.h b/paddle/math/Matrix.h index e273f11236..b4fcf05cb2 100644 --- a/paddle/math/Matrix.h +++ b/paddle/math/Matrix.h @@ -859,6 +859,26 @@ public: LOG(FATAL) << "Not implemented"; } + virtual void upsampleForward(Matrix& input, + Matrix& mask, + size_t imgSizeH, + size_t imgSizeW, + size_t channels, + size_t outputH, + size_t outputW) { + LOG(FATAL) << "Not implemeted"; + } + + virtual void upsampleBackward(Matrix& outputGrad, + Matrix& mask, + size_t imgSizeH, + size_t imgSizeW, + size_t channels, + size_t outputH, + size_t outputW) { + LOG(FATAL) << "Not implemeted"; + } + /** * Pooling forward operation, pick out the largest element * in the sizeX of value, if the maskMatP is not NULL, it will @@ -1417,6 +1437,22 @@ public: void classificationError(Matrix& output, IVector& label, size_t topkSize = 1); + void upsampleForward(Matrix& input, + Matrix& mask, + size_t imgSizeH, + size_t imgSizeW, + size_t channels, + size_t outputH, + size_t outputW); + + void upsampleBackward(Matrix& outputGrad, + Matrix& mask, + size_t imgSizeH, + size_t imgSizeW, + size_t channels, + size_t outputH, + size_t outputW); + void maxPoolForward(Matrix& inputMat, size_t imgSizeH, size_t imgSizeW, @@ -1689,6 +1725,22 @@ public: MatrixPtr clone(size_t height, size_t width, bool useGpu = false); + void upsampleForward(Matrix& input, + Matrix& mask, + size_t imgSizeH, + size_t imgSizeW, + size_t channels, + size_t outputH, + size_t outputW); + + void upsampleBackward(Matrix& outputGrad, + Matrix& mask, + size_t imgSizeH, + size_t imgSizeW, + size_t channels, + size_t outputH, + size_t outputW); + void maxPoolForward(Matrix& inputMat, size_t imgSizeH, size_t imgSizeW, diff --git a/proto/ModelConfig.proto b/proto/ModelConfig.proto index 2c2cc62459..2cff25d095 100644 --- a/proto/ModelConfig.proto +++ b/proto/ModelConfig.proto @@ -321,6 +321,16 @@ message ClipConfig { required double max = 2; } +message UpsampleConfig { + required ImageConfig image_conf = 1; + optional uint32 scale = 2 [ default = 2 ]; + optional uint32 scale_y = 3 [ default = 2 ]; + optional bool pad_out_x = 4 [ default = false ]; + optional bool pad_out_y = 5 [ default = false ]; + optional uint32 upsample_size = 6; + optional uint32 upsample_size_y = 7; +} + message ROIPoolConfig { required uint32 pooled_width = 1; required uint32 pooled_height = 2; @@ -357,6 +367,7 @@ message LayerInputConfig { optional ClipConfig clip_conf = 18; optional ScaleSubRegionConfig scale_sub_region_conf = 19; optional ROIPoolConfig roi_pool_conf = 20; + optional UpsampleConfig upsample_conf = 21; } message LayerConfig { diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index 5bd68e211a..067ca21d32 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -466,6 +466,7 @@ class Input(Cfg): maxout=None, spp=None, pad=None, + upsample=None, format=None, nnz=None, is_static=None, @@ -977,6 +978,11 @@ class Pad(Cfg): def __init__(self, channels, pad_c, pad_h, pad_w): self.add_keys(locals()) +@config_class +class Upsample(Cfg): + def __init__(self, scale, scale_y, pad_out_x, pad_out_y, upsample_size, + upsample_size_y): + self.add_keys(locals()) @config_class class Norm(Cfg): @@ -2387,6 +2393,44 @@ class SpatialPyramidPoolLayer(LayerBase): output_x = (pow(4, spp_conf.pyramid_height) - 1) / (4 - 1) self.set_cnn_layer(name, 1, output_x, spp_conf.image_conf.channels) +@config_layer('upsample') +class UpsampleLayer(LayerBase): + def __init__(self, name, inputs, **xargs): + super(UpsampleLayer, self).__init__( + name, 'upsample', 0, inputs=inputs, **xargs) + + input_layer = self.get_input_layer(0) + image_conf = self.config.inputs[0].upsample_conf.image_conf + image_conf.img_size = input_layer.width + image_conf.img_size_y = input_layer.height + image_conf.channels = input_layer.size / (input_layer.width * + input_layer.height) + + upsample = self.inputs[0].upsample + output_x = 0 + output_y = 0 + output_size = 0 + if upsample.scale: + self.config.inputs[0].upsample_conf.scale = upsample.scale + self.config.inputs[0].upsample_conf.scale_y = upsample.scale_y + output_x = input_layer.width * upsample.scale + output_y = input_layer.height * upsample.scale_y + self.config.inputs[0].upsample_conf.pad_out_x = upsample.pad_out_x + self.config.inputs[0].upsample_conf.pad_out_y = upsample.pad_out_y + if upsample.upsample_size: + self.config.inputs[ + 0].upsample_conf.upsample_size = upsample.upsample_size + self.config.inputs[ + 0].upsample_conf.upsample_size_y = upsample.upsample_size_y + output_x = upsample.upsample_size + output_y = upsample.upsample_size_y + + output_size = image_conf.channels * output_x * output_y + + + self.set_layer_height_width(output_y, output_x) + self.set_layer_depth(input_layer.depth) + self.set_layer_size(output_size) @config_layer('pad') class PadLayer(LayerBase): diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index 5de1c18950..95369000bb 100644 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -146,6 +146,7 @@ __all__ = [ 'resize_layer', 'sub_seq_layer', 'scale_sub_region_layer', + 'upsample_layer', ] @@ -163,6 +164,7 @@ class LayerType(object): SEQUENCE_RESHAPE = 'seqreshape' POOLING_MAX = 'max' POOLING_AVG = 'average' + UPSAMPLE_LAYER = 'upsample' FC_LAYER = 'fc' COST = 'cost' COSINE_SIM_VEC = 'cos_vm' @@ -2879,6 +2881,81 @@ def img_pool3d_layer(input, num_filters=num_channels, size=l.config.size) +@wrap_name_default("upsample") +@layer_support() +def upsample_layer(input, + name=None, + scale=None, + scale_y=None, + upsample_size=None, + upsample_size_y=None, + pad_out_x=False, + pad_out_y=False, + layer_attr=None): + """ + The DePooling process. + Inputs should be a list of length 2. The first input is a layer, + and the second input should be the MaxWithMaskPoolingLayer + + The example usage is: + + .. code-block:: python + pool1 = paddle.v2.layer.img_pool(input=input, pool_size=2, stride=2, + pool_type=paddle.pooling.MaxWithMask()) + upsample = paddle.v2.layer.upsample(input=[layer1, pool1]) + + :param name: The name of this layer. It is optional. + :type name: basestring + :param input: contains an input layer and a MaxWithMaskPoolingLayer + :type input: list | tuple | collections.Sequence + :param scale: outputSize = scale * inputSize + :type scale: int | list | tuple | . + :param scale_y: scale_y will be equal to scale, if it's value is None, + :type scale: int | None. + :param upsample_size: specify the outputSize. + :type upsample_size: int | list | tuple. + :param upsample_size_y: specify the y dimension outputSize. + :type upsample_size_y: int. + :param pad_out_x: specify exact x dimension size. This parameter only works when scale is 2 + :type pad_out_x: bool. + :param pad_out_y: specify exact y dimension size. This parameter only works when scale is 2 + :type pad_out_y: bool. + :param layer_attr: Extra Layer Attribute. + :type layer_attr: ExtraLayerAttribute + :return: LayerOutput object. + :rtype: LayerOutput + """ + + assert (scale is not None) or (upsample_size is not None), \ + 'scale or upsample_size, there must be one to be designated' + + assert len(input) == 2, 'layer input size must be 2' + assert input[1].layer_type == LayerType.POOL_LAYER, \ + 'the second input should be the MaxPoolWithMaskLayer' + + scale_y = scale \ + if scale is not None else scale_y + upsample_size_y = upsample_size \ + if upsample_size is not None else upsample_size_y + + layer_type = LayerType.UPSAMPLE_LAYER + + layer = Layer( + name=name, + type=layer_type, + inputs=[ + Input( + input[0].name, + upsample=Upsample(scale, scale_y, pad_out_x, pad_out_y, + upsample_size, upsample_size_y)), + Input(input[1].name) + ], + **ExtraLayerAttribute.to_kwargs(layer_attr)) + + sz = layer.config.size + + return LayerOutput(name, layer_type=layer_type, parents=input, size=sz) + @wrap_name_default("spp") @layer_support() From 6da00da7b5b2449d6668a84728708e43ec030433 Mon Sep 17 00:00:00 2001 From: xzl Date: Mon, 20 Nov 2017 11:58:42 +0800 Subject: [PATCH 2/3] code format check --- paddle/cuda/include/hl_cnn.h | 34 +-- paddle/cuda/include/stub/hl_cnn_stub.h | 36 +-- paddle/gserver/layers/UpsampleLayer.cpp | 1 + paddle/gserver/layers/UpsampleLayer.h | 1 - paddle/math/Matrix.cpp | 220 +++++++++--------- paddle/math/Matrix.h | 72 +++--- python/paddle/trainer/config_parser.py | 8 +- .../paddle/trainer_config_helpers/layers.py | 2 + 8 files changed, 192 insertions(+), 182 deletions(-) diff --git a/paddle/cuda/include/hl_cnn.h b/paddle/cuda/include/hl_cnn.h index c8dd3eb91e..8d0822471b 100644 --- a/paddle/cuda/include/hl_cnn.h +++ b/paddle/cuda/include/hl_cnn.h @@ -378,14 +378,15 @@ extern void hl_maxout_backward(real* inGrad, * @param[in] outputW the output widht. * @param[out] outputData output data. */ -extern void hl_upsample_forward(real *inputData, real *maskData, - size_t batchSize, - size_t imgSizeH, - size_t imgSizeW, - size_t channels, - size_t outputH, - size_t outputW, - real *outputData); +extern void hl_upsample_forward(real* inputData, + real* maskData, + size_t batchSize, + size_t imgSizeH, + size_t imgSizeW, + size_t channels, + size_t outputH, + size_t outputW, + real* outputData); /** * @brief Upsample backward. @@ -399,13 +400,14 @@ extern void hl_upsample_forward(real *inputData, real *maskData, * @param[in] outputW the output widht. * @param[out] inputGradData the input grad data. */ -extern void hl_upsample_backward(real *outputGradData, real *maskData, - size_t batchSize, - size_t imgSizeH, - size_t imgSizeW, - size_t channels, - size_t outputH, - size_t outputW, - real *inputGradData); +extern void hl_upsample_backward(real* outputGradData, + real* maskData, + size_t batchSize, + size_t imgSizeH, + size_t imgSizeW, + size_t channels, + size_t outputH, + size_t outputW, + real* inputGradData); #endif // HL_CNN_H_ diff --git a/paddle/cuda/include/stub/hl_cnn_stub.h b/paddle/cuda/include/stub/hl_cnn_stub.h index ef1f67980e..e83db71bb7 100644 --- a/paddle/cuda/include/stub/hl_cnn_stub.h +++ b/paddle/cuda/include/stub/hl_cnn_stub.h @@ -222,22 +222,24 @@ inline void hl_maxout_backward(real* inGrad, size_t featLen, size_t group) {} -inline void hl_upsample_forward(real *inputData, real *maskData, - size_t batchSize, - size_t imgSizeH, - size_t imgSizeW, - size_t channels, - size_t outputH, - size_t outputW, - real *outputData) {} - -inline void hl_upsample_backward(real *outputGradData, real *maskData, - size_t batchSize, - size_t imgSizeH, - size_t imgSizeW, - size_t channels, - size_t outputH, - size_t outputW, - real *inputGradData) {} +inline void hl_upsample_forward(real* inputData, + real* maskData, + size_t batchSize, + size_t imgSizeH, + size_t imgSizeW, + size_t channels, + size_t outputH, + size_t outputW, + real* outputData) {} + +inline void hl_upsample_backward(real* outputGradData, + real* maskData, + size_t batchSize, + size_t imgSizeH, + size_t imgSizeW, + size_t channels, + size_t outputH, + size_t outputW, + real* inputGradData) {} #endif // HL_CNN_STUB_H_ diff --git a/paddle/gserver/layers/UpsampleLayer.cpp b/paddle/gserver/layers/UpsampleLayer.cpp index 300bb82d68..3ff5332e64 100644 --- a/paddle/gserver/layers/UpsampleLayer.cpp +++ b/paddle/gserver/layers/UpsampleLayer.cpp @@ -30,6 +30,7 @@ size_t UpsampleLayer::getOutputSize() { bool UpsampleLayer::init(const LayerMap& layerMap, const ParameterMap& parameterMap) { Layer::init(layerMap, parameterMap); + CHECK_EQ(inputLayers_.size(), 2U); CHECK_EQ(config_.inputs_size(), 2); const auto& conf = config_.inputs(0).upsample_conf(); diff --git a/paddle/gserver/layers/UpsampleLayer.h b/paddle/gserver/layers/UpsampleLayer.h index 2ae9363439..25efbac5e9 100644 --- a/paddle/gserver/layers/UpsampleLayer.h +++ b/paddle/gserver/layers/UpsampleLayer.h @@ -32,7 +32,6 @@ namespace paddle { class UpsampleLayer : public Layer { public: explicit UpsampleLayer(const LayerConfig& config) : Layer(config) {} - ~UpsampleLayer() {} bool init(const LayerMap& layerMap, diff --git a/paddle/math/Matrix.cpp b/paddle/math/Matrix.cpp index 1f6458a288..ad9a73a2bf 100644 --- a/paddle/math/Matrix.cpp +++ b/paddle/math/Matrix.cpp @@ -1024,61 +1024,63 @@ void GpuMatrix::check(std::ostream& os, Matrix& refMat, bool printDiff) { } void GpuMatrix::upsampleForward(Matrix& input, - Matrix& mask, - size_t imgSizeH, - size_t imgSizeW, - size_t channels, - size_t outputH, - size_t outputW) { - CHECK(input.useGpu_ == true) << "Matrix type are not equal"; - CHECK(mask.useGpu_ == true) << "Matrix type are not equal"; - - real *inputData = input.getData(); - real *maskData = mask.getData(); - real *outData = data_; - - size_t batch = input.getHeight(); - - CHECK(imgSizeH * imgSizeW * channels == input.getWidth()); - CHECK(imgSizeH * imgSizeW * channels == mask.getWidth()); - CHECK_EQ(batch, this->getHeight()); - CHECK(width_ == outputH * outputW * channels); - hl_upsample_forward(inputData, maskData, - batch, - imgSizeH, - imgSizeW, - channels, - outputH, - outputW, - outData); + Matrix& mask, + size_t imgSizeH, + size_t imgSizeW, + size_t channels, + size_t outputH, + size_t outputW) { + CHECK(input.useGpu_ == true) << "Matrix type are not equal"; + CHECK(mask.useGpu_ == true) << "Matrix type are not equal"; + + real* inputData = input.getData(); + real* maskData = mask.getData(); + real* outData = data_; + + size_t batch = input.getHeight(); + + CHECK(imgSizeH * imgSizeW * channels == input.getWidth()); + CHECK(imgSizeH * imgSizeW * channels == mask.getWidth()); + CHECK_EQ(batch, this->getHeight()); + CHECK(width_ == outputH * outputW * channels); + hl_upsample_forward(inputData, + maskData, + batch, + imgSizeH, + imgSizeW, + channels, + outputH, + outputW, + outData); } void GpuMatrix::upsampleBackward(Matrix& outputGrad, - Matrix& mask, - size_t imgSizeH, - size_t imgSizeW, - size_t channels, - size_t outputH, - size_t outputW) { - CHECK(outputGrad.useGpu_ == true) << "Matrix type are not equal"; - CHECK(mask.useGpu_ == true) << "Matrix type are not equal"; - - real *outputGradData = outputGrad.getData(); - real *maskData = mask.getData(); - real *inputGradData = data_; - size_t batch = outputGrad.getHeight(); - - CHECK(imgSizeH * imgSizeW == this->getWidth()/channels); - CHECK_EQ(batch, this->getHeight()); - CHECK_EQ(channels * outputH * outputW, outputGrad.getWidth()); - hl_upsample_backward(outputGradData, maskData, - batch, - imgSizeH, - imgSizeW, - channels, - outputH, - outputW, - inputGradData); + Matrix& mask, + size_t imgSizeH, + size_t imgSizeW, + size_t channels, + size_t outputH, + size_t outputW) { + CHECK(outputGrad.useGpu_ == true) << "Matrix type are not equal"; + CHECK(mask.useGpu_ == true) << "Matrix type are not equal"; + + real* outputGradData = outputGrad.getData(); + real* maskData = mask.getData(); + real* inputGradData = data_; + size_t batch = outputGrad.getHeight(); + + CHECK(imgSizeH * imgSizeW == this->getWidth() / channels); + CHECK_EQ(batch, this->getHeight()); + CHECK_EQ(channels * outputH * outputW, outputGrad.getWidth()); + hl_upsample_backward(outputGradData, + maskData, + batch, + imgSizeH, + imgSizeW, + channels, + outputH, + outputW, + inputGradData); } void GpuMatrix::maxPoolForward(Matrix& inputMat, @@ -2040,71 +2042,69 @@ void CpuMatrix::inverse(MatrixPtr& matInv, bool memAlloc) { } void CpuMatrix::upsampleForward(Matrix& input, - Matrix& mask, - size_t imgSizeH, - size_t imgSizeW, - size_t channels, - size_t outputH, - size_t outputW) { - real *inputData = input.getData(); - real *maskData = mask.getData(); - real *outData = data_; - size_t inLength = imgSizeH * imgSizeW; - size_t outLength = outputH * outputW; - size_t batch = input.getHeight(); - CHECK(inLength == input.getWidth() / channels); - CHECK_EQ(batch, this->getHeight()); - CHECK_EQ(channels * outLength, this->getWidth()); - - for (size_t k = 0; k < batch; k++) { - for (size_t c = 0; c < channels; c++) { - for (size_t i = 0; i < inLength; i++) { - size_t out_index = static_cast(maskData[i]); - if (out_index >= outLength) { - LOG(FATAL) << "upsample index " << out_index - << " out of range."; - } - outData[out_index] = inputData[i]; - } - inputData += inLength; - maskData += inLength; - outData += outLength; + Matrix& mask, + size_t imgSizeH, + size_t imgSizeW, + size_t channels, + size_t outputH, + size_t outputW) { + real* inputData = input.getData(); + real* maskData = mask.getData(); + real* outData = data_; + size_t inLength = imgSizeH * imgSizeW; + size_t outLength = outputH * outputW; + size_t batch = input.getHeight(); + CHECK(inLength == input.getWidth() / channels); + CHECK_EQ(batch, this->getHeight()); + CHECK_EQ(channels * outLength, this->getWidth()); + + for (size_t k = 0; k < batch; k++) { + for (size_t c = 0; c < channels; c++) { + for (size_t i = 0; i < inLength; i++) { + size_t out_index = static_cast(maskData[i]); + if (out_index >= outLength) { + LOG(FATAL) << "upsample index " << out_index << " out of range."; } + outData[out_index] = inputData[i]; + } + inputData += inLength; + maskData += inLength; + outData += outLength; } + } } void CpuMatrix::upsampleBackward(Matrix& outputGrad, - Matrix& mask, - size_t imgSizeH, - size_t imgSizeW, - size_t channels, - size_t outputH, - size_t outputW) { - real *outputGradData = outputGrad.getData(); - real *maskData = mask.getData(); - real *inputGradData = data_; - size_t inLength = imgSizeH * imgSizeW; - size_t outLength = outputH * outputW; - size_t batch = outputGrad.getHeight(); - CHECK(inLength == this->getWidth()/channels); - CHECK_EQ(batch, this->getHeight()); - CHECK_EQ(channels * outLength, outputGrad.getWidth()); - - for (size_t k = 0; k < batch; k++) { - for (size_t c = 0; c < channels; c++) { - for (size_t i = 0; i < inLength; i++) { - size_t out_index = static_cast(maskData[i]); - if (out_index >= outLength) { - LOG(FATAL) << "upsample index " << out_index - << " out of range."; - } - inputGradData[i] = outputGradData[out_index]; - } - inputGradData += inLength; - maskData += inLength; - outputGradData += outLength; + Matrix& mask, + size_t imgSizeH, + size_t imgSizeW, + size_t channels, + size_t outputH, + size_t outputW) { + real* outputGradData = outputGrad.getData(); + real* maskData = mask.getData(); + real* inputGradData = data_; + size_t inLength = imgSizeH * imgSizeW; + size_t outLength = outputH * outputW; + size_t batch = outputGrad.getHeight(); + CHECK(inLength == this->getWidth() / channels); + CHECK_EQ(batch, this->getHeight()); + CHECK_EQ(channels * outLength, outputGrad.getWidth()); + + for (size_t k = 0; k < batch; k++) { + for (size_t c = 0; c < channels; c++) { + for (size_t i = 0; i < inLength; i++) { + size_t out_index = static_cast(maskData[i]); + if (out_index >= outLength) { + LOG(FATAL) << "upsample index " << out_index << " out of range."; } + inputGradData[i] = outputGradData[out_index]; + } + inputGradData += inLength; + maskData += inLength; + outputGradData += outLength; } + } } void CpuMatrix::maxPoolForward(Matrix& inputMat, diff --git a/paddle/math/Matrix.h b/paddle/math/Matrix.h index b4fcf05cb2..6e9ea04d66 100644 --- a/paddle/math/Matrix.h +++ b/paddle/math/Matrix.h @@ -860,22 +860,22 @@ public: } virtual void upsampleForward(Matrix& input, - Matrix& mask, - size_t imgSizeH, - size_t imgSizeW, - size_t channels, - size_t outputH, - size_t outputW) { + Matrix& mask, + size_t imgSizeH, + size_t imgSizeW, + size_t channels, + size_t outputH, + size_t outputW) { LOG(FATAL) << "Not implemeted"; } virtual void upsampleBackward(Matrix& outputGrad, - Matrix& mask, - size_t imgSizeH, - size_t imgSizeW, - size_t channels, - size_t outputH, - size_t outputW) { + Matrix& mask, + size_t imgSizeH, + size_t imgSizeW, + size_t channels, + size_t outputH, + size_t outputW) { LOG(FATAL) << "Not implemeted"; } @@ -1438,20 +1438,20 @@ public: void classificationError(Matrix& output, IVector& label, size_t topkSize = 1); void upsampleForward(Matrix& input, - Matrix& mask, - size_t imgSizeH, - size_t imgSizeW, - size_t channels, - size_t outputH, - size_t outputW); + Matrix& mask, + size_t imgSizeH, + size_t imgSizeW, + size_t channels, + size_t outputH, + size_t outputW); void upsampleBackward(Matrix& outputGrad, - Matrix& mask, - size_t imgSizeH, - size_t imgSizeW, - size_t channels, - size_t outputH, - size_t outputW); + Matrix& mask, + size_t imgSizeH, + size_t imgSizeW, + size_t channels, + size_t outputH, + size_t outputW); void maxPoolForward(Matrix& inputMat, size_t imgSizeH, @@ -1726,20 +1726,20 @@ public: MatrixPtr clone(size_t height, size_t width, bool useGpu = false); void upsampleForward(Matrix& input, - Matrix& mask, - size_t imgSizeH, - size_t imgSizeW, - size_t channels, - size_t outputH, - size_t outputW); + Matrix& mask, + size_t imgSizeH, + size_t imgSizeW, + size_t channels, + size_t outputH, + size_t outputW); void upsampleBackward(Matrix& outputGrad, - Matrix& mask, - size_t imgSizeH, - size_t imgSizeW, - size_t channels, - size_t outputH, - size_t outputW); + Matrix& mask, + size_t imgSizeH, + size_t imgSizeW, + size_t channels, + size_t outputH, + size_t outputW); void maxPoolForward(Matrix& inputMat, size_t imgSizeH, diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index 067ca21d32..7563368ad7 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -978,12 +978,14 @@ class Pad(Cfg): def __init__(self, channels, pad_c, pad_h, pad_w): self.add_keys(locals()) + @config_class class Upsample(Cfg): def __init__(self, scale, scale_y, pad_out_x, pad_out_y, upsample_size, upsample_size_y): self.add_keys(locals()) + @config_class class Norm(Cfg): def __init__(self, @@ -2393,6 +2395,7 @@ class SpatialPyramidPoolLayer(LayerBase): output_x = (pow(4, spp_conf.pyramid_height) - 1) / (4 - 1) self.set_cnn_layer(name, 1, output_x, spp_conf.image_conf.channels) + @config_layer('upsample') class UpsampleLayer(LayerBase): def __init__(self, name, inputs, **xargs): @@ -2407,9 +2410,10 @@ class UpsampleLayer(LayerBase): input_layer.height) upsample = self.inputs[0].upsample - output_x = 0 + output_x = 0 output_y = 0 output_size = 0 + if upsample.scale: self.config.inputs[0].upsample_conf.scale = upsample.scale self.config.inputs[0].upsample_conf.scale_y = upsample.scale_y @@ -2427,11 +2431,11 @@ class UpsampleLayer(LayerBase): output_size = image_conf.channels * output_x * output_y - self.set_layer_height_width(output_y, output_x) self.set_layer_depth(input_layer.depth) self.set_layer_size(output_size) + @config_layer('pad') class PadLayer(LayerBase): def __init__(self, name, inputs, **xargs): diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index 95369000bb..1ce603389d 100644 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -2881,6 +2881,7 @@ def img_pool3d_layer(input, num_filters=num_channels, size=l.config.size) + @wrap_name_default("upsample") @layer_support() def upsample_layer(input, @@ -2930,6 +2931,7 @@ def upsample_layer(input, 'scale or upsample_size, there must be one to be designated' assert len(input) == 2, 'layer input size must be 2' + assert input[1].layer_type == LayerType.POOL_LAYER, \ 'the second input should be the MaxPoolWithMaskLayer' From 76941d90b1c38b121d711a6e4455f73dfba8f14f Mon Sep 17 00:00:00 2001 From: xzl Date: Wed, 13 Dec 2017 16:31:52 +0800 Subject: [PATCH 3/3] add upsample cpu&gpu forward&backward compare test --- paddle/gserver/tests/CMakeLists.txt | 1 + paddle/gserver/tests/test_Upsample.cpp | 152 +++++++++++++++++++++++++ 2 files changed, 153 insertions(+) create mode 100644 paddle/gserver/tests/test_Upsample.cpp diff --git a/paddle/gserver/tests/CMakeLists.txt b/paddle/gserver/tests/CMakeLists.txt index c295ea19c9..5ef2726764 100644 --- a/paddle/gserver/tests/CMakeLists.txt +++ b/paddle/gserver/tests/CMakeLists.txt @@ -28,6 +28,7 @@ gserver_test(test_BatchNorm) gserver_test(test_KmaxSeqScore) gserver_test(test_Expand) gserver_test(test_MaxPoolingWithMaskOutput) +gserver_test(test_Upsample) ########## test_MKLDNN layers and activations ########## if(WITH_MKLDNN) diff --git a/paddle/gserver/tests/test_Upsample.cpp b/paddle/gserver/tests/test_Upsample.cpp new file mode 100644 index 0000000000..9d6fa1d130 --- /dev/null +++ b/paddle/gserver/tests/test_Upsample.cpp @@ -0,0 +1,152 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include + +#include "LayerGradUtil.h" +#include "paddle/math/MathUtils.h" +#include "paddle/testing/TestUtil.h" + +using namespace paddle; + +void setPoolConfig(TestConfig* config, + PoolConfig* pool, + const string& poolType) { + (*config).biasSize = 0; + (*config).layerConfig.set_type("pool"); + (*config).layerConfig.set_num_filters(1); + + int kw = 2, kh = 2; + int pw = 0, ph = 0; + int sw = 2, sh = 2; + pool->set_pool_type(poolType); + pool->set_channels(2); + pool->set_size_x(kw); + pool->set_size_y(kh); + pool->set_start(0); + pool->set_padding(pw); + pool->set_padding_y(ph); + pool->set_stride(sw); + pool->set_stride_y(sh); + + int ow = outputSize(pool->img_size(), kw, pw, sw, /* caffeMode */ false); + int oh = outputSize(pool->img_size_y(), kh, ph, sh, /* caffeMode */ false); + pool->set_output_x(ow); + pool->set_output_y(oh); +} + +LayerPtr doOneUpsampleTest(MatrixPtr& inputMat, + const string& poolType, + bool use_gpu, + real* tempGradData) { + /* prepare maxPoolWithMaskLayer */ + TestConfig config; + config.inputDefs.push_back({INPUT_DATA, "layer_0", 128, 0}); + LayerInputConfig* input = config.layerConfig.add_inputs(); + PoolConfig* pool = input->mutable_pool_conf(); + + pool->set_img_size(8); + pool->set_img_size_y(8); + setPoolConfig(&config, pool, "max-pool-with-mask"); + config.layerConfig.set_size(pool->output_x() * pool->output_y() * + pool->channels()); + + config.layerConfig.set_name("MaxPoolWithMask"); + + std::vector dataLayers; + LayerMap layerMap; + vector datas; + + initDataLayer(config, + &dataLayers, + &datas, + &layerMap, + "MaxPoolWithMask", + 1, + false, + use_gpu); + + dataLayers[0]->getOutputValue()->copyFrom(*inputMat); + + FLAGS_use_gpu = use_gpu; + std::vector parameters; + LayerPtr maxPoolingWithMaskOutputLayer; + initTestLayer(config, &layerMap, ¶meters, &maxPoolingWithMaskOutputLayer); + maxPoolingWithMaskOutputLayer->forward(PASS_GC); + + /* prepare the upsample layer */ + LayerConfig upsampleLayerConfig; + upsampleLayerConfig.set_type("upsample"); + LayerInputConfig* input1 = upsampleLayerConfig.add_inputs(); + upsampleLayerConfig.add_inputs(); + + UpsampleConfig* upsampleConfig = input1->mutable_upsample_conf(); + upsampleConfig->set_scale(2); + ImageConfig* imageConfig = upsampleConfig->mutable_image_conf(); + imageConfig->set_channels(2); + imageConfig->set_img_size(4); + imageConfig->set_img_size_y(4); + upsampleLayerConfig.set_size(2 * 8 * 8); + upsampleLayerConfig.set_name("upsample"); + + for (size_t i = 0; i < 2; i++) { + LayerInputConfig& inputTemp = *(upsampleLayerConfig.mutable_inputs(i)); + inputTemp.set_input_layer_name("MaxPoolWithMask"); + } + + LayerPtr upsampleLayer; + ParameterMap parameterMap; + upsampleLayer = Layer::create(upsampleLayerConfig); + layerMap[upsampleLayerConfig.name()] = upsampleLayer; + upsampleLayer->init(layerMap, parameterMap); + upsampleLayer->setNeedGradient(true); + upsampleLayer->forward(PASS_GC); + upsampleLayer->getOutputGrad()->copyFrom(tempGradData, 128); + upsampleLayer->backward(); + + return upsampleLayer; +} + +TEST(Layer, maxPoolingWithMaskOutputLayerFwd) { + bool useGpu = false; + MatrixPtr inputMat; + MatrixPtr inputGPUMat; + MatrixPtr tempGradMat; + + inputMat = Matrix::create(1, 128, false, useGpu); + inputMat->randomizeUniform(); + + tempGradMat = Matrix::create(1, 128, false, useGpu); + tempGradMat->randomizeUniform(); + real* data = inputMat->getData(); + real* tempGradData = tempGradMat->getData(); + + LayerPtr upsampleLayerCPU = + doOneUpsampleTest(inputMat, "max-pool-with-mask", useGpu, tempGradData); + +#ifdef PADDLE_WITH_CUDA + useGpu = true; + inputGPUMat = Matrix::create(1, 128, false, useGpu); + inputGPUMat->copyFrom(data, 128); + LayerPtr upsampleLayerGPU = doOneUpsampleTest( + inputGPUMat, "max-pool-with-mask", useGpu, tempGradData); + checkMatrixEqual(upsampleLayerCPU->getOutput("").value, + upsampleLayerGPU->getOutput("").value); + + checkMatrixEqual(upsampleLayerCPU->getPrev(0)->getOutputGrad(), + upsampleLayerGPU->getPrev(0)->getOutputGrad()); +#endif +}