|
|
|
@ -20,7 +20,7 @@ import string
|
|
|
|
|
|
|
|
|
|
from six.moves import cStringIO
|
|
|
|
|
from ..proto import framework_pb2
|
|
|
|
|
from ..framework import OpProtoHolder, Variable
|
|
|
|
|
from ..framework import OpProtoHolder, Variable, core, convert_np_dtype_to_dtype_
|
|
|
|
|
from ..layer_helper import LayerHelper
|
|
|
|
|
|
|
|
|
|
__all__ = [
|
|
|
|
@ -178,6 +178,15 @@ def generate_layer_fn(op_type):
|
|
|
|
|
"operator {0} must input same dtype. {1} vs {2}".format(
|
|
|
|
|
op_type, dtype, each.dtype))
|
|
|
|
|
|
|
|
|
|
if dtype is None:
|
|
|
|
|
arg_dtype = kwargs.get("dtype")
|
|
|
|
|
if arg_dtype:
|
|
|
|
|
if not isinstance(arg_dtype, core.VarDesc.VarType):
|
|
|
|
|
dtype = convert_np_dtype_to_dtype_(arg_dtype)
|
|
|
|
|
else:
|
|
|
|
|
dtype = arg_dtype
|
|
|
|
|
else:
|
|
|
|
|
dtype = core.VarDesc.VarType.FP32
|
|
|
|
|
return dtype
|
|
|
|
|
|
|
|
|
|
def func(*args, **kwargs):
|
|
|
|
|