|
|
|
@ -410,10 +410,14 @@ class Operator(object):
|
|
|
|
|
|
|
|
|
|
if op_maker.kOpRoleAttrName() not in self.attrs:
|
|
|
|
|
self.attrs[op_maker.kOpRoleAttrName()] = self.block.program.op_role
|
|
|
|
|
if len(self.block.program.op_role_var
|
|
|
|
|
) != 0 and op_maker.kOpRoleVarAttrName() not in self.attrs:
|
|
|
|
|
self.attrs[op_maker.kOpRoleVarAttrName(
|
|
|
|
|
)] = self.block.program.op_role_var
|
|
|
|
|
|
|
|
|
|
role_var_name = op_maker.kOpRoleVarAttrName()
|
|
|
|
|
if len(self.block.program.
|
|
|
|
|
op_role_var) != 0 and role_var_name not in self.attrs:
|
|
|
|
|
self.attrs[role_var_name] = self.block.program.op_role_var
|
|
|
|
|
|
|
|
|
|
if role_var_name in self.attrs and len(self.attrs[role_var_name]) == 0:
|
|
|
|
|
del self.attrs[role_var_name]
|
|
|
|
|
|
|
|
|
|
if len(self.desc.type()) != 0:
|
|
|
|
|
return
|
|
|
|
@ -497,7 +501,6 @@ class Operator(object):
|
|
|
|
|
attr_name, self.attrs[attr_name].serialize_to_string())
|
|
|
|
|
else:
|
|
|
|
|
self.desc.set_attr(attr_name, self.attrs[attr_name])
|
|
|
|
|
|
|
|
|
|
self.desc.check_attrs()
|
|
|
|
|
no_kernel_op_set = {
|
|
|
|
|
'feed', 'fetch', 'save', 'load', 'recurrent', 'go',
|
|
|
|
@ -1020,7 +1023,7 @@ class Program(object):
|
|
|
|
|
self.current_block_idx = 0
|
|
|
|
|
self._seed = 0
|
|
|
|
|
self._current_role = core.op_proto_and_checker_maker.OpRole.Forward
|
|
|
|
|
self._op_role_var = ""
|
|
|
|
|
self._op_role_var = []
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def op_role(self):
|
|
|
|
@ -1036,15 +1039,15 @@ class Program(object):
|
|
|
|
|
|
|
|
|
|
@op_role_var.setter
|
|
|
|
|
def set_op_role_var(self, var_name):
|
|
|
|
|
self._op_role_var = var_name
|
|
|
|
|
self._op_role_var = [var_name]
|
|
|
|
|
|
|
|
|
|
@contextlib.contextmanager
|
|
|
|
|
def optimized_guard(self, var):
|
|
|
|
|
OpRole = core.op_proto_and_checker_maker.OpRole
|
|
|
|
|
self._current_role = OpRole.Optimize
|
|
|
|
|
self._op_role_var = var.name if isinstance(var, Variable) else var
|
|
|
|
|
self._op_role_var = [var.name if isinstance(var, Variable) else var]
|
|
|
|
|
yield
|
|
|
|
|
self._op_role_var = ""
|
|
|
|
|
self._op_role_var = []
|
|
|
|
|
self._current_role = OpRole.Forward
|
|
|
|
|
|
|
|
|
|
def __str__(self):
|
|
|
|
|