|
|
|
@ -40,11 +40,9 @@ PADDLE_ON_MODEL_CE = os.environ.get('PADDLE_ON_MODEL_CE', None) is not None
|
|
|
|
|
|
|
|
|
|
__all__ = [
|
|
|
|
|
'Program',
|
|
|
|
|
'Operator',
|
|
|
|
|
'default_startup_program',
|
|
|
|
|
'default_main_program',
|
|
|
|
|
'program_guard',
|
|
|
|
|
'get_var',
|
|
|
|
|
'name_scope',
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
@ -663,11 +661,11 @@ class Operator(object):
|
|
|
|
|
self._update_desc_attr(attr_name, attr_val)
|
|
|
|
|
|
|
|
|
|
self.desc.check_attrs()
|
|
|
|
|
if self.has_kernel(type):
|
|
|
|
|
if self._has_kernel(type):
|
|
|
|
|
self.desc.infer_var_type(self.block.desc)
|
|
|
|
|
self.desc.infer_shape(self.block.desc)
|
|
|
|
|
|
|
|
|
|
def has_kernel(self, op_type):
|
|
|
|
|
def _has_kernel(self, op_type):
|
|
|
|
|
return op_type not in self.OP_WITHOUT_KERNEL_SET
|
|
|
|
|
|
|
|
|
|
def to_string(self, throw_on_error):
|
|
|
|
@ -708,7 +706,7 @@ class Operator(object):
|
|
|
|
|
"""
|
|
|
|
|
return self.desc.input(name)
|
|
|
|
|
|
|
|
|
|
def rename_input(self, old_name, new_name):
|
|
|
|
|
def _rename_input(self, old_name, new_name):
|
|
|
|
|
"""
|
|
|
|
|
Rename the `old_name` to `new_name`.
|
|
|
|
|
|
|
|
|
@ -719,9 +717,9 @@ class Operator(object):
|
|
|
|
|
Returns:
|
|
|
|
|
None
|
|
|
|
|
"""
|
|
|
|
|
self.desc.rename_input(old_name, new_name)
|
|
|
|
|
self.desc._rename_input(old_name, new_name)
|
|
|
|
|
|
|
|
|
|
def rename_output(self, old_name, new_name):
|
|
|
|
|
def _rename_output(self, old_name, new_name):
|
|
|
|
|
"""
|
|
|
|
|
Rename the `old_name` to `new_name`.
|
|
|
|
|
|
|
|
|
@ -732,7 +730,7 @@ class Operator(object):
|
|
|
|
|
Returns:
|
|
|
|
|
None
|
|
|
|
|
"""
|
|
|
|
|
self.desc.rename_output(old_name, new_name)
|
|
|
|
|
self.desc._rename_output(old_name, new_name)
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def input_names(self):
|
|
|
|
@ -796,7 +794,7 @@ class Operator(object):
|
|
|
|
|
"""
|
|
|
|
|
return self.desc.attr_type(name)
|
|
|
|
|
|
|
|
|
|
def set_attr(self, name, val):
|
|
|
|
|
def _set_attr(self, name, val):
|
|
|
|
|
"""
|
|
|
|
|
Set the value of attribute by attribute's name.
|
|
|
|
|
|
|
|
|
@ -829,7 +827,7 @@ class Operator(object):
|
|
|
|
|
isinstance(val, core.ProgramDesc):
|
|
|
|
|
self.desc.set_serialized_attr(name, val.serialize_to_string())
|
|
|
|
|
else:
|
|
|
|
|
self.desc.set_attr(name, val)
|
|
|
|
|
self.desc._set_attr(name, val)
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def attr_names(self):
|
|
|
|
@ -848,7 +846,7 @@ class Operator(object):
|
|
|
|
|
"""
|
|
|
|
|
return self.desc.attr(name)
|
|
|
|
|
|
|
|
|
|
def block_attr_id(self, name):
|
|
|
|
|
def _block_attr_id(self, name):
|
|
|
|
|
"""
|
|
|
|
|
Get the block attribute's id by name.
|
|
|
|
|
|
|
|
|
@ -858,9 +856,9 @@ class Operator(object):
|
|
|
|
|
Returns:
|
|
|
|
|
int: the block index.
|
|
|
|
|
"""
|
|
|
|
|
return self.desc.block_attr_id(name)
|
|
|
|
|
return self.desc._block_attr_id(name)
|
|
|
|
|
|
|
|
|
|
def block_attr(self, name):
|
|
|
|
|
def _block_attr(self, name):
|
|
|
|
|
"""
|
|
|
|
|
Get the block attribute by name.
|
|
|
|
|
|
|
|
|
@ -871,11 +869,11 @@ class Operator(object):
|
|
|
|
|
block: the block attribute.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
id = self.block_attr_id(name)
|
|
|
|
|
id = self._block_attr_id(name)
|
|
|
|
|
assert (id >= 0 and id < len(self.block.program.blocks))
|
|
|
|
|
return self.block.program.blocks[id]
|
|
|
|
|
|
|
|
|
|
def blocks_attr(self, name):
|
|
|
|
|
def _blocks_attr(self, name):
|
|
|
|
|
"""
|
|
|
|
|
Get the blocks attribute by name.
|
|
|
|
|
|
|
|
|
@ -886,13 +884,13 @@ class Operator(object):
|
|
|
|
|
list: list of the blocks attribute.
|
|
|
|
|
"""
|
|
|
|
|
attrs = []
|
|
|
|
|
for i in self.blocks_attr_ids(name):
|
|
|
|
|
for i in self._blocks_attr_ids(name):
|
|
|
|
|
assert (i >= 0 and i < len(self.block.program.blocks))
|
|
|
|
|
attrs.append(self.block.program.blocks[i])
|
|
|
|
|
|
|
|
|
|
return attrs
|
|
|
|
|
|
|
|
|
|
def blocks_attr_ids(self, name):
|
|
|
|
|
def _blocks_attr_ids(self, name):
|
|
|
|
|
"""
|
|
|
|
|
Get the blocks attribute's ids by name.
|
|
|
|
|
|
|
|
|
@ -903,7 +901,7 @@ class Operator(object):
|
|
|
|
|
list: list of the blocks ids.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
return self.desc.blocks_attr_ids(name)
|
|
|
|
|
return self.desc._blocks_attr_ids(name)
|
|
|
|
|
|
|
|
|
|
def all_attrs(self):
|
|
|
|
|
"""
|
|
|
|
@ -917,11 +915,11 @@ class Operator(object):
|
|
|
|
|
for n in attr_names:
|
|
|
|
|
attr_type = self.desc.attr_type(n)
|
|
|
|
|
if attr_type == core.AttrType.BLOCK:
|
|
|
|
|
attr_map[n] = self.block_attr(n)
|
|
|
|
|
attr_map[n] = self._block_attr(n)
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
if attr_type == core.AttrType.BLOCKS:
|
|
|
|
|
attr_map[n] = self.blocks_attr(n)
|
|
|
|
|
attr_map[n] = self._blocks_attr(n)
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
attr_map[n] = self.attr(n)
|
|
|
|
@ -1795,7 +1793,7 @@ class Program(object):
|
|
|
|
|
for j in six.moves.range(block.op_size()):
|
|
|
|
|
op = block.op(j)
|
|
|
|
|
if op.has_attr('is_test'):
|
|
|
|
|
op.set_attr('is_test', True)
|
|
|
|
|
op._set_attr('is_test', True)
|
|
|
|
|
res.blocks = [
|
|
|
|
|
Block(res, i) for i in six.moves.range(res.desc.num_blocks())
|
|
|
|
|
]
|
|
|
|
@ -2169,7 +2167,7 @@ def program_guard(main_program, startup_program=None):
|
|
|
|
|
switch_startup_program(startup_program)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_var(name, program=None):
|
|
|
|
|
def _get_var(name, program=None):
|
|
|
|
|
"""
|
|
|
|
|
Get a variable by name from the global block of a program.
|
|
|
|
|
|
|
|
|
|