|
|
|
@ -21,9 +21,30 @@
|
|
|
|
|
#include "session/anf_runtime_algorithm.h"
|
|
|
|
|
#include "ir/primitive.h"
|
|
|
|
|
#include "utils/utils.h"
|
|
|
|
|
#include "pre_activate/common/helper.h"
|
|
|
|
|
|
|
|
|
|
namespace mindspore {
|
|
|
|
|
namespace opt {
|
|
|
|
|
namespace {
|
|
|
|
|
void SetAttrsForFusionNode(const AnfNodePtr &sub_anf, const AnfNodePtr &fusion_node) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(sub_anf);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(fusion_node);
|
|
|
|
|
auto sub = sub_anf->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(sub);
|
|
|
|
|
if (sub->size() != kSubInputNum) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Sub's size is not equal with 3";
|
|
|
|
|
}
|
|
|
|
|
auto reduce_sum_anf = sub->input(2);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(reduce_sum_anf);
|
|
|
|
|
auto reduce_sum = reduce_sum_anf->cast<CNodePtr>();
|
|
|
|
|
if (reduce_sum == nullptr) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Sub's second input is not a cnode";
|
|
|
|
|
}
|
|
|
|
|
AnfAlgo::CopyNodeAttr(kAttrAxis, reduce_sum, fusion_node);
|
|
|
|
|
AnfAlgo::CopyNodeAttr(kAttrKeepDims, reduce_sum, fusion_node);
|
|
|
|
|
}
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
const BaseRef ConfusionSoftmaxGradRule::DefinePattern() const {
|
|
|
|
|
return VectorRef(
|
|
|
|
|
{prim::kPrimSub, input0_, VectorRef({prim::kPrimReduceSum, VectorRef({prim::kPrimMul, input0_, input1_})})});
|
|
|
|
@ -48,6 +69,7 @@ const AnfNodePtr ConfusionSoftmaxGradRule::Process(const FuncGraphPtr &graph, co
|
|
|
|
|
auto shapes = {AnfAlgo::GetOutputInferShape(node, 0)};
|
|
|
|
|
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, confusion_softmax_grad.get());
|
|
|
|
|
confusion_softmax_grad->set_scope(node->scope());
|
|
|
|
|
SetAttrsForFusionNode(node, confusion_softmax_grad);
|
|
|
|
|
return confusion_softmax_grad;
|
|
|
|
|
}
|
|
|
|
|
} // namespace opt
|
|
|
|
|