|
|
|
@ -20,58 +20,33 @@
|
|
|
|
|
|
|
|
|
|
namespace mindspore {
|
|
|
|
|
namespace opt {
|
|
|
|
|
namespace {
|
|
|
|
|
void AddOutputs(const AnfNodePtr &node, int64_t rank_size) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
auto origin_abstract = node->abstract();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(origin_abstract);
|
|
|
|
|
auto tuple_abstract = origin_abstract->cast<abstract::AbstractTuplePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(tuple_abstract);
|
|
|
|
|
auto &origin_abstracts = tuple_abstract->elements();
|
|
|
|
|
AbstractBasePtrList abstract_list;
|
|
|
|
|
std::vector<TypeId> outputs_device_type;
|
|
|
|
|
std::vector<std::string> outputs_device_format;
|
|
|
|
|
for (int64_t i = 0; i < rank_size; ++i) {
|
|
|
|
|
for (size_t j = 0; j < origin_abstracts.size(); ++j) {
|
|
|
|
|
abstract_list.push_back(origin_abstracts[j]);
|
|
|
|
|
outputs_device_type.push_back(AnfAlgo::GetOutputDeviceDataType(node, j));
|
|
|
|
|
outputs_device_format.push_back(AnfAlgo::GetOutputFormat(node, j));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// Update abstract
|
|
|
|
|
auto new_abstracts = std::make_shared<abstract::AbstractTuple>(abstract_list);
|
|
|
|
|
node->set_abstract(new_abstracts);
|
|
|
|
|
// Update kernel build info
|
|
|
|
|
auto builder =
|
|
|
|
|
std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(AnfAlgo::GetSelectKernelBuildInfo(node));
|
|
|
|
|
builder->SetOutputsDeviceType(outputs_device_type);
|
|
|
|
|
builder->SetOutputsFormat(outputs_device_format);
|
|
|
|
|
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get());
|
|
|
|
|
}
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
AnfNodePtr ConcatOutputsForAllGather::InsertConcatForOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
|
|
|
|
const std::vector<AnfNodePtr> &new_tuple_getitems,
|
|
|
|
|
int64_t rank_size) const {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(func_graph);
|
|
|
|
|
std::vector<AnfNodePtr> make_tuple_inputs;
|
|
|
|
|
std::vector<AnfNodePtr> make_tuple_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name()))};
|
|
|
|
|
size_t inputs_size = AnfAlgo::GetInputTensorNum(node);
|
|
|
|
|
for (size_t i = 0; i < inputs_size; ++i) {
|
|
|
|
|
for (size_t j = 0, idx = i; j < LongToSize(rank_size); ++j, idx += inputs_size) {
|
|
|
|
|
std::vector<AnfNodePtr> concat_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name()))};
|
|
|
|
|
std::vector<AnfNodePtr> concat_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name()))};
|
|
|
|
|
for (size_t j = 0, idx = i; j < IntToSize(rank_size); ++j, idx += inputs_size) {
|
|
|
|
|
concat_inputs.push_back(new_tuple_getitems[idx]);
|
|
|
|
|
auto concat = func_graph->NewCNode(concat_inputs);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(concat);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(new_tuple_getitems[idx]);
|
|
|
|
|
concat->set_abstract(new_tuple_getitems[idx]->abstract());
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(static_cast<int64_t>(0)), concat);
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrInputNums, MakeValue(rank_size), concat);
|
|
|
|
|
std::vector<int64_t> dyn_input_size{rank_size};
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_size), concat);
|
|
|
|
|
kernel_select_->SelectKernel(concat);
|
|
|
|
|
make_tuple_inputs.push_back(concat);
|
|
|
|
|
}
|
|
|
|
|
auto concat = func_graph->NewCNode(concat_inputs);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(concat);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(new_tuple_getitems[i]);
|
|
|
|
|
auto dtypes = {AnfAlgo::GetOutputInferDataType(new_tuple_getitems[i], 0)};
|
|
|
|
|
std::vector<size_t> shape = AnfAlgo::GetOutputInferShape(new_tuple_getitems[i], 0);
|
|
|
|
|
shape[0] *= rank_size;
|
|
|
|
|
std::vector<std::vector<size_t>> shapes = {shape};
|
|
|
|
|
AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, concat.get());
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(static_cast<int64_t>(0)), concat);
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrInputNums, MakeValue(rank_size), concat);
|
|
|
|
|
std::vector<int64_t> dyn_input_size{rank_size};
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_size), concat);
|
|
|
|
|
kernel_select_->SelectKernel(concat);
|
|
|
|
|
make_tuple_inputs.push_back(concat);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto make_tuple = func_graph->NewCNode(make_tuple_inputs);
|
|
|
|
|
return make_tuple;
|
|
|
|
|
}
|
|
|
|
@ -94,8 +69,11 @@ const AnfNodePtr ConcatOutputsForAllGather::Process(const FuncGraphPtr &func_gra
|
|
|
|
|
if (fusion <= 0) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
if (AnfAlgo::HasNodeAttr("fused", cnode)) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
AnfAlgo::SetNodeAttr("fused", MakeValue(true), node);
|
|
|
|
|
auto rank_size = AnfAlgo::GetNodeAttr<int64_t>(node, kAttrRankSize);
|
|
|
|
|
AddOutputs(node, rank_size);
|
|
|
|
|
std::vector<AnfNodePtr> new_outputs;
|
|
|
|
|
CreateMultipleOutputsOfAnfNode(func_graph, node, AnfAlgo::GetOutputTensorNum(node), &new_outputs);
|
|
|
|
|
return InsertConcatForOutput(func_graph, node, new_outputs, rank_size);
|
|
|
|
|