fix sparse_softmax_cross_entropy_with_logits

pull/11593/head
jjfeing 5 years ago
parent 54b8d53780
commit 2adff83c99

@ -123,7 +123,7 @@ CNodePtr CreateSoftmaxCrossEntropyWithLogits(const FuncGraphPtr &graph, const CN
std::vector<size_t> labels_shape = AnfAlgo::GetOutputInferShape(one_hot_node, 0);
std::vector<size_t> loss_shape;
if (labels_shape.size() > 0) {
if (!labels_shape.empty()) {
loss_shape.emplace_back(labels_shape[0]);
} else {
MS_LOG(EXCEPTION) << "one_hot output's shape is empty.";
@ -320,7 +320,7 @@ CNodePtr CreateRealDiv(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax
if (labels_shape.size() != 1) {
MS_LOG(EXCEPTION) << "label's shape should be 1-D.";
}
float y_value = static_cast<float>(labels_shape[0]);
auto y_value = static_cast<float>(labels_shape[0]);
auto y = std::make_shared<tensor::Tensor>(y_value, kFloat32);
auto y_node = CreateValueNode(y, kNumberTypeFloat32);
MS_EXCEPTION_IF_NULL(y_node);
@ -436,10 +436,11 @@ const BaseRef GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::DefinePattern(
VarPtr x1 = std::make_shared<Var>();
VarPtr x2 = std::make_shared<Var>();
VarPtr x3 = std::make_shared<Var>();
VarPtr x4 = std::make_shared<Var>();
VectorRef sparse_softmax_cross_entropy_with_logits_grad({prim::kPrimSparseSoftmaxCrossEntropyWithLogits, x1, x2});
VectorRef sparse_softmax_cross_entropy_with_logits({prim::kPrimSparseSoftmaxCrossEntropyWithLogits, x1, x2});
VectorRef depend({prim::kPrimDepend, sparse_softmax_cross_entropy_with_logits, x3});
return VectorRef({prim::kPrimMul, depend, x4});
VectorRef depend(
{prim::kPrimDepend, sparse_softmax_cross_entropy_with_logits_grad, sparse_softmax_cross_entropy_with_logits});
return VectorRef({prim::kPrimMul, depend, x3});
}
const AnfNodePtr GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::Process(const FuncGraphPtr &graph,
@ -455,6 +456,7 @@ const AnfNodePtr GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::Process(con
}
auto depend_node = GetDependNode(mul_node);
auto sparse_softmax_node = GetSparseNode(depend_node, 2);
auto sparse_softmax_node_grad = GetSparseNode(depend_node, 1);
if (sparse_softmax_node_grad->size() != kSparseSoftmaxCrossEntropyWithLogitsInputNum) {
MS_LOG(EXCEPTION) << "Op SparseSoftmaxCrossEntropyWithLogits's input not equal "
@ -467,6 +469,7 @@ const AnfNodePtr GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::Process(con
std::vector<AnfNodePtr> softmax_node_outputs;
CreateMultipleOutputsOfAnfNode(graph, softmax_node, kSoftmaxCrossEntropyWithLogitsOutputNum, &softmax_node_outputs);
auto reduce_node = CreateReduceMean(graph, sparse_softmax_node_grad, softmax_node_outputs[0]);
auto tile_node = CreateTile(graph, sparse_softmax_node_grad, mul_node);
CNodePtr real_div_node;
if (tile_node == nullptr) {
@ -484,16 +487,22 @@ const AnfNodePtr GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::Process(con
auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager);
manager->Replace(sparse_softmax_node_grad, softmax_node_outputs[1]);
manager->Replace(sparse_softmax_node, reduce_node);
manager->Replace(mul_node, new_mul_node);
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimDepend->name())),
NewValueNode(MakeValue<bool>(true)), NewValueNode(MakeValue<bool>(true))};
auto new_depend = graph->NewCNode(inputs);
manager->Replace(sparse_softmax_node_grad, new_depend);
return new_mul_node;
}
const BaseRef GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV2::DefinePattern() const {
VarPtr x1 = std::make_shared<Var>();
VarPtr x2 = std::make_shared<Var>();
VarPtr x3 = std::make_shared<Var>();
VectorRef sparse_softmax_cross_entropy_with_logits_grad({prim::kPrimSparseSoftmaxCrossEntropyWithLogits, x1, x2});
VectorRef sparse_softmax_cross_entropy_with_logits({prim::kPrimSparseSoftmaxCrossEntropyWithLogits, x1, x2});
return VectorRef({prim::kPrimDepend, sparse_softmax_cross_entropy_with_logits, x3});
return VectorRef(
{prim::kPrimDepend, sparse_softmax_cross_entropy_with_logits_grad, sparse_softmax_cross_entropy_with_logits});
}
const AnfNodePtr GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV2::Process(const FuncGraphPtr &graph,
@ -504,6 +513,7 @@ const AnfNodePtr GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV2::Process(c
auto depend_node = node->cast<CNodePtr>();
auto sparse_softmax_node_grad = GetSparseNode(depend_node, 1);
auto sparse_softmax_node = GetSparseNode(depend_node, 2);
CNodePtr softmax_node;
auto one_hot_node = CreateOneHot(graph, sparse_softmax_node_grad);
@ -511,11 +521,12 @@ const AnfNodePtr GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV2::Process(c
std::vector<AnfNodePtr> softmax_node_outputs;
CreateMultipleOutputsOfAnfNode(graph, softmax_node, kSoftmaxCrossEntropyWithLogitsOutputNum, &softmax_node_outputs);
auto reduce_node = CreateReduceMean(graph, sparse_softmax_node_grad, softmax_node_outputs[0]);
auto mul_node = CreateMul(graph, sparse_softmax_node_grad, softmax_node_outputs[1]);
auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager);
manager->Replace(sparse_softmax_node_grad, softmax_node_outputs[1]);
manager->Replace(sparse_softmax_node, reduce_node);
return mul_node;
}

@ -585,9 +585,9 @@ void AscendSession::UnifyMindIR(const KernelGraphPtr &graph) {
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
unify_mindir_pm->AddPass(std::make_shared<opt::DropoutGradUnifyMindIR>());
unify_mindir_pm->AddPass(std::make_shared<opt::DropoutUnifyMindIR>());
unify_mindir_pm->AddPass(std::make_shared<opt::SparseSoftmaxCrossEntropyWithLogitsUnifyMindIR>());
unify_mindir_pm->AddPass(std::make_shared<opt::GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR>());
unify_mindir_pm->AddPass(std::make_shared<opt::GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV2>());
unify_mindir_pm->AddPass(std::make_shared<opt::SparseSoftmaxCrossEntropyWithLogitsUnifyMindIR>());
} else {
unify_mindir_pm->AddPass(std::make_shared<opt::DropoutUnifyMindIRPynative>());
unify_mindir_pm->AddPass(std::make_shared<opt::DropoutGradUnifyMindIRPynative>());

Loading…
Cancel
Save