From 5d98b6f217f8c59ae32f7dabefb69037d80f9cb2 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Mon, 21 Aug 2017 16:32:29 +0800 Subject: [PATCH] Adapting to the BatchNorm structure to support 3D data --- paddle/gserver/layers/BatchNormBaseLayer.cpp | 6 ++- paddle/gserver/layers/BatchNormBaseLayer.h | 1 + paddle/gserver/tests/test_LayerGrad.cpp | 49 ++++++++++++++++++++ paddle/parameter/Argument.cpp | 2 + paddle/parameter/Argument.h | 8 ++-- proto/ModelConfig.proto | 13 ++++++ 6 files changed, 75 insertions(+), 4 deletions(-) diff --git a/paddle/gserver/layers/BatchNormBaseLayer.cpp b/paddle/gserver/layers/BatchNormBaseLayer.cpp index 1ceaaaa206..f7a80e23e1 100644 --- a/paddle/gserver/layers/BatchNormBaseLayer.cpp +++ b/paddle/gserver/layers/BatchNormBaseLayer.cpp @@ -62,14 +62,18 @@ void BatchNormBaseLayer::calFeatureMapSize() { const ImageConfig& conf = config_.inputs(0).image_conf(); imageH_ = inputLayers_[0]->getOutput().getFrameHeight(); imageW_ = inputLayers_[0]->getOutput().getFrameWidth(); + imageD_ = inputLayers_[0]->getOutput().getFrameDepth(); + + if (0 == imageD_) imageD_ = conf.img_size_z(); if (imageH_ == 0 && imageW_ == 0) { imageH_ = conf.has_img_size_y() ? conf.img_size_y() : conf.img_size(); imageW_ = conf.img_size(); } else { getOutput().setFrameHeight(imageH_); getOutput().setFrameWidth(imageW_); + getOutput().setFrameDepth(imageD_); } - imgPixels_ = imageH_ * imageW_; + imgPixels_ = imageH_ * imageW_ * imageD_; } } // namespace paddle diff --git a/paddle/gserver/layers/BatchNormBaseLayer.h b/paddle/gserver/layers/BatchNormBaseLayer.h index 230bafc31d..e721d2d267 100644 --- a/paddle/gserver/layers/BatchNormBaseLayer.h +++ b/paddle/gserver/layers/BatchNormBaseLayer.h @@ -80,6 +80,7 @@ protected: /// Height or width of input image feature. /// Both of them are 1 if the input is fully-connected layer. + int imageD_; int imageH_; int imageW_; /// Height * Width. diff --git a/paddle/gserver/tests/test_LayerGrad.cpp b/paddle/gserver/tests/test_LayerGrad.cpp index 0f312b6ca5..6418772584 100644 --- a/paddle/gserver/tests/test_LayerGrad.cpp +++ b/paddle/gserver/tests/test_LayerGrad.cpp @@ -1594,6 +1594,55 @@ TEST(Layer, BatchNormalizationLayer) { #endif } +void testBatchNorm3DLayer(const string& type, bool trans, bool useGpu) { + TestConfig config; + const int CHANNELS = 10; + const int IMG_SIZE = 16; + const int IMG_SIZE_Y = 8; + const int IMG_SIZE_Z = 8; + size_t size = CHANNELS * IMG_SIZE * IMG_SIZE_Y * IMG_SIZE_Z; + config.layerConfig.set_type(type); + config.layerConfig.set_size(size); + config.layerConfig.set_active_type("sigmoid"); + config.biasSize = CHANNELS; + config.inputDefs.push_back({INPUT_DATA, + "layer_0", + /* dim= */ size, + /* paraSize= */ CHANNELS}); + + config.inputDefs.push_back({INPUT_DATA, "layer_1_running_mean", 1, CHANNELS}); + config.inputDefs.back().isStatic = true; + config.inputDefs.push_back({INPUT_DATA, "layer_2_running_var", 1, CHANNELS}); + config.inputDefs.back().isStatic = true; + + LayerInputConfig* input = config.layerConfig.add_inputs(); + config.layerConfig.add_inputs(); + config.layerConfig.add_inputs(); + + ImageConfig* img_conf = input->mutable_image_conf(); + img_conf->set_channels(CHANNELS); + img_conf->set_img_size(IMG_SIZE); + img_conf->set_img_size_y(IMG_SIZE_Y); + img_conf->set_img_size_z(IMG_SIZE_Z); + + testLayerGrad(config, + "batch_norm", + 64, + /* trans= */ trans, + useGpu, + /* useWeight */ true); +} + +TEST(Layer, testBatchNorm3DLayer) { + testBatchNorm3DLayer("batch_norm", false, false); +#ifndef PADDLE_ONLY_CPU + testBatchNorm3DLayer("batch_norm", false, true); + if (hl_get_cudnn_lib_version() >= int(4000)) { + testBatchNorm3DLayer("cudnn_batch_norm", false, true); + } +#endif +} + void testConvOperator(bool isDeconv) { TestConfig config; const int NUM_FILTERS = 16; diff --git a/paddle/parameter/Argument.cpp b/paddle/parameter/Argument.cpp index 0547ac93cd..77fd0c5890 100644 --- a/paddle/parameter/Argument.cpp +++ b/paddle/parameter/Argument.cpp @@ -186,6 +186,7 @@ void Argument::resizeAndCopyFrom(const Argument& src, resizeAndCopy(strs, src.strs, useGpu, stream); frameWidth = src.frameWidth; frameHeight = src.frameHeight; + frameDepth = src.frameDepth; } int32_t Argument::resizeAndCopyFrom(const Argument& src, @@ -206,6 +207,7 @@ int32_t Argument::resizeAndCopyFrom(const Argument& src, dataId = src.dataId; frameWidth = src.frameWidth; frameHeight = src.frameHeight; + frameDepth = src.frameDepth; if (!src.sequenceStartPositions) { // non-sequence input, copy samples directly diff --git a/paddle/parameter/Argument.h b/paddle/parameter/Argument.h index d8d7a4398f..ba3ad2fd4d 100644 --- a/paddle/parameter/Argument.h +++ b/paddle/parameter/Argument.h @@ -1,11 +1,8 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -35,6 +32,7 @@ struct Argument { strs(nullptr), frameHeight(0), frameWidth(0), + frameDepth(0), sequenceStartPositions(nullptr), subSequenceStartPositions(nullptr), cpuSequenceDims(nullptr), @@ -64,6 +62,7 @@ struct Argument { allCount = argument.allCount; frameHeight = argument.frameHeight; frameWidth = argument.frameWidth; + frameDepth = argument.frameDepth; dataId = argument.dataId; } @@ -76,6 +75,7 @@ struct Argument { // A dataBatch includes batchSize frames, one frame maybe not only vector size_t frameHeight; size_t frameWidth; + size_t frameDepth; // If NULL, each position is treated independently. // Otherwise, its size should be #NumberOfSequences + 1. @@ -136,8 +136,10 @@ struct Argument { } size_t getFrameHeight() const { return frameHeight; } size_t getFrameWidth() const { return frameWidth; } + size_t getFrameDepth() const { return frameDepth; } void setFrameHeight(size_t h) { frameHeight = h; } void setFrameWidth(size_t w) { frameWidth = w; } + void setFrameDepth(size_t d) { frameDepth = d; } int64_t getNumSequences() const { return sequenceStartPositions ? sequenceStartPositions->getSize() - 1 diff --git a/proto/ModelConfig.proto b/proto/ModelConfig.proto index 4f3d5bf3f6..ef2b076c33 100644 --- a/proto/ModelConfig.proto +++ b/proto/ModelConfig.proto @@ -82,6 +82,12 @@ message ConvConfig { // if not set, use img_size optional uint32 img_size_y = 14; + + optional uint32 filter_size_z = 15 [ default = 1 ]; + optional uint32 padding_z = 16 [ default = 1 ]; + optional uint32 stride_z = 17 [ default = 1 ]; + optional uint32 output_z = 18 [ default = 1 ]; + optional uint32 img_size_z = 19 [ default = 1 ]; } message PoolConfig { @@ -124,6 +130,12 @@ message PoolConfig { // if not set, use padding optional uint32 padding_y = 13; + + optional uint32 size_z = 14 [ default = 1 ]; + optional uint32 stride_z = 15 [ default = 1 ]; + optional uint32 output_z = 16 [ default = 1 ]; + optional uint32 img_size_z = 17 [ default = 1 ]; + optional uint32 padding_z = 18 [ default = 1 ]; } message SppConfig { @@ -256,6 +268,7 @@ message ImageConfig { // The size of input feature map. required uint32 img_size = 8; optional uint32 img_size_y = 9; + optional uint32 img_size_z = 10 [ default = 1 ]; } message PriorBoxConfig {