Ops(relu6/selu/soft_relu/softshrink/stanh/swish/thresholded_relu/hard_shrink/hard_sigmoid/hard_swish/hsigmoid/maxout) error message enhancement (#23718)

revert-23830-2.0-beta
zhupengyang 5 years ago committed by GitHub
parent 0b6f09e74f
commit 2787944c2b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -61,30 +61,15 @@ class HierarchicalSigmoidOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "hsigmoid");
platform::errors::NotFound( OP_INOUT_CHECK(ctx->HasInput("Label"), "Input", "Label", "hsigmoid");
"Input(X) of HierarchicalSigmoidOp is not found.")); OP_INOUT_CHECK(ctx->HasInput("W"), "Input", "W", "hsigmoid");
PADDLE_ENFORCE_EQ( OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "hsigmoid");
ctx->HasInput("Label"), true, OP_INOUT_CHECK(ctx->HasOutput("PreOut"), "Output", "PreOut", "hsigmoid");
platform::errors::NotFound(
"Input(Label) of HierarchicalSigmoidOp is not found."));
PADDLE_ENFORCE_EQ(ctx->HasInput("W"), true,
platform::errors::NotFound(
"Input(W) of HierarchicalSigmoidOp is not found."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("Out"), true,
platform::errors::NotFound(
"Output(Out) of HierarchicalSigmoidOp is not found."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("PreOut"), true,
platform::errors::NotFound(
"Output(PreOut) of HierarchicalSigmoidOp is not found."));
auto with_prefetch = ctx->Attrs().Get<bool>("remote_prefetch"); auto with_prefetch = ctx->Attrs().Get<bool>("remote_prefetch");
if (with_prefetch) { if (with_prefetch) {
PADDLE_ENFORCE_EQ( OP_INOUT_CHECK(ctx->HasOutput("W_Out"), "Output", "W_Out", "hsigmoid");
ctx->HasOutput("W_Out"), true,
platform::errors::NotFound(
"Output(W_Out) of HierarchicalSigmoidOp is not found."));
} }
const int64_t batch_size = ctx->GetInputDim("X")[0]; const int64_t batch_size = ctx->GetInputDim("X")[0];
std::vector<int64_t> output_shape({batch_size, 1}); std::vector<int64_t> output_shape({batch_size, 1});
@ -213,30 +198,15 @@ class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ( OP_INOUT_CHECK(ctx->HasInput("W"), "Input", "W", "hsigmoid_grad");
ctx->HasInput("W"), true, OP_INOUT_CHECK(ctx->HasInput("Label"), "Input", "Label", "hsigmoid_grad");
platform::errors::NotFound( OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
"Input(W) of HierarchicalSigmoidGradOp is not found.")); "Out@Grad", "hsigmoid_grad");
PADDLE_ENFORCE_EQ( OP_INOUT_CHECK(ctx->HasInput("PreOut"), "Input", "PreOut", "hsigmoid_grad");
ctx->HasInput("Label"), true, OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("W")), "Output",
platform::errors::NotFound( "W@Grad", "hsigmoid_grad");
"Input(Label) of HierarchicalSigmoidGradOp is not found.")); OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output",
PADDLE_ENFORCE_EQ( "X@Grad", "hsigmoid_grad");
ctx->HasInput(framework::GradVarName("Out")), true,
platform::errors::NotFound(
"Input(Out@Grad) of HierarchicalSigmoidGradOp is not found."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("PreOut"), true,
platform::errors::NotFound(
"Input(Preout) of HierarchicalSigmoidGradOp is not found."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput(framework::GradVarName("W")), true,
platform::errors::NotFound(
"Output(W@Grad of HierarchicalSigmoidGradOp is not found."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput(framework::GradVarName("X")), true,
platform::errors::NotFound(
"Output(X@Grad of HierarchicalSigmoidGradOp is not found."));
if (ctx->HasOutput(framework::GradVarName("Bias"))) { if (ctx->HasOutput(framework::GradVarName("Bias"))) {
ctx->SetOutputDim(framework::GradVarName("Bias"), ctx->SetOutputDim(framework::GradVarName("Bias"),

@ -203,8 +203,9 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> {
zero(dev_ctx, w_grad, static_cast<T>(0.0)); zero(dev_ctx, w_grad, static_cast<T>(0.0));
bit_code->MulGradWeight(pre_out_grad, w_grad, in); bit_code->MulGradWeight(pre_out_grad, w_grad, in);
} else { } else {
PADDLE_ENFORCE(path != nullptr, PADDLE_ENFORCE_NOT_NULL(path,
"Sparse mode should not be used without custom tree!"); platform::errors::NotFound(
"Custom tree must be set for sparse mode!"));
framework::Vector<int64_t> real_rows = PathToRows(*path); framework::Vector<int64_t> real_rows = PathToRows(*path);
auto* w_grad = auto* w_grad =
ctx.Output<framework::SelectedRows>(framework::GradVarName("W")); ctx.Output<framework::SelectedRows>(framework::GradVarName("W"));

@ -72,24 +72,26 @@ class MaxOutOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "maxout");
"Input(X) of MaxoutOpshould not be null."); OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "maxout");
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
"Output(Out) of MaxoutOp should not be null.");
auto in_x_dims = ctx->GetInputDim("X"); auto in_x_dims = ctx->GetInputDim("X");
int groups = ctx->Attrs().Get<int>("groups"); int groups = ctx->Attrs().Get<int>("groups");
int axis = ctx->Attrs().Get<int>("axis"); int axis = ctx->Attrs().Get<int>("axis");
// check groups > 1 // check groups > 1
PADDLE_ENFORCE_GT(groups, 1, PADDLE_ENFORCE_GT(groups, 1, platform::errors::InvalidArgument(
"Attr(groups) of Op(maxout) should be larger than 1."); "Attr(groups) of Op(maxout) should be "
"larger than 1. But received %d.",
groups));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
in_x_dims[axis] % groups, 0, in_x_dims[axis] % groups, 0,
"ValueError: The number of input channels for Op(maxout) " platform::errors::InvalidArgument(
"should be divisible by Attr(groups). But received: the " "The number of input channels for Op(maxout) "
"input's channels is [%d], the shape of input is [%s], " "should be divisible by Attr(groups). But received: the "
"the Attr(groups) is [%d], the Attr(axis) is [%d]. The " "input's channels is [%d], the shape of input is [%s], "
"error may come from wrong Attr(groups) or Attr(axis) setting.", "the Attr(groups) is [%d], the Attr(axis) is [%d]. The "
in_x_dims[axis], in_x_dims, groups, axis); "error may come from wrong Attr(groups) or Attr(axis) setting.",
in_x_dims[axis], in_x_dims, groups, axis));
std::vector<int64_t> output_shape( std::vector<int64_t> output_shape(
{in_x_dims[0], in_x_dims[1], in_x_dims[2], in_x_dims[3]}); {in_x_dims[0], in_x_dims[1], in_x_dims[2], in_x_dims[3]});
output_shape[axis] = in_x_dims[axis] / groups; output_shape[axis] = in_x_dims[axis] / groups;
@ -101,10 +103,9 @@ class MaxOutOpGrad : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "maxout_grad");
"Input(X) of MaxOutOpGrad must not be null."); OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output",
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), "X@Grad", "maxout_grad");
"Output(Grad@X) of MaxOutOpGrad should not be null.");
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
} }
}; };

@ -28,10 +28,8 @@ class SeluOp : public framework::OperatorWithKernel {
: OperatorWithKernel(type, inputs, outputs, attrs) {} : OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "selu");
"Input(X) of SeluOp should not be null."); OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "selu");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of SeluOp should not be null.");
ctx->ShareDim("X", /*->*/ "Out"); ctx->ShareDim("X", /*->*/ "Out");
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
@ -105,9 +103,9 @@ class SeluGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
"Input(Out@GRAD) should not be null"); "Out@GRAD", "selu_grad");
PADDLE_ENFORCE(ctx->HasInput("Out"), "Input(Out) should not be null"); OP_INOUT_CHECK(ctx->HasInput("Out"), "Input", "Out", "selu_grad");
auto x_grad_name = framework::GradVarName("X"); auto x_grad_name = framework::GradVarName("X");
ctx->SetOutputDim(x_grad_name, ctx->GetInputDim("Out")); ctx->SetOutputDim(x_grad_name, ctx->GetInputDim("Out"));
} }

@ -923,6 +923,8 @@ def hsigmoid(input,
value=0.05), bias_attr=fluid.initializer.Constant(value=.0)) value=0.05), bias_attr=fluid.initializer.Constant(value=.0))
# out = [[0.62792355], [0.62792355], [0.62792355], [0.62792355]] # out = [[0.62792355], [0.62792355], [0.62792355], [0.62792355]]
""" """
check_variable_and_dtype(input, 'input', ['float32', 'float64'], 'hsigmoid')
check_variable_and_dtype(label, 'label', ['int64'], 'hsigmoid')
helper = LayerHelper('hierarchical_sigmoid', **locals()) helper = LayerHelper('hierarchical_sigmoid', **locals())
dtype = helper.input_dtype() dtype = helper.input_dtype()

@ -8280,6 +8280,8 @@ def selu(x, scale=None, alpha=None, name=None):
res = exe.run(fluid.default_main_program(), feed={'x':img}, fetch_list=[output]) res = exe.run(fluid.default_main_program(), feed={'x':img}, fetch_list=[output])
print(res) # [array([[0. , 1.050701],[2.101402, 3.152103]], dtype=float32)] print(res) # [array([[0. , 1.050701],[2.101402, 3.152103]], dtype=float32)]
""" """
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'selu')
helper = LayerHelper('selu', **locals()) helper = LayerHelper('selu', **locals())
dtype = helper.input_dtype(input_param_name='x') dtype = helper.input_dtype(input_param_name='x')
out = helper.create_variable_for_type_inference(dtype) out = helper.create_variable_for_type_inference(dtype)
@ -8888,6 +8890,8 @@ def relu6(x, threshold=6.0, name=None):
# [[0. 0. ] # [[0. 0. ]
# [2.5 6. ]] # [2.5 6. ]]
""" """
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'relu6')
helper = LayerHelper('relu6', **locals()) helper = LayerHelper('relu6', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype) out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op( helper.append_op(
@ -8980,6 +8984,8 @@ def stanh(x, scale_a=0.67, scale_b=1.7159, name=None):
# [0.62705994, 0.23110689, 0.56902856]], dtype=float32)] # [0.62705994, 0.23110689, 0.56902856]], dtype=float32)]
""" """
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'stanh')
helper = LayerHelper('stanh', **locals()) helper = LayerHelper('stanh', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype) out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op( helper.append_op(
@ -9014,6 +9020,9 @@ def hard_sigmoid(x, slope=0.2, offset=0.5, name=None):
data = fluid.layers.fill_constant(shape=[3, 2], value=0.5, dtype='float32') # [[0.5, 0.5], [0.5, 0.5], [0.5, 0.5]] data = fluid.layers.fill_constant(shape=[3, 2], value=0.5, dtype='float32') # [[0.5, 0.5], [0.5, 0.5], [0.5, 0.5]]
result = fluid.layers.hard_sigmoid(data) # [[0.6, 0.6], [0.6, 0.6], [0.6, 0.6]] result = fluid.layers.hard_sigmoid(data) # [[0.6, 0.6], [0.6, 0.6], [0.6, 0.6]]
""" """
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
'hard_sigmoid')
helper = LayerHelper('hard_sigmoid', **locals()) helper = LayerHelper('hard_sigmoid', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype) out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op( helper.append_op(
@ -9094,6 +9103,8 @@ def swish(x, beta=1.0, name=None):
# array([[-0.03916847, 0.8835007 , -0.25835553], # array([[-0.03916847, 0.8835007 , -0.25835553],
# [ 0.51126915, 0.82324016, 0.06915068]], dtype=float32) # [ 0.51126915, 0.82324016, 0.06915068]], dtype=float32)
""" """
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'swish')
helper = LayerHelper('swish', **locals()) helper = LayerHelper('swish', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype) out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op( helper.append_op(
@ -9293,6 +9304,9 @@ def soft_relu(x, threshold=40.0, name=None):
res = exe.run(fluid.default_main_program(), feed={'x':img}, fetch_list=[output]) res = exe.run(fluid.default_main_program(), feed={'x':img}, fetch_list=[output])
print(res) # [array([[0.6931472, 1.3132616], [2.126928 , 3.0485873]], dtype=float32)] print(res) # [array([[0.6931472, 1.3132616], [2.126928 , 3.0485873]], dtype=float32)]
""" """
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
'soft_relu')
helper = LayerHelper('soft_relu', **locals()) helper = LayerHelper('soft_relu', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype) out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op( helper.append_op(
@ -11786,6 +11800,8 @@ def maxout(x, groups, name=None, axis=1):
dtype='float32') dtype='float32')
out = fluid.layers.maxout(input, groups=2) out = fluid.layers.maxout(input, groups=2)
""" """
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'maxout')
helper = LayerHelper("maxout", **locals()) helper = LayerHelper("maxout", **locals())
if axis not in [1, -1, 3]: if axis not in [1, -1, 3]:
raise ValueError( raise ValueError(
@ -14005,6 +14021,9 @@ def hard_swish(x, threshold=6.0, scale=6.0, offset=3.0, name=None):
out, = exe.run(feed={'x':x_data}, fetch_list=[y.name]) out, = exe.run(feed={'x':x_data}, fetch_list=[y.name])
print(out) # [[0.66666667, 1.66666667,3., 4.]] print(out) # [[0.66666667, 1.66666667,3., 4.]]
""" """
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
'hard_swish')
helper = LayerHelper('hard_swish', **locals()) helper = LayerHelper('hard_swish', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype) out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op( helper.append_op(

@ -17,6 +17,7 @@ import os
from .layer_function_generator import generate_layer_fn, generate_activation_fn from .layer_function_generator import generate_layer_fn, generate_activation_fn
from .. import core from .. import core
from ..framework import convert_np_dtype_to_dtype_ from ..framework import convert_np_dtype_to_dtype_
from ..data_feeder import check_variable_and_dtype
__activations_noattr__ = [ __activations_noattr__ = [
'sigmoid', 'sigmoid',
@ -64,6 +65,9 @@ _softshrink_ = generate_layer_fn('softshrink')
def softshrink(x, alpha=None): def softshrink(x, alpha=None):
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
'softshrink')
locals_var = locals().copy() locals_var = locals().copy()
kwargs = dict() kwargs = dict()
for name, val in locals_var.items(): for name, val in locals_var.items():
@ -107,6 +111,9 @@ _hard_shrink_ = generate_layer_fn('hard_shrink')
def hard_shrink(x, threshold=None): def hard_shrink(x, threshold=None):
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
'hard_shrink')
locals_var = locals().copy() locals_var = locals().copy()
kwargs = dict() kwargs = dict()
for name, val in locals_var.items(): for name, val in locals_var.items():
@ -163,6 +170,9 @@ _thresholded_relu_ = generate_layer_fn('thresholded_relu')
def thresholded_relu(x, threshold=None): def thresholded_relu(x, threshold=None):
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
'thresholded_relu')
locals_var = locals().copy() locals_var = locals().copy()
kwargs = dict() kwargs = dict()
for name, val in locals_var.items(): for name, val in locals_var.items():

@ -220,6 +220,19 @@ class TestHardShrink(TestActivation):
self.check_grad(['X'], 'Out') self.check_grad(['X'], 'Out')
class TestHardShrinkOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program()):
# The input type must be Variable.
self.assertRaises(TypeError, fluid.layers.hard_shrink, 1)
# The input dtype must be float16, float32, float64.
x_int32 = fluid.data(name='x_int32', shape=[12, 10], dtype='int32')
self.assertRaises(TypeError, fluid.layers.hard_shrink, x_int32)
# support the input dtype is float16
x_fp16 = fluid.data(name='x_fp16', shape=[12, 10], dtype='float16')
fluid.layers.hard_shrink(x_fp16)
class TestSoftShrink(TestActivation): class TestSoftShrink(TestActivation):
def setUp(self): def setUp(self):
self.op_type = "softshrink" self.op_type = "softshrink"
@ -241,6 +254,19 @@ class TestSoftShrink(TestActivation):
self.check_grad(['X'], 'Out') self.check_grad(['X'], 'Out')
class TestSoftShrinkOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program()):
# The input type must be Variable.
self.assertRaises(TypeError, fluid.layers.softshrink, 1)
# The input dtype must be float16, float32, float64.
x_int32 = fluid.data(name='x_int32', shape=[12, 10], dtype='int32')
self.assertRaises(TypeError, fluid.layers.softshrink, x_int32)
# support the input dtype is float16
x_fp16 = fluid.data(name='x_fp16', shape=[12, 10], dtype='float16')
fluid.layers.softshrink(x_fp16)
class TestSqrt(TestActivation, TestParameter): class TestSqrt(TestActivation, TestParameter):
def setUp(self): def setUp(self):
self.op_type = "sqrt" self.op_type = "sqrt"
@ -586,6 +612,19 @@ class TestRelu6(TestActivation):
self.check_grad(['X'], 'Out') self.check_grad(['X'], 'Out')
class TestRelu6OpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program()):
# The input type must be Variable.
self.assertRaises(TypeError, fluid.layers.relu6, 1)
# The input dtype must be float16, float32, float64.
x_int32 = fluid.data(name='x_int32', shape=[12, 10], dtype='int32')
self.assertRaises(TypeError, fluid.layers.relu6, x_int32)
# support the input dtype is float16
x_fp16 = fluid.data(name='x_fp16', shape=[12, 10], dtype='float16')
fluid.layers.relu6(x_fp16)
class TestHardSwish(TestActivation): class TestHardSwish(TestActivation):
def setUp(self): def setUp(self):
self.op_type = 'hard_swish' self.op_type = 'hard_swish'
@ -610,6 +649,19 @@ class TestHardSwish(TestActivation):
self.check_grad(['X'], 'Out') self.check_grad(['X'], 'Out')
class TestHardSwishOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program()):
# The input type must be Variable.
self.assertRaises(TypeError, fluid.layers.hard_swish, 1)
# The input dtype must be float16, float32, float64.
x_int32 = fluid.data(name='x_int32', shape=[12, 10], dtype='int32')
self.assertRaises(TypeError, fluid.layers.hard_swish, x_int32)
# support the input dtype is float16
x_fp16 = fluid.data(name='x_fp16', shape=[12, 10], dtype='float16')
fluid.layers.hard_swish(x_fp16)
class TestSoftRelu(TestActivation): class TestSoftRelu(TestActivation):
def setUp(self): def setUp(self):
self.op_type = "soft_relu" self.op_type = "soft_relu"
@ -635,6 +687,19 @@ class TestSoftRelu(TestActivation):
self.check_grad(['X'], 'Out', max_relative_error=0.02) self.check_grad(['X'], 'Out', max_relative_error=0.02)
class TestSoftReluOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program()):
# The input type must be Variable.
self.assertRaises(TypeError, fluid.layers.soft_relu, 1)
# The input dtype must be float16, float32, float64.
x_int32 = fluid.data(name='x_int32', shape=[12, 10], dtype='int32')
self.assertRaises(TypeError, fluid.layers.soft_relu, x_int32)
# support the input dtype is float16
x_fp16 = fluid.data(name='x_fp16', shape=[12, 10], dtype='float16')
fluid.layers.soft_relu(x_fp16)
class TestELU(TestActivation): class TestELU(TestActivation):
def setUp(self): def setUp(self):
self.op_type = "elu" self.op_type = "elu"
@ -812,6 +877,19 @@ class TestSTanh(TestActivation):
self.check_grad(['X'], 'Out') self.check_grad(['X'], 'Out')
class TestSTanhOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program()):
# The input type must be Variable.
self.assertRaises(TypeError, fluid.layers.stanh, 1)
# The input dtype must be float16, float32, float64.
x_int32 = fluid.data(name='x_int32', shape=[12, 10], dtype='int32')
self.assertRaises(TypeError, fluid.layers.stanh, x_int32)
# support the input dtype is float16
x_fp16 = fluid.data(name='x_fp16', shape=[12, 10], dtype='float16')
fluid.layers.stanh(x_fp16)
class TestSoftplus(TestActivation): class TestSoftplus(TestActivation):
def setUp(self): def setUp(self):
self.op_type = "softplus" self.op_type = "softplus"
@ -870,6 +948,19 @@ class TestThresholdedRelu(TestActivation):
self.check_grad(['X'], 'Out') self.check_grad(['X'], 'Out')
class TestThresholdedReluOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program()):
# The input type must be Variable.
self.assertRaises(TypeError, fluid.layers.thresholded_relu, 1)
# The input dtype must be float16, float32, float64.
x_int32 = fluid.data(name='x_int32', shape=[12, 10], dtype='int32')
self.assertRaises(TypeError, fluid.layers.thresholded_relu, x_int32)
# support the input dtype is float16
x_fp16 = fluid.data(name='x_fp16', shape=[12, 10], dtype='float16')
fluid.layers.thresholded_relu(x_fp16)
class TestHardSigmoid(TestActivation): class TestHardSigmoid(TestActivation):
def setUp(self): def setUp(self):
self.op_type = "hard_sigmoid" self.op_type = "hard_sigmoid"
@ -899,6 +990,19 @@ class TestHardSigmoid(TestActivation):
self.check_grad(['X'], 'Out') self.check_grad(['X'], 'Out')
class TestHardSigmoidOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program()):
# The input type must be Variable.
self.assertRaises(TypeError, fluid.layers.hard_sigmoid, 1)
# The input dtype must be float16, float32, float64.
x_int32 = fluid.data(name='x_int32', shape=[12, 10], dtype='int32')
self.assertRaises(TypeError, fluid.layers.hard_sigmoid, x_int32)
# support the input dtype is float16
x_fp16 = fluid.data(name='x_fp16', shape=[12, 10], dtype='float16')
fluid.layers.hard_sigmoid(x_fp16)
class TestSwish(TestActivation): class TestSwish(TestActivation):
def setUp(self): def setUp(self):
self.op_type = "swish" self.op_type = "swish"
@ -918,6 +1022,19 @@ class TestSwish(TestActivation):
self.check_grad(['X'], 'Out', max_relative_error=0.008) self.check_grad(['X'], 'Out', max_relative_error=0.008)
class TestSwishOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program()):
# The input type must be Variable.
self.assertRaises(TypeError, fluid.layers.swish, 1)
# The input dtype must be float16, float32, float64.
x_int32 = fluid.data(name='x_int32', shape=[12, 10], dtype='int32')
self.assertRaises(TypeError, fluid.layers.swish, x_int32)
# support the input dtype is float16
x_fp16 = fluid.data(name='x_fp16', shape=[12, 10], dtype='float16')
fluid.layers.swish(x_fp16)
#------------------ Test Cudnn Activation---------------------- #------------------ Test Cudnn Activation----------------------
def create_test_act_cudnn_class(parent, atol=1e-3, grad_atol=1e-3): def create_test_act_cudnn_class(parent, atol=1e-3, grad_atol=1e-3):
@unittest.skipIf(not core.is_compiled_with_cuda(), @unittest.skipIf(not core.is_compiled_with_cuda(),

@ -18,6 +18,7 @@ import unittest
import numpy as np import numpy as np
import paddle.fluid.core as core import paddle.fluid.core as core
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
import math import math
from op_test import OpTest, skip_check_grad_ci from op_test import OpTest, skip_check_grad_ci
@ -378,5 +379,27 @@ class TestHSigmoidOpWithCostumTreeWithoutBias(OpTest):
self.check_grad(['X', 'W'], ['Out'], no_grad_set=set('Label')) self.check_grad(['X', 'W'], ['Out'], no_grad_set=set('Label'))
class TestHSigmoidOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program()):
label = fluid.data('label', [4, 1], 'int64')
# The input type must be Variable.
self.assertRaises(TypeError, fluid.layers.hsigmoid, 1, label, 2)
# The input dtype must be float16, float32, float64.
x_int32 = fluid.data(name='x_int32', shape=[4, 3], dtype='int32')
self.assertRaises(TypeError, fluid.layers.hsigmoid, x_int32, label,
2)
# support the input dtype is float32
x_fp32 = fluid.data(name='x_fp32', shape=[4, 3], dtype='float32')
fluid.layers.hsigmoid(x_fp32, label, 2)
# The label type must be Variable.
self.assertRaises(TypeError, fluid.layers.hsigmoid, x_fp32, 1, 2)
# The label dtype must be int64.
label_int32 = fluid.data('label_int32', [4, 1], 'int32')
self.assertRaises(TypeError, fluid.layers.hsigmoid, x_fp32,
label_int32, 2)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

@ -17,6 +17,7 @@ from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
import paddle.fluid.core as core import paddle.fluid.core as core
from op_test import OpTest from op_test import OpTest
@ -96,5 +97,18 @@ class TestMaxOutOpAxisAPI(unittest.TestCase):
self.assertRaises(ValueError, _attr_axis) self.assertRaises(ValueError, _attr_axis)
class TestMaxOutOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program()):
# The input type must be Variable.
self.assertRaises(TypeError, fluid.layers.maxout, 1, 2)
# The input dtype must be float16, float32, float64.
x_int32 = fluid.data(name='x_int32', shape=[12, 10], dtype='int32')
self.assertRaises(TypeError, fluid.layers.maxout, x_int32, 2)
# support the input dtype is float32
x_fp32 = fluid.data(name='x_fp32', shape=[12, 10], dtype='float32')
fluid.layers.maxout(x_fp32, 2)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

@ -18,6 +18,8 @@ import unittest
import numpy as np import numpy as np
import six import six
from op_test import OpTest from op_test import OpTest
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
class SeluTest(OpTest): class SeluTest(OpTest):
@ -67,5 +69,18 @@ class SeluTest(OpTest):
self.check_grad(['X'], 'Out') self.check_grad(['X'], 'Out')
class TestSeluOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program()):
# The input type must be Variable.
self.assertRaises(TypeError, fluid.layers.selu, 1)
# The input dtype must be float16, float32, float64.
x_int32 = fluid.data(name='x_int32', shape=[12, 10], dtype='int32')
self.assertRaises(TypeError, fluid.layers.selu, x_int32)
# support the input dtype is float32
x_fp32 = fluid.data(name='x_fp32', shape=[12, 10], dtype='float32')
fluid.layers.selu(x_fp32)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

Loading…
Cancel
Save