|
|
|
@ -18,11 +18,17 @@ from paddle.fluid.optimizer import Optimizer
|
|
|
|
|
import paddle.fluid.core as core
|
|
|
|
|
import numpy as np
|
|
|
|
|
from . import ascend_parser
|
|
|
|
|
from paddle.distributed import fleet
|
|
|
|
|
import hccl.manage.api as hccl
|
|
|
|
|
from collections import namedtuple
|
|
|
|
|
|
|
|
|
|
HcomGroupConfig = namedtuple('HcomGroupConfig', ['name', 'nranks', 'rank_ids'])
|
|
|
|
|
|
|
|
|
|
class AscendIRParser(object):
|
|
|
|
|
def __init__(self):
|
|
|
|
|
self.graph_idx = 0
|
|
|
|
|
self.hcom_endpoints = {}
|
|
|
|
|
self.groups_to_create = []
|
|
|
|
|
|
|
|
|
|
def _construct_input_map(self, input_varlist):
|
|
|
|
|
ret_map = {}
|
|
|
|
@ -38,8 +44,37 @@ class AscendIRParser(object):
|
|
|
|
|
ret_map[var.name] = ge_input
|
|
|
|
|
return ge_in_operator, ret_map
|
|
|
|
|
|
|
|
|
|
def _endpoint_to_world_rank_id(self, endpoint):
|
|
|
|
|
world_endpoints = fleet.worker_endpoints()
|
|
|
|
|
assert endpoint in world_endpoints, "endpoint (%s) not in worker_endpoints (%s) " % (endpoint, fleet.world_device_ids())
|
|
|
|
|
return world_endpoints.index(endpoint)
|
|
|
|
|
|
|
|
|
|
def parse_op(self, op):
|
|
|
|
|
if op.type in ascend_parser.registerd_op:
|
|
|
|
|
if op.type == 'c_gen_nccl_id':
|
|
|
|
|
endpoint = op.attr("endpoint")
|
|
|
|
|
other_endpoints = op.attr("other_endpoints")
|
|
|
|
|
rank = op.attr("rank")
|
|
|
|
|
|
|
|
|
|
nccl_id = op.output_arg_names[0]
|
|
|
|
|
|
|
|
|
|
# c_gen_nccl_id operator splits endpoints into local endpoint and other_endpoints
|
|
|
|
|
# we should combine these together to produce world_rank_ids
|
|
|
|
|
self.hcom_endpoints[nccl_id] = other_endpoints[:]
|
|
|
|
|
self.hcom_endpoints[nccl_id].insert(rank, endpoint)
|
|
|
|
|
|
|
|
|
|
print("nccl_id (%s) registered endpoints %s" % (nccl_id, self.hcom_endpoints[nccl_id]))
|
|
|
|
|
elif op.type == 'c_comm_init':
|
|
|
|
|
nccl_id = op.input_arg_names[0]
|
|
|
|
|
nranks = op.attr("nranks")
|
|
|
|
|
assert nranks == len(self.hcom_endpoints[nccl_id]), "nranks doesn't match endpoint count"
|
|
|
|
|
rank = op.attr("rank")
|
|
|
|
|
ring_id = op.attr("ring_id")
|
|
|
|
|
|
|
|
|
|
group_name = "hcom_group_" + str(ring_id)
|
|
|
|
|
global_rank_ids = [self._endpoint_to_world_rank_id(endpoint) for endpoint in self.hcom_endpoints[nccl_id]]
|
|
|
|
|
self.groups_to_create.append(HcomGroupConfig(name=group_name, nranks=nranks, rank_ids=global_rank_ids))
|
|
|
|
|
print("append to create group: %s, with rank_ids: %s" % (group_name, global_rank_ids))
|
|
|
|
|
elif op.type in ascend_parser.registerd_op:
|
|
|
|
|
print("Op[%s] has been registered, begin to parse it" % (op.type))
|
|
|
|
|
op_parser = self.parser_factory.create_parse(ascend_parser.registerd_op[op.type])
|
|
|
|
|
op_parser.apply(op)
|
|
|
|
@ -137,6 +172,8 @@ class AscendOptimizer(Optimizer):
|
|
|
|
|
parameter_list=None,
|
|
|
|
|
no_grad_set=None,
|
|
|
|
|
auto_dp=False):
|
|
|
|
|
minimized = None
|
|
|
|
|
if self.inner_opt:
|
|
|
|
|
minimized = self.inner_opt.minimize(loss, startup_program=startup_program)
|
|
|
|
|
|
|
|
|
|
self.ascend_instance = core.AscendInstance()
|
|
|
|
@ -172,6 +209,10 @@ class AscendOptimizer(Optimizer):
|
|
|
|
|
startup_graph, main_graph = self.parser.parse_program(
|
|
|
|
|
startup_program, main_block.program, input_varlist, self.fetch_list)
|
|
|
|
|
|
|
|
|
|
for cfg in self.parser.groups_to_create:
|
|
|
|
|
hccl.create_group(cfg.name, cfg.nranks, cfg.rank_ids)
|
|
|
|
|
print("create group (%s), nranks: %d, rank_ids: %s" % (cfg.name, cfg.nranks, cfg.rank_ids))
|
|
|
|
|
|
|
|
|
|
self.ascend_instance.add_ascend_subgraph(0, startup_graph)
|
|
|
|
|
self.ascend_instance.add_ascend_subgraph(1, main_graph)
|
|
|
|
|
|
|
|
|
|