test=develop
release/1.1
Xin Pan 6 years ago
parent 38cf553108
commit 4cd44c00c5

@ -49,6 +49,7 @@ LOOKUP_TABLE_GRAD_TYPE = "lookup_table_grad"
OP_ROLE_VAR_ATTR_NAME = core.op_proto_and_checker_maker.kOpRoleVarAttrName() OP_ROLE_VAR_ATTR_NAME = core.op_proto_and_checker_maker.kOpRoleVarAttrName()
RPC_OP_ROLE_ATTR_NAME = op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName( RPC_OP_ROLE_ATTR_NAME = op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName(
) )
OPT_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.Optimize
RPC_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.RPC RPC_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.RPC
DIST_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.Dist DIST_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.Dist
LR_SCHED_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.LRSched LR_SCHED_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.LRSched
@ -1717,8 +1718,10 @@ to transpile() call.")
lr_ops = [] lr_ops = []
block = self.origin_program.global_block() block = self.origin_program.global_block()
for op in block.ops: for op in block.ops:
if int(op.attr(RPC_OP_ROLE_ATTR_NAME)) | int( role_id = int(op.attr(RPC_OP_ROLE_ATTR_NAME))
LR_SCHED_OP_ROLE_ATTR_VALUE) > 0: if role_id == int(LR_SCHED_OP_ROLE_ATTR_VALUE) or \
role_id == int(LR_SCHED_OP_ROLE_ATTR_VALUE) | \
int(OPT_OP_ROLE_ATTR_VALUE):
lr_ops.append(op) lr_ops.append(op)
log("append lr op: ", op.type) log("append lr op: ", op.type)
return lr_ops return lr_ops

Loading…
Cancel
Save