|
|
|
@ -182,18 +182,34 @@ AnfNodePtr CreateValueNodeOfDeviceNumReciprocal(const FuncGraphPtr &graph, const
|
|
|
|
|
return device_num_reciprocal_value;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
AnfNodePtr InsertCast(const FuncGraphPtr &graph, const AnfNodePtr &input, const TypeId dst_type) {
|
|
|
|
|
if (AnfAlgo::GetOutputInferDataType(input, 0) != dst_type) {
|
|
|
|
|
AnfNodePtr cast = graph->NewCNode({NewValueNode(std::make_shared<Primitive>(kCastOpName)), input});
|
|
|
|
|
AnfAlgo::SetOutputInferTypeAndShape({dst_type}, {AnfAlgo::GetOutputInferShape(input, 0)}, cast.get());
|
|
|
|
|
AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(true), cast);
|
|
|
|
|
cast->set_scope(input->scope());
|
|
|
|
|
return cast;
|
|
|
|
|
}
|
|
|
|
|
return input;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
AnfNodePtr CreateAllReduceAndMul(const FuncGraphPtr &graph, const AnfNodePtr &allreduce_input,
|
|
|
|
|
const CNodePtr &sync_bn_cnode) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(allreduce_input);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(sync_bn_cnode);
|
|
|
|
|
|
|
|
|
|
// Cast input to fp32, this can reduce the number of cast node. Since the input of AllReduce,
|
|
|
|
|
// BNTrainingReduce/BNTrainingUpdateGrad op only support fp32 output, when inferred output is fp16, it will
|
|
|
|
|
// insert cast: output_fp32->cast_fp16->allreduce&mul->cast_fp32. Add this cast can eliminate above cast.
|
|
|
|
|
// Should be removed if BNTrainingReduce/BNTrainingUpdateGrad op support fp16 output.
|
|
|
|
|
AnfNodePtr input_node = InsertCast(graph, allreduce_input, kNumberTypeFloat32);
|
|
|
|
|
|
|
|
|
|
// create AllReduce
|
|
|
|
|
std::vector<AnfNodePtr> allreduce_inputs = {NewValueNode(std::make_shared<Primitive>(kAllReduceOpName)),
|
|
|
|
|
allreduce_input};
|
|
|
|
|
std::vector<AnfNodePtr> allreduce_inputs = {NewValueNode(std::make_shared<Primitive>(kAllReduceOpName)), input_node};
|
|
|
|
|
auto allreduce = graph->NewCNode(allreduce_inputs);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(allreduce);
|
|
|
|
|
allreduce->set_abstract(allreduce_input->abstract());
|
|
|
|
|
allreduce->set_abstract(input_node->abstract());
|
|
|
|
|
allreduce->set_scope(allreduce_input->scope());
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrOp, MakeValue(kReduceOpSum), allreduce);
|
|
|
|
|
AnfAlgo::CopyNodeAttr(kAttrGroup, sync_bn_cnode, allreduce);
|
|
|
|
@ -216,9 +232,12 @@ AnfNodePtr CreateAllReduceAndMul(const FuncGraphPtr &graph, const AnfNodePtr &al
|
|
|
|
|
device_num_reciprocal_vnode};
|
|
|
|
|
auto mul = graph->NewCNode(mul_inputs);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(mul);
|
|
|
|
|
mul->set_abstract(allreduce_input->abstract());
|
|
|
|
|
mul->set_abstract(input_node->abstract());
|
|
|
|
|
mul->set_scope(allreduce_input->scope());
|
|
|
|
|
return mul;
|
|
|
|
|
|
|
|
|
|
// Cast output to origin datatype to reduce the number of cast node.
|
|
|
|
|
// Should be removed if BNTrainingReduce/BNTrainingUpdateGrad op support fp16 output.
|
|
|
|
|
return InsertCast(graph, mul, AnfAlgo::GetOutputInferDataType(allreduce_input, 0));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const BaseRef BnSplit::DefinePattern() const {
|
|
|
|
|