refine sqrt api check (#20254)

test=develop
revert-20712-fix_depthwise_conv
Zhaolong Xing 5 years ago committed by GitHub
parent a69baf639f
commit 63d88b522f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -22,6 +22,7 @@ from six.moves import cStringIO
from ..proto import framework_pb2
from ..framework import OpProtoHolder, Variable, core, convert_np_dtype_to_dtype_
from ..layer_helper import LayerHelper
from ..data_feeder import convert_dtype
__all__ = [
'deprecated', 'generate_layer_fn', 'generate_activation_fn', 'autodoc',
@ -250,6 +251,18 @@ def generate_activation_fn(op_type):
def func(x, name=None):
helper = LayerHelper(op_type, **locals())
if not isinstance(x, Variable):
raise TypeError(
"The type of 'x' in %s must be Variable, but received %s" %
(op_type, type(x)))
if convert_dtype(x.dtype) in ['float16']:
warnings.warn(
"The data type of 'x' in %s only support float16 in GPU now." %
(op_type))
if convert_dtype(x.dtype) not in ['float16', 'float32', 'float64']:
raise TypeError(
"The data type of 'x' in %s must be float16 (only support on GPU), float32 or float64, but received %s."
% (op_type, convert_dtype(x.dtype)))
output = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(type=op_type, inputs={"X": x}, outputs={"Out": output})
return output

@ -23,6 +23,22 @@ import paddle.fluid as fluid
from paddle.fluid import compiler, Program, program_guard
class TestSqrtOpError(OpTest):
def test_errors(self):
with program_guard(Program(), Program()):
# The input type of sqrt op must be Variable or numpy.ndarray.
in1 = 1
self.assertRaises(TypeError, fluid.layers.sqrt, in1)
# The input dtype of sqrt op must be float16, float32, float64.
in2 = fluid.layers.data(
name='input2', shape=[12, 10], dtype="int32")
self.assertRaises(TypeError, fluid.layers.sqrt, in2)
in3 = fluid.layers.data(
name='input3', shape=[12, 10], dtype="float16")
fluid.layers.sqrt(x=in3)
class TestActivation(OpTest):
def setUp(self):
self.op_type = "exp"

Loading…
Cancel
Save