|
|
|
@ -50,7 +50,8 @@ ValueNodePtr CreateValueNode(const ValuePtr &value_ptr, TypeId output_type) {
|
|
|
|
|
return new_node;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
CNodePtr CreateOneHot(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_node, bool is_pynative = false) {
|
|
|
|
|
CNodePtr CreateOneHot(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_node,
|
|
|
|
|
bool is_convert_const_to_attr = false) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(sparse_softmax_node);
|
|
|
|
|
|
|
|
|
@ -80,7 +81,7 @@ CNodePtr CreateOneHot(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_
|
|
|
|
|
one_hot_primitive->set_attr(kAttrOutputNames, MakeValue(output_names));
|
|
|
|
|
|
|
|
|
|
std::vector<AnfNodePtr> one_hot_inputs;
|
|
|
|
|
if (is_pynative) {
|
|
|
|
|
if (is_convert_const_to_attr) {
|
|
|
|
|
one_hot_inputs = {NewValueNode(one_hot_primitive), sparse_softmax_node->input(2), value_on_node, value_off_node};
|
|
|
|
|
} else {
|
|
|
|
|
auto depth_node = NewValueNode(depth);
|
|
|
|
@ -97,7 +98,7 @@ CNodePtr CreateOneHot(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_
|
|
|
|
|
std::vector<size_t> labels_shape = AnfAlgo ::GetPrevNodeOutputInferShape(sparse_softmax_node, 1);
|
|
|
|
|
labels_shape.emplace_back(depth);
|
|
|
|
|
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32}, {labels_shape}, one_hot_node.get());
|
|
|
|
|
if (is_pynative) {
|
|
|
|
|
if (is_convert_const_to_attr) {
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrDepth, MakeValue(depth), one_hot_node);
|
|
|
|
|
}
|
|
|
|
|
return one_hot_node;
|
|
|
|
@ -252,7 +253,7 @@ CNodePtr CreateExpandDimsPynative(const FuncGraphPtr &graph, const CNodePtr &rea
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
CNodePtr CreateTile(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_node, const CNodePtr &mul_node,
|
|
|
|
|
bool is_pynative = false) {
|
|
|
|
|
bool is_convert_const_to_attr = false) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(sparse_softmax_node);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(mul_node);
|
|
|
|
@ -268,6 +269,9 @@ CNodePtr CreateTile(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_no
|
|
|
|
|
std::vector<int64_t> multiple_value;
|
|
|
|
|
std::transform(labels_shape.begin(), labels_shape.end(), std::back_inserter(multiple_value),
|
|
|
|
|
[](size_t label) { return static_cast<int64_t>(label); });
|
|
|
|
|
if (std::all_of(multiple_value.begin(), multiple_value.end(), [](int64_t value) { return value == 1; })) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
auto multiples = MakeValue(multiple_value);
|
|
|
|
|
auto multiples_node = CreateValueNode(multiples, kNumberTypeInt64);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(multiples_node);
|
|
|
|
@ -279,7 +283,7 @@ CNodePtr CreateTile(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_no
|
|
|
|
|
tile_primitive->set_attr(kAttrOutputNames, MakeValue(output_names));
|
|
|
|
|
|
|
|
|
|
std::vector<AnfNodePtr> tile_inputs;
|
|
|
|
|
if (is_pynative) {
|
|
|
|
|
if (is_convert_const_to_attr) {
|
|
|
|
|
tile_inputs = {NewValueNode(tile_primitive), mul_node->input(2)};
|
|
|
|
|
} else {
|
|
|
|
|
auto kernel_graph = graph->cast<KernelGraphPtr>();
|
|
|
|
@ -292,7 +296,7 @@ CNodePtr CreateTile(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_no
|
|
|
|
|
tile_node->set_scope(mul_node->scope());
|
|
|
|
|
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetPrevNodeOutputInferDataType(mul_node, 1)}, {labels_shape},
|
|
|
|
|
tile_node.get());
|
|
|
|
|
if (is_pynative) {
|
|
|
|
|
if (is_convert_const_to_attr) {
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrMultiples, MakeValue(multiples), tile_node);
|
|
|
|
|
}
|
|
|
|
|
// feature map set
|
|
|
|
@ -302,7 +306,7 @@ CNodePtr CreateTile(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_no
|
|
|
|
|
return tile_node;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
CNodePtr CreateRealDiv(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_node, const CNodePtr &tile_node) {
|
|
|
|
|
CNodePtr CreateRealDiv(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_node, const AnfNodePtr &tile_node) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(sparse_softmax_node);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(tile_node);
|
|
|
|
@ -464,16 +468,24 @@ const AnfNodePtr GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::Process(con
|
|
|
|
|
std::vector<AnfNodePtr> softmax_node_outputs;
|
|
|
|
|
CreateMultipleOutputsOfAnfNode(graph, softmax_node, kSoftmaxCrossEntropyWithLogitsOutputNum, &softmax_node_outputs);
|
|
|
|
|
auto tile_node = CreateTile(graph, sparse_softmax_node_grad, mul_node);
|
|
|
|
|
auto real_div_node = CreateRealDiv(graph, sparse_softmax_node_grad, tile_node);
|
|
|
|
|
CNodePtr real_div_node;
|
|
|
|
|
if (tile_node == nullptr) {
|
|
|
|
|
real_div_node = CreateRealDiv(graph, sparse_softmax_node_grad, mul_node->input(2));
|
|
|
|
|
} else {
|
|
|
|
|
real_div_node = CreateRealDiv(graph, sparse_softmax_node_grad, tile_node);
|
|
|
|
|
}
|
|
|
|
|
auto expand_dims_node = CreateExpandDims(graph, real_div_node);
|
|
|
|
|
|
|
|
|
|
mul_node->set_input(1, softmax_node_outputs[1]);
|
|
|
|
|
mul_node->set_input(2, expand_dims_node);
|
|
|
|
|
std::vector<AnfNodePtr> new_mul_inputs = {NewValueNode(std::make_shared<Primitive>(kMulOpName)),
|
|
|
|
|
softmax_node_outputs[1], expand_dims_node};
|
|
|
|
|
auto new_mul_node = graph->NewCNode(new_mul_inputs);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(new_mul_node);
|
|
|
|
|
new_mul_node->set_scope(mul_node->scope());
|
|
|
|
|
new_mul_node->set_abstract(mul_node->abstract());
|
|
|
|
|
|
|
|
|
|
auto manager = graph->manager();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(manager);
|
|
|
|
|
manager->Replace(sparse_softmax_node_grad, softmax_node_outputs[1]);
|
|
|
|
|
return mul_node;
|
|
|
|
|
return new_mul_node;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const BaseRef GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV2::DefinePattern() const {
|
|
|
|
@ -563,19 +575,26 @@ const AnfNodePtr PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::Pro
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
CNodePtr softmax_node;
|
|
|
|
|
auto one_hot_node = CreateOneHot(graph, sparse_softmax_node_grad, true);
|
|
|
|
|
auto one_hot_node = CreateOneHot(graph, sparse_softmax_node_grad);
|
|
|
|
|
softmax_node = CreateSoftmaxCrossEntropyWithLogits(graph, sparse_softmax_node_grad, one_hot_node);
|
|
|
|
|
|
|
|
|
|
std::vector<AnfNodePtr> softmax_node_outputs;
|
|
|
|
|
CreateMultipleOutputsOfAnfNode(graph, softmax_node, kSoftmaxCrossEntropyWithLogitsOutputNum, &softmax_node_outputs);
|
|
|
|
|
auto tile_node = CreateTile(graph, sparse_softmax_node_grad, mul_node, true);
|
|
|
|
|
auto real_div_node = CreateRealDiv(graph, sparse_softmax_node_grad, tile_node);
|
|
|
|
|
auto tile_node = CreateTile(graph, sparse_softmax_node_grad, mul_node);
|
|
|
|
|
CNodePtr real_div_node;
|
|
|
|
|
if (tile_node == nullptr) {
|
|
|
|
|
real_div_node = CreateRealDiv(graph, sparse_softmax_node_grad, mul_node->input(2));
|
|
|
|
|
} else {
|
|
|
|
|
real_div_node = CreateRealDiv(graph, sparse_softmax_node_grad, tile_node);
|
|
|
|
|
}
|
|
|
|
|
auto expand_dims_node = CreateExpandDimsPynative(graph, real_div_node);
|
|
|
|
|
|
|
|
|
|
mul_node->set_input(1, softmax_node_outputs[1]);
|
|
|
|
|
mul_node->set_input(2, expand_dims_node);
|
|
|
|
|
|
|
|
|
|
return mul_node;
|
|
|
|
|
std::vector<AnfNodePtr> new_mul_inputs = {NewValueNode(std::make_shared<Primitive>(kMulOpName)),
|
|
|
|
|
softmax_node_outputs[1], expand_dims_node};
|
|
|
|
|
auto new_mul_node = graph->NewCNode(new_mul_inputs);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(new_mul_node);
|
|
|
|
|
new_mul_node->set_scope(mul_node->scope());
|
|
|
|
|
new_mul_node->set_abstract(mul_node->abstract());
|
|
|
|
|
return new_mul_node;
|
|
|
|
|
}
|
|
|
|
|
} // namespace opt
|
|
|
|
|
} // namespace mindspore
|
|
|
|
|