add double grad compute for batch norm (#27296)

* add double grad compute for batch norm,test=develop

* fix unittest, test=develop

* remove unuse tensor,test=develop

* add format,test=develop

* update, test=develop
revert-27520-disable_pr
ceci3 4 years ago committed by GitHub
parent 4bd7aa2566
commit 1d3b27cae8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

File diff suppressed because it is too large Load Diff

@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/operators/batch_norm_op.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/norm_utils.cu.h"
#include "paddle/fluid/platform/cudnn_helper.h"
#include "paddle/fluid/platform/float16.h"
@ -840,6 +841,45 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
}
};
template <typename T>
class BatchNormDoubleGradKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
const auto *X = ctx.Input<Tensor>("X");
const auto *Scale = ctx.Input<Tensor>("Scale");
const auto *dY = ctx.Input<Tensor>("DY");
const auto *Saved_mean = ctx.Input<Tensor>("SavedMean");
const auto *Saved_variance = ctx.Input<Tensor>("SavedVariance");
const double epsilon = static_cast<double>(ctx.Attr<float>("epsilon"));
const bool use_global_stats = ctx.Attr<bool>("use_global_stats");
const bool is_test = ctx.Attr<bool>("is_test");
PADDLE_ENFORCE_EQ(
is_test, false,
platform::errors::InvalidArgument(
"`is_test = True` CANNOT be used in train program. If "
"you want to use global status in pre_train model, "
"please set `use_global_stats = True`"));
const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
const DataLayout data_layout =
framework::StringToDataLayout(data_layout_str);
const auto *ddX = ctx.Input<Tensor>("DDX");
const auto *ddScale = ctx.Input<Tensor>("DDScale");
const auto *ddBias = ctx.Input<Tensor>("DDBias");
auto *dX = ctx.Output<Tensor>("DX");
auto *dScale = ctx.Output<Tensor>("DScale");
auto *ddY = ctx.Output<Tensor>("DDY");
NormDoubleGradFunctor<platform::CUDADeviceContext, T>(
ctx, data_layout, X, Scale, dY, Saved_mean, Saved_variance, epsilon,
use_global_stats, ddX, ddScale, ddBias, dX, dScale, ddY);
}
};
} // namespace operators
} // namespace paddle
@ -853,3 +893,7 @@ REGISTER_OP_CUDA_KERNEL(
batch_norm_grad, ops::BatchNormGradKernel<plat::CUDADeviceContext, float>,
ops::BatchNormGradKernel<plat::CUDADeviceContext, double>,
ops::BatchNormGradKernel<plat::CUDADeviceContext, plat::float16>);
REGISTER_OP_CUDA_KERNEL(
batch_norm_grad_grad,
ops::BatchNormDoubleGradKernel<plat::CUDADeviceContext, float>,
ops::BatchNormDoubleGradKernel<plat::CUDADeviceContext, double>);

@ -103,6 +103,42 @@ inline void TransToChannelFirst(const framework::ExecutionContext& context,
}
}
template <typename DeviceContext, typename T>
inline void ResizeToChannelLast(const framework::ExecutionContext& context,
const Tensor* input,
Tensor* transformed_input) {
int dim = input->dims().size() - 2;
if (dim == 3) {
transformed_input->Resize(input->dims());
auto in_dims_vec = framework::vectorize(input->dims());
in_dims_vec[1] = input->dims()[2];
in_dims_vec[2] = input->dims()[3];
in_dims_vec[3] = input->dims()[4];
in_dims_vec[4] = input->dims()[1];
transformed_input->Resize(framework::make_ddim(in_dims_vec));
transformed_input->mutable_data<T>(context.GetPlace());
} else if (dim == 2) {
transformed_input->Resize(input->dims());
auto in_dims_vec = framework::vectorize(input->dims());
in_dims_vec[1] = input->dims()[2];
in_dims_vec[2] = input->dims()[3];
in_dims_vec[3] = input->dims()[1];
transformed_input->Resize(framework::make_ddim(in_dims_vec));
transformed_input->mutable_data<T>(context.GetPlace());
} else if (dim == 1) {
transformed_input->Resize(input->dims());
auto in_dims_vec = framework::vectorize(input->dims());
in_dims_vec[1] = input->dims()[2];
in_dims_vec[2] = input->dims()[1];
transformed_input->Resize(framework::make_ddim(in_dims_vec));
transformed_input->mutable_data<T>(context.GetPlace());
}
}
template <typename DeviceContext, typename T>
inline void TransToChannelLast(const framework::ExecutionContext& context,
const Tensor* input, Tensor* transformed_input) {
@ -154,6 +190,16 @@ class BatchNormGradOp : public framework::OperatorWithKernel {
const framework::OpKernelType& expected_kernel_type) const override;
};
class BatchNormDoubleGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override;
};
class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override;
@ -168,6 +214,15 @@ class BatchNormGradMaker : public framework::SingleGradOpMaker<T> {
void Apply(GradOpPtr<T> op) const override;
};
template <typename T>
class BatchNormDoubleGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override;
};
class BatchNormOpInferVarType
: public framework::PassInDtypeAndVarTypeToOutput {
protected:
@ -190,5 +245,11 @@ class BatchNormGradKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override;
};
template <typename DeviceContext, typename T>
class BatchNormDoubleGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override;
};
} // namespace operators
} // namespace paddle

@ -595,9 +595,13 @@ class InstanceNormDoubleGradKernel<platform::CPUDeviceContext, T>
first_grad_arr +=
inv_var_tile_data *
(dy_arr - dy_arr.colwise().sum() / sample_size -
(dy_arr -
dy_arr.colwise().sum().replicate(sample_size, 1) / sample_size -
x_sub_mean_mul_invstd_arr *
(dy_arr * x_sub_mean_mul_invstd_arr).colwise().sum() /
(dy_arr * x_sub_mean_mul_invstd_arr)
.colwise()
.sum()
.replicate(sample_size, 1) /
sample_size);
first_grad_arr = first_grad_arr * ddx_arr;
for (int nc = 0; nc < NxC; ++nc) {

File diff suppressed because it is too large Load Diff

@ -3167,7 +3167,7 @@ def instance_norm(input,
param_shape = [channel_num]
if param_attr and bias_attr:
if param_attr != False and bias_attr != False:
# create parameter
scale = helper.create_parameter(
attr=helper.param_attr,
@ -3190,7 +3190,7 @@ def instance_norm(input,
instance_norm_out = helper.create_variable_for_type_inference(dtype)
inputs = {"X": input}
if param_attr and bias_attr:
if param_attr != False and bias_attr != False:
inputs["Scale"] = scale
inputs["Bias"] = bias

@ -346,7 +346,7 @@ class TestRaiseNoDoubleGradOp(TestCase):
with fluid.dygraph.guard():
x = fluid.layers.ones(shape=[2, 3, 2, 2], dtype='float32')
x.stop_gradient = False
y = paddle.fluid.layers.batch_norm(x)
y = paddle.fluid.layers.group_norm(x, groups=1)
dx = fluid.dygraph.grad(
outputs=[y], inputs=[x], create_graph=True,

@ -68,5 +68,67 @@ class TestInstanceNormDoubleGradCheckWithoutParamBias(
[x], z, x_init=x_arr, atol=atol, place=place, eps=eps)
class TestBatchNormDoubleGradCheck(unittest.TestCase):
def setUp(self):
self.init_test()
def init_test(self):
self.data_layout = 'NCHW'
self.use_global_stats = False
self.shape = [2, 3, 4, 5]
@prog_scope()
def func(self, place):
prog = fluid.Program()
with fluid.program_guard(prog):
np.random.seed()
dtype = "float32"
eps = 0.005
atol = 1e-4
x = layers.create_parameter(dtype=dtype, shape=self.shape, name='x')
z = fluid.layers.batch_norm(
input=x,
data_layout=self.data_layout,
use_global_stats=self.use_global_stats)
x_arr = np.random.uniform(-1, 1, self.shape).astype(dtype)
gradient_checker.double_grad_check(
[x], z, x_init=x_arr, atol=atol, place=place, eps=eps)
def test_grad(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
class TestBatchNormDoubleGradCheckCase1(TestBatchNormDoubleGradCheck):
def init_test(self):
self.data_layout = 'NHWC'
self.use_global_stats = False
self.shape = [2, 3, 4, 5]
class TestBatchNormDoubleGradCheckCase2(TestBatchNormDoubleGradCheck):
def init_test(self):
self.data_layout = 'NCHW'
self.use_global_stats = True
self.shape = [2, 3, 4, 5]
class TestBatchNormDoubleGradCheckCase3(TestBatchNormDoubleGradCheck):
def init_test(self):
self.data_layout = 'NHWC'
self.use_global_stats = True
self.shape = [2, 3, 4, 5]
class TestBatchNormDoubleGradCheckCase4(TestBatchNormDoubleGradCheck):
def init_test(self):
self.data_layout = 'NCHW'
self.use_global_stats = False
self.shape = [2, 2, 3, 4, 5]
if __name__ == "__main__":
unittest.main()

Loading…
Cancel
Save