|
|
|
@ -49,6 +49,7 @@ LOOKUP_TABLE_GRAD_TYPE = "lookup_table_grad"
|
|
|
|
|
OP_ROLE_VAR_ATTR_NAME = core.op_proto_and_checker_maker.kOpRoleVarAttrName()
|
|
|
|
|
RPC_OP_ROLE_ATTR_NAME = op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName(
|
|
|
|
|
)
|
|
|
|
|
OPT_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.Optimize
|
|
|
|
|
RPC_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.RPC
|
|
|
|
|
DIST_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.Dist
|
|
|
|
|
LR_SCHED_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.LRSched
|
|
|
|
@ -1717,8 +1718,10 @@ to transpile() call.")
|
|
|
|
|
lr_ops = []
|
|
|
|
|
block = self.origin_program.global_block()
|
|
|
|
|
for op in block.ops:
|
|
|
|
|
if int(op.attr(RPC_OP_ROLE_ATTR_NAME)) | int(
|
|
|
|
|
LR_SCHED_OP_ROLE_ATTR_VALUE) > 0:
|
|
|
|
|
role_id = int(op.attr(RPC_OP_ROLE_ATTR_NAME))
|
|
|
|
|
if role_id == int(LR_SCHED_OP_ROLE_ATTR_VALUE) or \
|
|
|
|
|
role_id == int(LR_SCHED_OP_ROLE_ATTR_VALUE) | \
|
|
|
|
|
int(OPT_OP_ROLE_ATTR_VALUE):
|
|
|
|
|
lr_ops.append(op)
|
|
|
|
|
log("append lr op: ", op.type)
|
|
|
|
|
return lr_ops
|
|
|
|
|