fix run error when there is a Depend or ControlDepend on BatchNorm

pull/5623/head
huanghui 5 years ago
parent 8e442ce7ca
commit b8e737f66a

@ -79,7 +79,6 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt
: AnfAlgo::GetOutputInferShape(input_node, insert_index);
bool need_padding = is_insert_input ? trans::IsNeedPadding(dst_format, input_node_out_shape.size())
: trans::IsNeedPadding(input_format, input_node_out_shape.size());
if (!need_padding) {
// don't need padding insert transdata only
trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, prim::KPrimTransData->name());

@ -121,7 +121,9 @@ const AnfNodePtr BatchNorm2BNInfer::Process(const FuncGraphPtr &graph, const Anf
if (!NeedFusion(graph, node, &batchnorm)) {
return nullptr;
}
return CreateBNInfer(graph, batchnorm, node);
auto bn_infer = CreateBNInfer(graph, batchnorm, node);
TransferDepend(batchnorm, graph, bn_infer);
return bn_infer;
}
} // namespace opt
} // namespace mindspore

@ -81,7 +81,7 @@ bool CheckBatchNormGrad(const FuncGraphPtr &graph, const CNodePtr &batchnormgrad
return true;
}
bool NeedFusion(const FuncGraphPtr &graph, const AnfNodePtr &node, CNodePtr *batchnormgrad) {
bool NeedFusion(const FuncGraphPtr &graph, const AnfNodePtr &node, CNodePtr *batchnorm_grad) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
auto tuple_getitem = node->cast<CNodePtr>();
@ -93,12 +93,12 @@ bool NeedFusion(const FuncGraphPtr &graph, const AnfNodePtr &node, CNodePtr *bat
return false;
}
AnfNodePtr batchnormgrad_anf = tuple_getitem->input(kRealInputNodeIndexInTupleGetItem);
MS_EXCEPTION_IF_NULL(batchnormgrad_anf);
MS_EXCEPTION_IF_NULL(batchnormgrad);
*batchnormgrad = batchnormgrad_anf->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(*batchnormgrad);
return CheckBatchNormGrad(graph, *batchnormgrad);
AnfNodePtr batchnorm_grad_anf = tuple_getitem->input(kRealInputNodeIndexInTupleGetItem);
MS_EXCEPTION_IF_NULL(batchnorm_grad_anf);
MS_EXCEPTION_IF_NULL(batchnorm_grad);
*batchnorm_grad = batchnorm_grad_anf->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(*batchnorm_grad);
return CheckBatchNormGrad(graph, *batchnorm_grad);
}
} // namespace
@ -117,11 +117,13 @@ const AnfNodePtr BatchNormGrad2BNInferGrad::Process(const FuncGraphPtr &graph, c
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
CNodePtr batchnormgrad = nullptr;
if (!NeedFusion(graph, node, &batchnormgrad)) {
CNodePtr batchnorm_grad = nullptr;
if (!NeedFusion(graph, node, &batchnorm_grad)) {
return nullptr;
}
return CreateBNInferGrad(graph, batchnormgrad, node);
auto bn_infer_grad = CreateBNInferGrad(graph, batchnorm_grad, node);
TransferDepend(batchnorm_grad, graph, bn_infer_grad);
return bn_infer_grad;
}
} // namespace opt
} // namespace mindspore

@ -872,5 +872,26 @@ ValueNodePtr MakeValueNode(const ValueNodePtr &value_node) {
return new_value_node;
}
void TransferDepend(const CNodePtr &old_node, const FuncGraphPtr &graph, const CNodePtr &new_node) {
MS_EXCEPTION_IF_NULL(old_node);
MS_EXCEPTION_IF_NULL(graph);
auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager);
// find BatchNorm's output which is a Depend or ControlDepend
for (const auto &node_index : manager->node_users()[old_node]) {
AnfNodePtr output = node_index.first;
size_t index = IntToSize(node_index.second);
MS_EXCEPTION_IF_NULL(output);
if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimControlDepend)) {
auto control_depend = output->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(control_depend);
control_depend->set_input(index, new_node);
} else if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimDepend)) {
auto depend = output->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(depend);
depend->set_input(index, new_node);
}
}
}
} // namespace opt
} // namespace mindspore

@ -203,6 +203,9 @@ bool CheckSupportDataType(const AnfNodePtr &node, const std::set<TypeId> &suppor
// Create a new value node of func graph,not kernel graph
ValueNodePtr MakeValueNode(const ValueNodePtr &value_node);
// Transfer depend or control_depend to the new node
void TransferDepend(const CNodePtr &old_node, const FuncGraphPtr &graph, const CNodePtr &new_node);
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_HELPER_H_

@ -27,7 +27,6 @@
namespace mindspore {
namespace opt {
namespace {
static std::vector<size_t> g_output_idx;
bool HasAtomic(const AnfNodePtr &input) {

@ -98,7 +98,7 @@ void CheckInputs(const std::vector<AnfNodePtr> &fusion_inputs) {
}
}
bool CheckSegments(size_t segments, size_t communication_op_node_size, std::vector<size_t> *segment_index) {
bool CheckSegments(size_t segments, size_t communication_op_node_size, const std::vector<size_t> *segment_index) {
MS_EXCEPTION_IF_NULL(segment_index);
if (segments >= communication_op_node_size) {
MS_LOG(INFO) << "fusion not changed: segment_num=" << segments

@ -24,7 +24,7 @@ namespace opt {
class ConstToAttrStridedSliceGradPass : public PatternProcessPass {
public:
explicit ConstToAttrStridedSliceGradPass(bool multigraph = true)
: PatternProcessPass("const_to_attr_strided_slice_grad_", multigraph) {}
: PatternProcessPass("const_to_attr_strided_slice_grad", multigraph) {}
~ConstToAttrStridedSliceGradPass() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;

Loading…
Cancel
Save