Support backward of backward for Relu and add a new gradient checker by comparing theoretical and numerical Jacobian. (#16862)

* Support backward of backward and a new gradient checker
* Rename decorators.py to decorator_helper.py, since Python on Windows CI has decorators package.

1. Add ReluDoubleGradMaker when register relu_grad.
2. Add a new gradient checker by comparing theoretical and numerical Jacobian.  Check double gradients by double_grad_check.
shanyi15-patch-1
qingqing01 6 years ago committed by GitHub
parent 63d9fe3362
commit c1c2633a63
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -597,10 +597,57 @@ REGISTER_ACTIVATION_OP_MAKER(Square, SquareDoc);
REGISTER_ACTIVATION_OP_MAKER(Softplus, SoftplusDoc);
REGISTER_ACTIVATION_OP_MAKER(Softsign, SoftsignDoc);
class ActivationOpDoubleGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
if (ctx->HasOutput("DOut")) {
ctx->ShareDim("Out", "DOut");
ctx->ShareLoD("Out", "DOut");
}
if (ctx->HasOutput("DDOut")) {
ctx->ShareDim("Out", "DDOut");
ctx->ShareLoD("Out", "DDOut");
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return GetKernelType(ctx, *this, "Out");
}
};
//
// ReluGrad: dx = dy if y >= 0 else 0
// ReluGradGrad: ddy = ddx if y >= 0 else 0
//
class ReluDoubleGradMaker : public ::paddle::framework::SingleGradOpDescMaker {
public:
using ::paddle::framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<::paddle::framework::OpDesc> Apply() const override {
auto* op = new ::paddle::framework::OpDesc();
op->SetType("relu_grad_grad");
// input1: Out
op->SetInput("Out", Input("Out"));
// X@GRAD@GRAD: ddx
op->SetInput("DDX", OutputGrad(framework::GradVarName("X")));
op->SetAttrMap(Attrs());
// Out@GRAD@GRAD: ddy
op->SetOutput("DOut", InputGrad("Out"));
op->SetOutput("DDOut", InputGrad(framework::GradVarName("Out")));
return std::unique_ptr<::paddle::framework::OpDesc>(op);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
#define REGISTER_ACTIVATION_OP(KERNEL_TYPE, OP_NAME, functor, grad_functor) \
REGISTER_OPERATOR( \
@ -632,3 +679,23 @@ namespace ops = paddle::operators;
FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_OP);
FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_CPU_KERNEL);
REGISTER_OPERATOR(
relu, ops::ActivationOp, ops::ReluOpMaker, ops::ActivationOpInferVarType,
ops::ActivationGradOpDescMaker<ops::ReluGradFunctor<float>::FwdDeps()>,
paddle::framework::SingleOpInplaceInToOut);
REGISTER_OPERATOR(relu_grad, ops::ActivationOpGrad,
paddle::framework::SingleOpInplaceInToOut,
ops::ReluDoubleGradMaker);
REGISTER_OPERATOR(relu_grad_grad, ops::ActivationOpDoubleGrad);
REGISTER_ACTIVATION_CPU_KERNEL(relu, Relu, ReluFunctor, ReluGradFunctor);
REGISTER_OP_CPU_KERNEL(
relu_grad_grad,
ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
ops::ReluGradGradFunctor<float>>,
ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
ops::ReluGradGradFunctor<double>>,
ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
ops::ReluGradGradFunctor<plat::float16>>);

@ -32,3 +32,14 @@ namespace plat = paddle::platform;
ops::grad_functor<plat::float16>>);
FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_CUDA_KERNEL);
REGISTER_ACTIVATION_CUDA_KERNEL(relu, Relu, ReluFunctor, ReluGradFunctor);
REGISTER_OP_CUDA_KERNEL(
relu_grad_grad,
ops::ActivationDoubleGradKernel<paddle::platform::CUDADeviceContext,
ops::ReluGradGradFunctor<float>>,
ops::ActivationDoubleGradKernel<paddle::platform::CUDADeviceContext,
ops::ReluGradGradFunctor<double>>,
ops::ActivationDoubleGradKernel<plat::CUDADeviceContext,
ops::ReluGradGradFunctor<plat::float16>>);

@ -1198,6 +1198,126 @@ struct SwishGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
/*
* in arguments: x, out, ddx
* out arguments: ddout, dout, dx
*/
template <ActBwdOpFwdDeps kDepValue>
inline void ExtractActivationDoubleGradTensor(
const framework::ExecutionContext& ctx, const framework::Tensor** X,
const framework::Tensor** Out, const framework::Tensor** ddX,
framework::Tensor** dX, framework::Tensor** dOut,
framework::Tensor** ddOut) {
auto out_var = ctx.InputVar("Out");
auto ddx_var = ctx.InputVar("DDX");
auto ddo_var = ctx.OutputVar("DDOut");
auto do_var = ctx.OutputVar("DOut");
PADDLE_ENFORCE(out_var != nullptr,
"Cannot get input Variable Out, variable name = %s",
ctx.op().Input("Out"));
PADDLE_ENFORCE(ddx_var != nullptr,
"Cannot get input Variable %s, variable name = %s", "DDX",
ctx.op().Input("DDX"));
if (CanBeUsedBySelectedRows.count(ctx.op().Type())) {
*Out = paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(*out_var);
*ddX = paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(*ddx_var);
if (ddo_var) {
*ddOut = paddle::framework::GetMutableLoDTensorOrSelectedRowsValueFromVar(
ddo_var);
}
if (do_var) {
*dOut = paddle::framework::GetMutableLoDTensorOrSelectedRowsValueFromVar(
do_var);
}
} else {
*Out = ctx.Input<framework::Tensor>("Out");
*ddX = ctx.Input<framework::Tensor>("DDX");
if (ddo_var) {
*ddOut = ctx.Output<framework::Tensor>("DDOut");
}
if (do_var) {
*dOut = ctx.Output<framework::Tensor>("DOut");
}
}
PADDLE_ENFORCE(*ddX != nullptr,
"Cannot get output tensor %s, variable name = %s", "DDX",
ctx.op().Output("DDX"));
if (static_cast<int>(kDepValue) & static_cast<int>(kDepX)) {
auto x_var = ctx.InputVar("X");
PADDLE_ENFORCE(x_var != nullptr,
"Cannot get input tensor X, variable name = %s",
ctx.op().Input("X"));
auto dx_var = ctx.OutputVar("DX");
if (CanBeUsedBySelectedRows.count(ctx.op().Type())) {
*X = paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(*x_var);
if (dx_var) {
*dX = paddle::framework::GetMutableLoDTensorOrSelectedRowsValueFromVar(
dx_var);
}
} else {
*X = ctx.Input<framework::Tensor>("X");
if (dx_var) {
*dX = ctx.Output<framework::Tensor>("DX");
}
}
} else {
VLOG(10) << " Inplace activation of Op : " << ctx.op().Type();
*X = *ddX;
}
}
template <typename DeviceContext, typename Functor>
class ActivationDoubleGradKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public:
using T = typename Functor::ELEMENT_TYPE;
void Compute(const framework::ExecutionContext& ctx) const override {
const framework::Tensor *X, *Out, *ddX;
X = Out = ddX = nullptr;
framework::Tensor *ddOut, *dOut, *dX;
ddOut = dOut = dX = nullptr;
ExtractActivationDoubleGradTensor<Functor::FwdDeps()>(ctx, &X, &Out, &ddX,
&dX, &dOut, &ddOut);
if (ddOut) ddOut->mutable_data<T>(ctx.GetPlace());
if (dOut) dOut->mutable_data<T>(ctx.GetPlace());
if (dX) dX->mutable_data<T>(Out->dims(), ctx.GetPlace());
auto& place = ctx.template device_context<DeviceContext>();
Functor functor;
auto attrs = functor.GetAttrs();
for (auto& attr : attrs) {
*attr.second = ctx.Attr<float>(attr.first);
}
functor(place, X, Out, ddX, ddOut, dOut, dX);
}
};
template <typename T>
struct ReluGradGradFunctor : public BaseActivationFunctor<T> {
template <typename Device>
void operator()(const Device& dev, const framework::Tensor* X,
const framework::Tensor* Out, const framework::Tensor* ddX,
framework::Tensor* ddOut, framework::Tensor* dOut,
framework::Tensor* dX) const {
auto* d = dev.eigen_device();
auto ddx = framework::EigenVector<T>::Flatten(detail::Ref(ddX));
auto out = framework::EigenVector<T>::Flatten(detail::Ref(Out));
if (ddOut) {
auto ddout = framework::EigenVector<T>::Flatten(detail::Ref(ddOut));
ddout.device(*d) = ddx * (out > static_cast<T>(0)).template cast<T>();
}
if (dOut) {
auto dout = framework::EigenVector<T>::Flatten(detail::Ref(dOut));
dout.device(*d) = dout.constant(static_cast<T>(0));
}
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};
} // namespace operators
} // namespace paddle
@ -1205,7 +1325,6 @@ struct SwishGradFunctor : public BaseActivationFunctor<T> {
__macro(sigmoid, Sigmoid, SigmoidFunctor, SigmoidGradFunctor); \
__macro(logsigmoid, LogSigmoid, LogSigmoidFunctor, LogSigmoidGradFunctor); \
__macro(exp, Exp, ExpFunctor, ExpGradFunctor); \
__macro(relu, Relu, ReluFunctor, ReluGradFunctor); \
__macro(gelu, Gelu, GeluFunctor, GeluGradFunctor); \
__macro(tanh, Tanh, TanhFunctor, TanhGradFunctor); \
__macro(atan, Atan, AtanFunctor, AtanGradFunctor); \

@ -611,7 +611,7 @@ def _find_op_path_(block, outputs, inputs, no_grad_set):
if inputs:
for op in op_path:
for name in op.desc.input_arg_names():
if name not in input_names:
if name not in input_names and block.vars[name].stop_gradient:
no_grad_set.add(name)
return op_path

@ -29,7 +29,7 @@ list(REMOVE_ITEM TEST_OPS test_lstm_unit_op) # # FIXME(qijun) https://github.com
list(REMOVE_ITEM TEST_OPS test_cond_op) # FIXME(qijun): https://github.com/PaddlePaddle/Paddle/issues/5101#issuecomment-339814957
list(REMOVE_ITEM TEST_OPS op_test) # op_test is a helper python file, not a test
list(REMOVE_ITEM TEST_OPS decorators) # decorators is a helper python file, not a test
list(REMOVE_ITEM TEST_OPS decorator_helper) # decorator_helper is a helper python file, not a test
if(APPLE)
if(NOT WITH_DISTRIBUTE)
list(REMOVE_ITEM TEST_OPS test_desc_clone)

File diff suppressed because it is too large Load Diff

@ -19,7 +19,7 @@ import random
import collections
import paddle.fluid as fluid
import unittest
from decorators import *
from decorator_helper import *
class Memory(object):

@ -16,12 +16,12 @@ from __future__ import print_function
import paddle.fluid as fluid
from paddle.fluid.layers.device import get_places
import decorators
from decorator_helper import prog_scope
import unittest
class TestGetPlaces(unittest.TestCase):
@decorators.prog_scope()
@prog_scope()
def test_get_places(self):
places = get_places()
cpu = fluid.CPUPlace()

@ -17,7 +17,7 @@ import unittest
import contextlib
import numpy as np
import decorators
from decorator_helper import prog_scope
import inspect
from six.moves import filter
@ -1171,7 +1171,7 @@ class TestBook(LayerTest):
fluid.default_startup_program()):
get_places(device_count=1)
@decorators.prog_scope()
@prog_scope()
def make_nce(self):
window_size = 5
words = []

@ -15,13 +15,13 @@
from __future__ import print_function
import unittest
import decorators
from decorator_helper import prog_scope
import paddle.fluid as fluid
import numpy
class TestMathOpPatches(unittest.TestCase):
@decorators.prog_scope()
@prog_scope()
def test_add_scalar(self):
a = fluid.layers.data(name="a", shape=[1])
b = a + 10
@ -41,7 +41,7 @@ class TestMathOpPatches(unittest.TestCase):
d_expected = ab_np + numpy.concatenate([a_np, a_np], axis=1)
self.assertTrue(numpy.allclose(d_expected, d_np))
@decorators.prog_scope()
@prog_scope()
def test_radd_scalar(self):
a = fluid.layers.data(name="a", shape=[1])
b = 10 + a
@ -53,7 +53,7 @@ class TestMathOpPatches(unittest.TestCase):
fetch_list=[b])
self.assertTrue(numpy.allclose(a_np + 10, b_np))
@decorators.prog_scope()
@prog_scope()
def test_sub_scalar(self):
a = fluid.layers.data(name="a", shape=[1])
b = a - 10
@ -65,7 +65,7 @@ class TestMathOpPatches(unittest.TestCase):
fetch_list=[b])
self.assertTrue(numpy.allclose(a_np - 10, b_np))
@decorators.prog_scope()
@prog_scope()
def test_radd_scalar(self):
a = fluid.layers.data(name="a", shape=[1])
b = 10 - a
@ -77,7 +77,7 @@ class TestMathOpPatches(unittest.TestCase):
fetch_list=[b])
self.assertTrue(numpy.allclose(10 - a_np, b_np))
@decorators.prog_scope()
@prog_scope()
def test_mul_scalar(self):
a = fluid.layers.data(name="a", shape=[1])
b = a * 10
@ -89,7 +89,7 @@ class TestMathOpPatches(unittest.TestCase):
fetch_list=[b])
self.assertTrue(numpy.allclose(a_np * 10, b_np))
@decorators.prog_scope()
@prog_scope()
def test_rmul_scalar(self):
a = fluid.layers.data(name="a", shape=[1])
b = 10 * a
@ -101,7 +101,7 @@ class TestMathOpPatches(unittest.TestCase):
fetch_list=[b])
self.assertTrue(numpy.allclose(10 * a_np, b_np))
@decorators.prog_scope()
@prog_scope()
def test_div_scalar(self):
a = fluid.layers.data(name="a", shape=[1])
b = a / 10
@ -113,7 +113,7 @@ class TestMathOpPatches(unittest.TestCase):
fetch_list=[b])
self.assertTrue(numpy.allclose(a_np / 10, b_np))
@decorators.prog_scope()
@prog_scope()
def test_rdiv_scalar(self):
a = fluid.layers.data(name="a", shape=[1])
b = 10 / a
@ -126,7 +126,7 @@ class TestMathOpPatches(unittest.TestCase):
fetch_list=[b])
self.assertTrue(numpy.allclose(10 / a_np, b_np))
@decorators.prog_scope()
@prog_scope()
def test_div_two_tensor(self):
a = fluid.layers.data(name="a", shape=[1])
b = fluid.layers.data(name="b", shape=[1])
@ -141,7 +141,7 @@ class TestMathOpPatches(unittest.TestCase):
fetch_list=[c])
self.assertTrue(numpy.allclose(a_np / b_np, c_np))
@decorators.prog_scope()
@prog_scope()
def test_mul_two_tensor(self):
a = fluid.layers.data(name="a", shape=[1])
b = fluid.layers.data(name="b", shape=[1])
@ -156,7 +156,7 @@ class TestMathOpPatches(unittest.TestCase):
fetch_list=[c])
self.assertTrue(numpy.allclose(a_np * b_np, c_np))
@decorators.prog_scope()
@prog_scope()
def test_add_two_tensor(self):
a = fluid.layers.data(name="a", shape=[1])
b = fluid.layers.data(name="b", shape=[1])
@ -171,7 +171,7 @@ class TestMathOpPatches(unittest.TestCase):
fetch_list=[c])
self.assertTrue(numpy.allclose(a_np + b_np, c_np))
@decorators.prog_scope()
@prog_scope()
def test_sub_two_tensor(self):
a = fluid.layers.data(name="a", shape=[1])
b = fluid.layers.data(name="b", shape=[1])

@ -0,0 +1,72 @@
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle.fluid as fluid
import paddle.fluid.layers as layers
import paddle.fluid.core as core
import gradient_checker
from decorator_helper import prog_scope
class TestMulGradCheck(unittest.TestCase):
@prog_scope()
def func(self, place):
prog = fluid.Program()
with fluid.program_guard(prog):
x = layers.create_parameter(dtype="float64", shape=[2, 8], name='x')
y = layers.create_parameter(dtype="float64", shape=[8, 4], name='y')
z = layers.mul(x=x, y=y)
gradient_checker.grad_check([x, y], z, place=place)
def test_grad(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
class TestReluDoubleGradCheck(unittest.TestCase):
@prog_scope()
def func(self, place):
# the shape of input variable shoule be clearly specified, not inlcude -1.
shape = [2, 8]
eps = 0.005
dtype = np.float64
x = layers.data('x', shape, False, dtype)
x.persistable = True
y = layers.relu(x)
x_arr = np.random.uniform(-1, 1, shape).astype(dtype)
x_arr[np.abs(x_arr) < 0.005] = 0.02
gradient_checker.double_grad_check(
[x], y, x_init=x_arr, place=place, eps=eps)
def test_grad(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
if __name__ == "__main__":
unittest.main()

@ -17,11 +17,11 @@ import unittest
import paddle.fluid as fluid
import numpy as np
import decorators
from decorator_helper import prog_scope
class TestRegistry(unittest.TestCase):
@decorators.prog_scope()
@prog_scope()
def test_registry_layer(self):
x = fluid.layers.data(name='X', shape=[10, 10], dtype='float32')
output = fluid.layers.mean(x)

Loading…
Cancel
Save