!13316 【GraphKernel】Normalize the Reduce nodes' axis

From: @dayschan
Reviewed-by: @gaoxiong1,@dylangeng
Signed-off-by: @gaoxiong1
pull/13316/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 8e1d582221

@ -27,8 +27,9 @@ class ReduceMean(Expander):
axis = self.attrs['axis']
keep_dims = self.attrs['keep_dims']
# cal reduce_mean, when axis is None, reduce all axes.
if not axis:
if not isinstance(axis, (tuple, list)):
axis = (axis,)
elif not axis:
axis = list(range(len(x.shape)))
reduce_size = 1.0
for idx in axis:

@ -0,0 +1,93 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/graph_kernel/axis_normalizer.h"
#include <algorithm>
#include <vector>
#include "ir/scalar.h"
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
#include "backend/session/anf_runtime_algorithm.h"
namespace mindspore {
namespace opt {
int64_t AxisNormalizer::NormAxis(int64_t x, size_t rank) { return x >= 0 ? x : x + static_cast<int64_t>(rank); }
bool AxisNormalizer::IsReduce(const AnfNodePtr &node) {
std::vector<PrimitivePtr> node_with_axis = {prim::kPrimReduceSum, prim::kPrimReduceMax, prim::kPrimReduceMin};
return std::any_of(node_with_axis.begin(), node_with_axis.end(),
[&node](PrimitivePtr &p) { return IsPrimitiveCNode(node, p); });
}
bool AxisNormalizer::Process(const FuncGraphPtr &func_graph) {
bool changed = false;
auto todos = TopoSort(func_graph->get_return());
for (auto node : todos) {
if (!IsReduce(node)) {
continue;
}
if (auto primitive = GetCNodePrimitive(node); primitive != nullptr && primitive->HasAttr(kAttrAxis)) {
auto axis = primitive->GetAttr(kAttrAxis);
size_t rank = AnfAlgo::GetInputDeviceShape(node, 0).size();
if (rank == 0) {
// scalar tensor
rank = 1;
}
bool diff = false;
ShapeVector axis_vec;
if (axis->isa<Int32Imm>() || axis->isa<Int64Imm>()) {
int64_t v1 = GetValue<int64_t>(axis);
int64_t v2 = NormAxis(v1, rank);
axis_vec.push_back(v2);
diff = diff || (v1 != v2);
} else if (axis->isa<ValueList>() || axis->isa<ValueTuple>()) {
auto vec = axis->isa<ValueList>() ? axis->cast<ValueListPtr>()->value() : axis->cast<ValueTuplePtr>()->value();
if (vec.empty()) {
diff = true;
for (size_t i = 0; i < rank; i++) {
axis_vec.push_back(i);
}
} else if (vec[0]->isa<Int32Imm>() || vec[0]->isa<Int64Imm>()) {
for (auto v : vec) {
int64_t v1 = GetValue<int64_t>(v);
int64_t v2 = NormAxis(v1, rank);
axis_vec.push_back(v2);
diff = diff || (v1 != v2);
}
}
}
if (diff) {
changed = true;
SetNodeAttrSafely(kAttrAxis, MakeValue(axis_vec), node);
}
}
}
return changed;
}
bool AxisNormalizer::Run(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
bool changed = false;
auto todos = TopoSort(func_graph->get_return());
for (auto node : todos) {
if (AnfAlgo::IsGraphKernel(node)) {
auto sub_func_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
changed = Process(sub_func_graph) || changed;
}
}
return changed;
}
} // namespace opt
} // namespace mindspore

@ -0,0 +1,38 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_AXIS_NORMALIZER_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_AXIS_NORMALIZER_H_
#include "ir/func_graph.h"
#include "backend/optimizer/common/optimizer.h"
namespace mindspore {
namespace opt {
// change Reduce nodes' axis to non-negative value
class AxisNormalizer : public Pass {
public:
AxisNormalizer() : Pass("axis_normalizer") {}
~AxisNormalizer() = default;
bool Run(const FuncGraphPtr &func_graph) override;
private:
bool Process(const FuncGraphPtr &func_graph);
int64_t NormAxis(int64_t x, size_t rank);
bool IsReduce(const AnfNodePtr &node);
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_AXIS_NORMALIZER_H_

@ -41,6 +41,7 @@
#include "backend/optimizer/graph_kernel/split_assign.h"
#include "backend/optimizer/graph_kernel/reorder_ops.h"
#include "backend/optimizer/graph_kernel/update_state_formatter.h"
#include "backend/optimizer/graph_kernel/axis_normalizer.h"
#include "backend/optimizer/pass/getitem_tuple.h"
namespace mindspore {
@ -78,6 +79,9 @@ PassManagerPtr GraphKernelOptimizer::Cluster() {
PassManagerPtr GraphKernelOptimizer::HighLevelOpt1() {
auto pm = std::make_shared<PassManager>("graphkernel_stage3_highlevelopt1");
// normalize the Reduce axis
pm->AddPass(std::make_shared<AxisNormalizer>());
// Replace Assign with InplaceAssign, and replace original output with overridden parameters
pm->AddPass(std::make_shared<OptimizeAssign>());
pm->AddPass(std::make_shared<EliminateRedundantOutput>());

Loading…
Cancel
Save