|
|
|
|
@ -123,7 +123,7 @@ CNodePtr CreateSoftmaxCrossEntropyWithLogits(const FuncGraphPtr &graph, const CN
|
|
|
|
|
|
|
|
|
|
std::vector<size_t> labels_shape = AnfAlgo::GetOutputInferShape(one_hot_node, 0);
|
|
|
|
|
std::vector<size_t> loss_shape;
|
|
|
|
|
if (labels_shape.size() > 0) {
|
|
|
|
|
if (!labels_shape.empty()) {
|
|
|
|
|
loss_shape.emplace_back(labels_shape[0]);
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(EXCEPTION) << "one_hot output's shape is empty.";
|
|
|
|
|
@ -320,7 +320,7 @@ CNodePtr CreateRealDiv(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax
|
|
|
|
|
if (labels_shape.size() != 1) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "label's shape should be 1-D.";
|
|
|
|
|
}
|
|
|
|
|
float y_value = static_cast<float>(labels_shape[0]);
|
|
|
|
|
auto y_value = static_cast<float>(labels_shape[0]);
|
|
|
|
|
auto y = std::make_shared<tensor::Tensor>(y_value, kFloat32);
|
|
|
|
|
auto y_node = CreateValueNode(y, kNumberTypeFloat32);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(y_node);
|
|
|
|
|
@ -436,10 +436,11 @@ const BaseRef GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::DefinePattern(
|
|
|
|
|
VarPtr x1 = std::make_shared<Var>();
|
|
|
|
|
VarPtr x2 = std::make_shared<Var>();
|
|
|
|
|
VarPtr x3 = std::make_shared<Var>();
|
|
|
|
|
VarPtr x4 = std::make_shared<Var>();
|
|
|
|
|
VectorRef sparse_softmax_cross_entropy_with_logits_grad({prim::kPrimSparseSoftmaxCrossEntropyWithLogits, x1, x2});
|
|
|
|
|
VectorRef sparse_softmax_cross_entropy_with_logits({prim::kPrimSparseSoftmaxCrossEntropyWithLogits, x1, x2});
|
|
|
|
|
VectorRef depend({prim::kPrimDepend, sparse_softmax_cross_entropy_with_logits, x3});
|
|
|
|
|
return VectorRef({prim::kPrimMul, depend, x4});
|
|
|
|
|
VectorRef depend(
|
|
|
|
|
{prim::kPrimDepend, sparse_softmax_cross_entropy_with_logits_grad, sparse_softmax_cross_entropy_with_logits});
|
|
|
|
|
return VectorRef({prim::kPrimMul, depend, x3});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const AnfNodePtr GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::Process(const FuncGraphPtr &graph,
|
|
|
|
|
@ -455,6 +456,7 @@ const AnfNodePtr GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::Process(con
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto depend_node = GetDependNode(mul_node);
|
|
|
|
|
auto sparse_softmax_node = GetSparseNode(depend_node, 2);
|
|
|
|
|
auto sparse_softmax_node_grad = GetSparseNode(depend_node, 1);
|
|
|
|
|
if (sparse_softmax_node_grad->size() != kSparseSoftmaxCrossEntropyWithLogitsInputNum) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Op SparseSoftmaxCrossEntropyWithLogits's input not equal "
|
|
|
|
|
@ -467,6 +469,7 @@ const AnfNodePtr GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::Process(con
|
|
|
|
|
|
|
|
|
|
std::vector<AnfNodePtr> softmax_node_outputs;
|
|
|
|
|
CreateMultipleOutputsOfAnfNode(graph, softmax_node, kSoftmaxCrossEntropyWithLogitsOutputNum, &softmax_node_outputs);
|
|
|
|
|
auto reduce_node = CreateReduceMean(graph, sparse_softmax_node_grad, softmax_node_outputs[0]);
|
|
|
|
|
auto tile_node = CreateTile(graph, sparse_softmax_node_grad, mul_node);
|
|
|
|
|
CNodePtr real_div_node;
|
|
|
|
|
if (tile_node == nullptr) {
|
|
|
|
|
@ -484,16 +487,22 @@ const AnfNodePtr GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::Process(con
|
|
|
|
|
|
|
|
|
|
auto manager = graph->manager();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(manager);
|
|
|
|
|
manager->Replace(sparse_softmax_node_grad, softmax_node_outputs[1]);
|
|
|
|
|
manager->Replace(sparse_softmax_node, reduce_node);
|
|
|
|
|
manager->Replace(mul_node, new_mul_node);
|
|
|
|
|
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimDepend->name())),
|
|
|
|
|
NewValueNode(MakeValue<bool>(true)), NewValueNode(MakeValue<bool>(true))};
|
|
|
|
|
auto new_depend = graph->NewCNode(inputs);
|
|
|
|
|
manager->Replace(sparse_softmax_node_grad, new_depend);
|
|
|
|
|
return new_mul_node;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const BaseRef GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV2::DefinePattern() const {
|
|
|
|
|
VarPtr x1 = std::make_shared<Var>();
|
|
|
|
|
VarPtr x2 = std::make_shared<Var>();
|
|
|
|
|
VarPtr x3 = std::make_shared<Var>();
|
|
|
|
|
VectorRef sparse_softmax_cross_entropy_with_logits_grad({prim::kPrimSparseSoftmaxCrossEntropyWithLogits, x1, x2});
|
|
|
|
|
VectorRef sparse_softmax_cross_entropy_with_logits({prim::kPrimSparseSoftmaxCrossEntropyWithLogits, x1, x2});
|
|
|
|
|
return VectorRef({prim::kPrimDepend, sparse_softmax_cross_entropy_with_logits, x3});
|
|
|
|
|
return VectorRef(
|
|
|
|
|
{prim::kPrimDepend, sparse_softmax_cross_entropy_with_logits_grad, sparse_softmax_cross_entropy_with_logits});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const AnfNodePtr GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV2::Process(const FuncGraphPtr &graph,
|
|
|
|
|
@ -504,6 +513,7 @@ const AnfNodePtr GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV2::Process(c
|
|
|
|
|
|
|
|
|
|
auto depend_node = node->cast<CNodePtr>();
|
|
|
|
|
auto sparse_softmax_node_grad = GetSparseNode(depend_node, 1);
|
|
|
|
|
auto sparse_softmax_node = GetSparseNode(depend_node, 2);
|
|
|
|
|
|
|
|
|
|
CNodePtr softmax_node;
|
|
|
|
|
auto one_hot_node = CreateOneHot(graph, sparse_softmax_node_grad);
|
|
|
|
|
@ -511,11 +521,12 @@ const AnfNodePtr GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV2::Process(c
|
|
|
|
|
|
|
|
|
|
std::vector<AnfNodePtr> softmax_node_outputs;
|
|
|
|
|
CreateMultipleOutputsOfAnfNode(graph, softmax_node, kSoftmaxCrossEntropyWithLogitsOutputNum, &softmax_node_outputs);
|
|
|
|
|
auto reduce_node = CreateReduceMean(graph, sparse_softmax_node_grad, softmax_node_outputs[0]);
|
|
|
|
|
auto mul_node = CreateMul(graph, sparse_softmax_node_grad, softmax_node_outputs[1]);
|
|
|
|
|
|
|
|
|
|
auto manager = graph->manager();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(manager);
|
|
|
|
|
manager->Replace(sparse_softmax_node_grad, softmax_node_outputs[1]);
|
|
|
|
|
manager->Replace(sparse_softmax_node, reduce_node);
|
|
|
|
|
return mul_node;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|