|
|
|
@ -130,7 +130,7 @@ def generate_layer_fn(op_type):
|
|
|
|
|
o_name = not_intermediate_outputs[0].name
|
|
|
|
|
intermediate_output_names = [output.name for output in intermediate_outputs]
|
|
|
|
|
|
|
|
|
|
def infer_and_check_dtype(op_proto, **kwargs):
|
|
|
|
|
def infer_and_check_dtype(op_proto, *args, **kwargs):
|
|
|
|
|
"""
|
|
|
|
|
This function performs the sanity check for dtype and
|
|
|
|
|
instance type.
|
|
|
|
@ -141,6 +141,10 @@ def generate_layer_fn(op_type):
|
|
|
|
|
val = kwargs.pop(name, [])
|
|
|
|
|
if not isinstance(val, list) and not isinstance(val, tuple):
|
|
|
|
|
val = [val]
|
|
|
|
|
if len(val) == 0:
|
|
|
|
|
val = [args[0]]
|
|
|
|
|
args = args[1:]
|
|
|
|
|
|
|
|
|
|
for each in val:
|
|
|
|
|
if not isinstance(each, Variable):
|
|
|
|
|
raise ValueError("input of {0} must be variable".format(
|
|
|
|
@ -158,7 +162,7 @@ def generate_layer_fn(op_type):
|
|
|
|
|
def func(*args, **kwargs):
|
|
|
|
|
helper = LayerHelper(op_type, **kwargs)
|
|
|
|
|
|
|
|
|
|
dtype = infer_and_check_dtype(op_proto, **kwargs)
|
|
|
|
|
dtype = infer_and_check_dtype(op_proto, *args, **kwargs)
|
|
|
|
|
|
|
|
|
|
inputs = dict()
|
|
|
|
|
for ipt in op_proto.inputs:
|
|
|
|
|