|
|
|
@ -22,6 +22,7 @@ from six.moves import cStringIO
|
|
|
|
|
from ..proto import framework_pb2
|
|
|
|
|
from ..framework import OpProtoHolder, Variable, core, convert_np_dtype_to_dtype_
|
|
|
|
|
from ..layer_helper import LayerHelper
|
|
|
|
|
from ..data_feeder import convert_dtype
|
|
|
|
|
|
|
|
|
|
__all__ = [
|
|
|
|
|
'deprecated', 'generate_layer_fn', 'generate_activation_fn', 'autodoc',
|
|
|
|
@ -250,6 +251,18 @@ def generate_activation_fn(op_type):
|
|
|
|
|
|
|
|
|
|
def func(x, name=None):
|
|
|
|
|
helper = LayerHelper(op_type, **locals())
|
|
|
|
|
if not isinstance(x, Variable):
|
|
|
|
|
raise TypeError(
|
|
|
|
|
"The type of 'x' in %s must be Variable, but received %s" %
|
|
|
|
|
(op_type, type(x)))
|
|
|
|
|
if convert_dtype(x.dtype) in ['float16']:
|
|
|
|
|
warnings.warn(
|
|
|
|
|
"The data type of 'x' in %s only support float16 in GPU now." %
|
|
|
|
|
(op_type))
|
|
|
|
|
if convert_dtype(x.dtype) not in ['float16', 'float32', 'float64']:
|
|
|
|
|
raise TypeError(
|
|
|
|
|
"The data type of 'x' in %s must be float16 (only support on GPU), float32 or float64, but received %s."
|
|
|
|
|
% (op_type, convert_dtype(x.dtype)))
|
|
|
|
|
output = helper.create_variable_for_type_inference(dtype=x.dtype)
|
|
|
|
|
helper.append_op(type=op_type, inputs={"X": x}, outputs={"Out": output})
|
|
|
|
|
return output
|
|
|
|
|