nhwc optimization for batchnorm (#21090)

paddle_tiny_install
Jie Fang 5 years ago committed by gongweibao
parent fce24315fb
commit 5e813b53c5

@ -141,6 +141,10 @@ class GradOpDescMakerBase {
return (fwd_op_.Inputs().count(name) > 0);
}
bool HasOutput(const std::string& name) const {
return (fwd_op_.Outputs().count(name) > 0);
}
private:
const OpDesc& fwd_op_;
const std::unordered_set<std::string>& no_grad_set_;

@ -107,6 +107,12 @@ class GradOpBaseMakerBase {
return it != var_base_map_in_.end();
}
bool HasOutput(const std::string name) const {
auto it = var_base_map_out_.find(name);
return it != var_base_map_out_.end();
}
private:
std::vector<std::shared_ptr<VarBase>> GetVarBaseList(const std::string& name,
bool is_grad,

@ -25,27 +25,42 @@ namespace paddle {
namespace operators {
void BatchNormOp::InferShape(framework::InferShapeContext *ctx) const {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of ConvOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Scale"),
"Input(Scale) of ConvOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Bias"),
"Input(Bias) of ConvOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Mean"),
"Input(Mean) of ConvOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Variance"),
"Input(Variance) of ConvOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Y"),
"Output(Y) of ConvOp should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
platform::errors::InvalidArgument(
"Input(X) of BatchNormOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("Scale"), true,
platform::errors::InvalidArgument(
"Input(Scale) of BatchNormOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("Bias"), true,
platform::errors::InvalidArgument(
"Input(Bias) of BatchNormOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("Mean"), true,
platform::errors::InvalidArgument(
"Input(Mean) of BatchNormOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("Variance"), true,
platform::errors::InvalidArgument(
"Input(Variance) of BatchNormOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Y"), true,
platform::errors::InvalidArgument(
"Output(Y) of BatchNormOp should not be null."));
bool is_test = ctx->Attrs().Get<bool>("is_test");
if (!is_test) {
PADDLE_ENFORCE(ctx->HasOutput("MeanOut"),
"Output(MeanOut) of ConvOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("VarianceOut"),
"Output(VarianceOut) of ConvOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("SavedMean"),
"Output(SavedMean) of ConvOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("SavedVariance"),
"Output(SavedVariance) of ConvOp should not be null.");
PADDLE_ENFORCE_EQ(
ctx->HasOutput("MeanOut"), true,
platform::errors::InvalidArgument(
"Output(MeanOut) of BatchNormOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("VarianceOut"), true,
platform::errors::InvalidArgument(
"Output(VarianceOut) of BatchNormOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("SavedMean"), true,
platform::errors::InvalidArgument(
"Output(SavedMean) of BatchNormOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("SavedVariance"), true,
platform::errors::InvalidArgument(
"Output(SavedVariance) of BatchNormOp should not be null."));
}
// make sure Mean/MeanOut and Variance/VarianceOut share memory in Python
@ -200,6 +215,10 @@ void BatchNormOpMaker::Make() {
"Variance of the current mini batch, "
"will apply to output when training")
.AsIntermediate();
AddOutput("ReserveSpace",
"Reserve GPU space for triggering the new semi-persistent "
"NHWC kernel")
.AsDispensable();
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
@ -643,6 +662,9 @@ std::unique_ptr<T> BatchNormGradMaker<T>::Apply() const {
op->SetInput("Bias", this->Input("Bias"));
op->SetInput("SavedMean", this->Output("SavedMean"));
op->SetInput("SavedVariance", this->Output("SavedVariance"));
if (this->HasOutput("ReserveSpace")) {
op->SetInput("ReserveSpace", this->Output("ReserveSpace"));
}
// used when setting use_global_stats True during training
if (boost::get<bool>(this->GetAttr("use_global_stats"))) {

File diff suppressed because it is too large Load Diff

@ -16,8 +16,10 @@ limitations under the License. */
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/norm_utils.h"
namespace paddle {
@ -39,24 +41,109 @@ template <typename T>
using ConstEigenVectorArrayMap =
Eigen::Map<const Eigen::Array<T, Eigen::Dynamic, 1>>;
template <typename DeviceContext, typename T>
inline void ResizeToChannelFirst(const framework::ExecutionContext& context,
const Tensor* input,
Tensor* transformed_input) {
int dim = input->dims().size() - 2;
if (dim == 3) {
// input
transformed_input->Resize(input->dims());
auto in_dims_vec = framework::vectorize(input->dims());
in_dims_vec[1] = input->dims()[4];
in_dims_vec[2] = input->dims()[1];
in_dims_vec[3] = input->dims()[2];
in_dims_vec[4] = input->dims()[3];
transformed_input->Resize(framework::make_ddim(in_dims_vec));
transformed_input->mutable_data<T>(context.GetPlace());
} else if (dim == 2) {
// input
transformed_input->Resize(input->dims());
auto in_dims_vec = framework::vectorize(input->dims());
in_dims_vec[1] = input->dims()[3];
in_dims_vec[2] = input->dims()[1];
in_dims_vec[3] = input->dims()[2];
transformed_input->Resize(framework::make_ddim(in_dims_vec));
transformed_input->mutable_data<T>(context.GetPlace());
} else if (dim == 1) {
transformed_input->Resize(input->dims());
auto in_dims_vec = framework::vectorize(input->dims());
in_dims_vec[1] = input->dims()[2];
in_dims_vec[2] = input->dims()[1];
transformed_input->Resize(framework::make_ddim(in_dims_vec));
transformed_input->mutable_data<T>(context.GetPlace());
}
}
template <typename DeviceContext, typename T>
inline void TransToChannelFirst(const framework::ExecutionContext& context,
const Tensor* input,
Tensor* transformed_input) {
int dim = input->dims().size() - 2;
if (dim == 3) {
auto& dev_ctx = context.template device_context<DeviceContext>();
std::vector<int> axis{0, 4, 1, 2, 3};
math::Transpose<DeviceContext, T, 5> trans5;
trans5(dev_ctx, *input, transformed_input, axis);
} else if (dim == 2) {
auto& dev_ctx = context.template device_context<DeviceContext>();
std::vector<int> axis{0, 3, 1, 2};
math::Transpose<DeviceContext, T, 4> trans4;
trans4(dev_ctx, *input, transformed_input, axis);
} else if (dim == 1) {
auto& dev_ctx = context.template device_context<DeviceContext>();
std::vector<int> axis{0, 2, 1};
math::Transpose<DeviceContext, T, 3> trans3;
trans3(dev_ctx, *input, transformed_input, axis);
}
}
template <typename DeviceContext, typename T>
inline void TransToChannelLast(const framework::ExecutionContext& context,
const Tensor* input, Tensor* transformed_input) {
int dim = input->dims().size() - 2;
if (dim == 3) {
auto& dev_ctx = context.template device_context<DeviceContext>();
std::vector<int> axis{0, 2, 3, 4, 1};
math::Transpose<DeviceContext, T, 5> trans5;
trans5(dev_ctx, *input, transformed_input, axis);
} else if (dim == 2) {
auto& dev_ctx = context.template device_context<DeviceContext>();
std::vector<int> axis{0, 2, 3, 1};
math::Transpose<DeviceContext, T, 4> trans4;
trans4(dev_ctx, *input, transformed_input, axis);
} else if (dim == 1) {
auto& dev_ctx = context.template device_context<DeviceContext>();
std::vector<int> axis{0, 2, 1};
math::Transpose<DeviceContext, T, 3> trans3;
trans3(dev_ctx, *input, transformed_input, axis);
}
}
class BatchNormOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override;
void InferShape(framework::InferShapeContext* ctx) const override;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override;
const framework::ExecutionContext& ctx) const override;
};
class BatchNormGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override;
void InferShape(framework::InferShapeContext* ctx) const override;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override;
const framework::ExecutionContext& ctx) const override;
};
class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker {
@ -85,13 +172,13 @@ class BatchNormOpInferVarType
template <typename DeviceContext, typename T>
class BatchNormKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override;
void Compute(const framework::ExecutionContext& ctx) const override;
};
template <typename DeviceContext, typename T>
class BatchNormGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override;
void Compute(const framework::ExecutionContext& ctx) const override;
};
} // namespace operators

@ -46,6 +46,10 @@ CUDNN_DNN_ROUTINE_EACH_R6(DEFINE_WRAP);
CUDNN_DNN_ROUTINE_EACH_R7(DEFINE_WRAP);
#endif
#ifdef CUDNN_DNN_ROUTINE_EACH_AFTER_R7
CUDNN_DNN_ROUTINE_EACH_AFTER_R7(DEFINE_WRAP);
#endif
#ifdef PADDLE_USE_DSO
bool HasCUDNN() {
std::call_once(cudnn_dso_flag,

@ -189,6 +189,15 @@ CUDNN_DNN_ROUTINE_EACH_R6(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP)
CUDNN_DNN_ROUTINE_EACH_R7(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP)
#endif
#if CUDNN_VERSION >= 7401
#define CUDNN_DNN_ROUTINE_EACH_AFTER_R7(__macro) \
__macro(cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize); \
__macro(cudnnBatchNormalizationForwardTrainingEx); \
__macro(cudnnGetBatchNormalizationBackwardExWorkspaceSize); \
__macro(cudnnBatchNormalizationBackwardEx); \
__macro(cudnnGetBatchNormalizationTrainingExReserveSpaceSize);
CUDNN_DNN_ROUTINE_EACH_AFTER_R7(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP)
#endif
} // namespace dynload
} // namespace platform
} // namespace paddle

@ -2523,6 +2523,13 @@ def batch_norm(input,
check_type_and_dtype(input, 'input', Variable,
['float16', 'float32', 'float64'], 'batch_norm')
dtype = helper.input_dtype()
has_reserve_space = False
if data_layout == 'NHWC':
flag = os.environ.get('FLAGS_cudnn_batchnorm_spatial_persistent')
if flag is not None and flag.lower() in ['true', '1']:
has_reserve_space = True
# use fp32 for bn parameter
if dtype == core.VarDesc.VarType.FP16:
dtype = core.VarDesc.VarType.FP32
@ -2577,6 +2584,11 @@ def batch_norm(input,
saved_variance = helper.create_variable_for_type_inference(
dtype=dtype, stop_gradient=True)
reserve_space = None
if has_reserve_space:
reserve_space = helper.create_variable_for_type_inference(
dtype=core.VarDesc.VarType.FP16, stop_gradient=True)
batch_norm_out = input if in_place else helper.create_variable_for_type_inference(
dtype)
@ -2599,17 +2611,19 @@ def batch_norm(input,
inputs['MomemtumTensor'] = momentum
else:
attrs['momentum'] = momentum
outputs = {
"Y": batch_norm_out,
"MeanOut": mean_out,
"VarianceOut": variance_out,
"SavedMean": saved_mean,
"SavedVariance": saved_variance
}
if reserve_space is not None:
outputs["ReserveSpace"] = reserve_space
helper.append_op(
type="batch_norm",
inputs=inputs,
outputs={
"Y": batch_norm_out,
"MeanOut": mean_out,
"VarianceOut": variance_out,
"SavedMean": saved_mean,
"SavedVariance": saved_variance
},
attrs=attrs)
type="batch_norm", inputs=inputs, outputs=outputs, attrs=attrs)
return helper.append_activation(batch_norm_out)

@ -14,6 +14,7 @@
from __future__ import print_function
import os
import unittest
import numpy as np
import paddle.fluid.core as core
@ -413,16 +414,28 @@ class TestBatchNormOpTraining(unittest.TestCase):
inputs['MomentumTensor'] = block.var('momentum_var')
else:
attrs['momentum'] = momentum
outputs = {
"Y": block.var('y'),
"MeanOut": block.var('mean'), # share memory
"VarianceOut": block.var('variance'), # share memory
"SavedMean": block.var('saved_mean'),
"SavedVariance": block.var('saved_variance')
}
has_reserve_space = False
if data_format == 'NHWC':
flag = os.environ.get(
'FLAGS_cudnn_batchnorm_spatial_persistent')
if flag is not None and flag.lower() in ['true', '1']:
has_reserve_space = True
if has_reserve_space:
block.create_var(name="reserve_space", dtype='float16')
outputs["ReserveSpace"] = block.var('reserve_space')
del os.environ['FLAGS_cudnn_batchnorm_spatial_persistent']
bn_op = block.append_op(
type="batch_norm",
inputs=inputs,
outputs={
"Y": block.var('y'),
"MeanOut": block.var('mean'), # share memory
"VarianceOut": block.var('variance'), # share memory
"SavedMean": block.var('saved_mean'),
"SavedVariance": block.var('saved_variance')
},
outputs=outputs,
attrs=attrs)
block.create_var(name='y@GRAD', dtype='float32', shape=y.shape)
@ -479,6 +492,17 @@ class TestBatchNormOpTrainingCase1(TestBatchNormOpTraining):
self.fetch_list = ['y', 'mean', 'variance', 'x@GRAD']
class TestBatchNormOpTrainingCase2(TestBatchNormOpTraining):
def init_test_case(self):
self.use_global_stats = False
self.no_grad_set = set()
self.fetch_list = [
'y', 'mean', 'variance', 'saved_mean', 'saved_variance', 'x@GRAD',
'scale@GRAD', 'bias@GRAD'
]
os.environ['FLAGS_cudnn_batchnorm_spatial_persistent'] = "1"
class TestBatchNormOpTrainingMomentumVariable(TestBatchNormOpTraining):
def init_test_case(self):
self.use_momentum_variable = True

Loading…
Cancel
Save