|
|
|
@ -34,7 +34,14 @@ registerd_op = {
|
|
|
|
|
"relu_grad": "ReluGradParser",
|
|
|
|
|
"softmax_with_cross_entropy_grad": "SoftmaxWithCrossEntropyGradParser",
|
|
|
|
|
"truncated_gaussian_random": "TruncatedNormalParser",
|
|
|
|
|
"sgd": "SGDParser"
|
|
|
|
|
"sgd": "SGDParser",
|
|
|
|
|
"c_allgather": "AllGatherParser",
|
|
|
|
|
"c_allreduce_sum": "AllReduceSumParser",
|
|
|
|
|
"c_allreduce_max": "AllReduceMaxParser",
|
|
|
|
|
"c_broadcast": "BroadcastParser",
|
|
|
|
|
"c_reduce_scatter": "ReduceScatterParser",
|
|
|
|
|
"c_send": "SendParser",
|
|
|
|
|
"c_receive": "ReceiveParser"
|
|
|
|
|
}
|
|
|
|
|
global_cnt = -1
|
|
|
|
|
global_input_cnt = -1
|
|
|
|
@ -522,6 +529,135 @@ class TruncatedNormalParser(AscendParserBase):
|
|
|
|
|
)
|
|
|
|
|
return [truncated_normal], [[0]] #[assign]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AllGatherParser(AscendParserBase):
|
|
|
|
|
def __init__(self, graph, var2geop):
|
|
|
|
|
super(AllGatherParser, self).__init__(graph, var2geop)
|
|
|
|
|
self.parser_name = "c_allgather"
|
|
|
|
|
|
|
|
|
|
def _apply(self):
|
|
|
|
|
x = self._get_ge_input(self.op.input_arg_names[0])
|
|
|
|
|
rank_size = self.op.attr("rank_size")
|
|
|
|
|
group = self.op.attr("group")
|
|
|
|
|
|
|
|
|
|
allgather = core.GEOperatorFactory.create_operator(
|
|
|
|
|
"allgather" + self._accumulated_op_id(), "HcomAllGather").set_input(
|
|
|
|
|
"x", x).set_attr_int32(
|
|
|
|
|
"rank_size", rank_size).set_attr_string("group", group)
|
|
|
|
|
return [allgather], [[0]]
|
|
|
|
|
|
|
|
|
|
class AllReduceParser(AscendParserBase):
|
|
|
|
|
def __init__(self, graph, var2geop, reduction):
|
|
|
|
|
super(AllReduceParser, self).__init__(graph, var2geop)
|
|
|
|
|
self.parser_name = "c_allreduce_" + reduction
|
|
|
|
|
self.reduction = reduction
|
|
|
|
|
|
|
|
|
|
def _apply(self):
|
|
|
|
|
x = self._get_ge_input(self.op.input_arg_names[0])
|
|
|
|
|
reduction = self.reduction
|
|
|
|
|
group = "hccl_world_group" #self.op.attr("group")
|
|
|
|
|
fusion = None #self.op.attr("fusion")
|
|
|
|
|
fusion_id = None #self.op.attr("fusion_id")
|
|
|
|
|
|
|
|
|
|
allreduce = core.GEOperatorFactory.create_operator(
|
|
|
|
|
"allreduce" + self._accumulated_op_id(), "HcomAllReduce").set_input(
|
|
|
|
|
"x", x).set_attr_string(
|
|
|
|
|
"reduction", reduction).set_attr_string("group", group)
|
|
|
|
|
if fusion is not None:
|
|
|
|
|
allreduce.set_attr_int32("fusion", fusion)
|
|
|
|
|
|
|
|
|
|
if fusion_id is not None:
|
|
|
|
|
allreduce.set_attr_int32("fusion_id", fusion_id)
|
|
|
|
|
return [allreduce], [[0]]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AllReduceSumParser(AllReduceParser):
|
|
|
|
|
def __init__(self, graph, var2geop):
|
|
|
|
|
super(AllReduceSumParser, self).__init__(graph, var2geop, 'sum')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AllReduceMaxParser(AllReduceParser):
|
|
|
|
|
def __init__(self, graph, var2geop):
|
|
|
|
|
super(AllReduceMaxParser, self).__init__(graph, var2geop, 'max')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BroadcastParser(AscendParserBase):
|
|
|
|
|
def __init__(self, graph, var2geop):
|
|
|
|
|
super(BroadcastParser, self).__init__(graph, var2geop)
|
|
|
|
|
self.parser_name = "c_broadcast"
|
|
|
|
|
|
|
|
|
|
def _apply(self):
|
|
|
|
|
x = self._get_ge_input(self.op.input_arg_names[0])
|
|
|
|
|
root_rank = self.op.attr("root_rank")
|
|
|
|
|
group = self.op.attr("group")
|
|
|
|
|
|
|
|
|
|
broadcast = core.GEOperatorFactory.create_operator(
|
|
|
|
|
"broadcast" + self._accumulated_op_id(), "HcomBroadcast").set_input(
|
|
|
|
|
"x", x).set_attr_int32(
|
|
|
|
|
"root_rank", root_rank).set_attr_string("group", group)
|
|
|
|
|
return [broadcast], [[0]]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ReduceScatterParser(AscendParserBase):
|
|
|
|
|
def __init__(self, graph, var2geop):
|
|
|
|
|
super(ReduceScatterParser, self).__init__(graph, var2geop)
|
|
|
|
|
self.parser_name = "c_reduce_scatter"
|
|
|
|
|
|
|
|
|
|
def _apply(self):
|
|
|
|
|
x = self._get_ge_input(self.op.input_arg_names[0])
|
|
|
|
|
reduction = self.op.attr("reduction")
|
|
|
|
|
group = self.op.attr("group")
|
|
|
|
|
rank_size = self.op.attr("rank_size")
|
|
|
|
|
|
|
|
|
|
reduce_scatter = core.GEOperatorFactory.create_operator(
|
|
|
|
|
"reducescatter" + self._accumulated_op_id(), "HcomReduceScatter").set_input(
|
|
|
|
|
"x", x).set_attr_string(
|
|
|
|
|
"reduction", reduction).set_attr_string(
|
|
|
|
|
"group", group).set_attr_int32("rank_size", rank_size)
|
|
|
|
|
return [reduce_scatter], [[0]]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SendParser(AscendParserBase):
|
|
|
|
|
def __init__(self, graph, var2geop):
|
|
|
|
|
super(SendParser, self).__init__(graph, var2geop)
|
|
|
|
|
self.parser_name = "c_send"
|
|
|
|
|
|
|
|
|
|
def _apply(self):
|
|
|
|
|
x = self._get_ge_input(self.op.input_arg_names[0])
|
|
|
|
|
sr_tag = self.op.attr("sr_tag")
|
|
|
|
|
dest_rank = self.op.attr("dest_rank")
|
|
|
|
|
group = self.op.attr("group")
|
|
|
|
|
|
|
|
|
|
send = core.GEOperatorFactory.create_operator(
|
|
|
|
|
"send" + self._accumulated_op_id(), "HcomSend").set_input(
|
|
|
|
|
"x", x).set_attr_int32(
|
|
|
|
|
"sr_tag", sr_tag).set_attr_int32(
|
|
|
|
|
"dest_rank", dest_rank).set_attr_string("group", group)
|
|
|
|
|
return [send], [[0]]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ReceiveParser(AscendParserBase):
|
|
|
|
|
def __init__(self, graph, var2geop):
|
|
|
|
|
super(ReceiveParser, self).__init__(graph, var2geop)
|
|
|
|
|
self.parser_name = "c_receive"
|
|
|
|
|
|
|
|
|
|
def _apply(self):
|
|
|
|
|
x = self._get_ge_input(self.op.input_arg_names[0])
|
|
|
|
|
sr_tag = self.op.attr("sr_tag")
|
|
|
|
|
src_rank = self.op.attr("src_rank")
|
|
|
|
|
group = self.op.attr("group")
|
|
|
|
|
shape = self.op.attr("shape")
|
|
|
|
|
dtype = self.op.attr("dtype")
|
|
|
|
|
|
|
|
|
|
receive = core.GEOperatorFactory.create_operator(
|
|
|
|
|
"receive" + self._accumulated_op_id(), "HcomReceive").set_input(
|
|
|
|
|
"x", x).set_attr_int32(
|
|
|
|
|
"sr_tag", sr_tag).set_attr_int32(
|
|
|
|
|
"src_rank", src_rank).set_attr_string(
|
|
|
|
|
"group", group).set_attr_vec_int32(
|
|
|
|
|
"shape", shape).set_attr_int32("dtype", dtype)
|
|
|
|
|
return [receive], [[0]]
|
|
|
|
|
|
|
|
|
|
class ScaleParser(AscendParserBase):
|
|
|
|
|
def __init__(self, graph, var2geop):
|
|
|
|
|
super(ScaleParser, self).__init__(graph, var2geop)
|
|
|
|
|