!13316 【GraphKernel】Normalize the Reduce nodes' axis
From: @dayschan Reviewed-by: @gaoxiong1,@dylangeng Signed-off-by: @gaoxiong1pull/13316/MERGE
commit
8e1d582221
@ -0,0 +1,93 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/optimizer/graph_kernel/axis_normalizer.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
#include "ir/scalar.h"
|
||||
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
int64_t AxisNormalizer::NormAxis(int64_t x, size_t rank) { return x >= 0 ? x : x + static_cast<int64_t>(rank); }
|
||||
|
||||
bool AxisNormalizer::IsReduce(const AnfNodePtr &node) {
|
||||
std::vector<PrimitivePtr> node_with_axis = {prim::kPrimReduceSum, prim::kPrimReduceMax, prim::kPrimReduceMin};
|
||||
return std::any_of(node_with_axis.begin(), node_with_axis.end(),
|
||||
[&node](PrimitivePtr &p) { return IsPrimitiveCNode(node, p); });
|
||||
}
|
||||
|
||||
bool AxisNormalizer::Process(const FuncGraphPtr &func_graph) {
|
||||
bool changed = false;
|
||||
auto todos = TopoSort(func_graph->get_return());
|
||||
for (auto node : todos) {
|
||||
if (!IsReduce(node)) {
|
||||
continue;
|
||||
}
|
||||
if (auto primitive = GetCNodePrimitive(node); primitive != nullptr && primitive->HasAttr(kAttrAxis)) {
|
||||
auto axis = primitive->GetAttr(kAttrAxis);
|
||||
size_t rank = AnfAlgo::GetInputDeviceShape(node, 0).size();
|
||||
if (rank == 0) {
|
||||
// scalar tensor
|
||||
rank = 1;
|
||||
}
|
||||
bool diff = false;
|
||||
ShapeVector axis_vec;
|
||||
if (axis->isa<Int32Imm>() || axis->isa<Int64Imm>()) {
|
||||
int64_t v1 = GetValue<int64_t>(axis);
|
||||
int64_t v2 = NormAxis(v1, rank);
|
||||
axis_vec.push_back(v2);
|
||||
diff = diff || (v1 != v2);
|
||||
} else if (axis->isa<ValueList>() || axis->isa<ValueTuple>()) {
|
||||
auto vec = axis->isa<ValueList>() ? axis->cast<ValueListPtr>()->value() : axis->cast<ValueTuplePtr>()->value();
|
||||
if (vec.empty()) {
|
||||
diff = true;
|
||||
for (size_t i = 0; i < rank; i++) {
|
||||
axis_vec.push_back(i);
|
||||
}
|
||||
} else if (vec[0]->isa<Int32Imm>() || vec[0]->isa<Int64Imm>()) {
|
||||
for (auto v : vec) {
|
||||
int64_t v1 = GetValue<int64_t>(v);
|
||||
int64_t v2 = NormAxis(v1, rank);
|
||||
axis_vec.push_back(v2);
|
||||
diff = diff || (v1 != v2);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (diff) {
|
||||
changed = true;
|
||||
SetNodeAttrSafely(kAttrAxis, MakeValue(axis_vec), node);
|
||||
}
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
bool AxisNormalizer::Run(const FuncGraphPtr &func_graph) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
bool changed = false;
|
||||
auto todos = TopoSort(func_graph->get_return());
|
||||
for (auto node : todos) {
|
||||
if (AnfAlgo::IsGraphKernel(node)) {
|
||||
auto sub_func_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
|
||||
changed = Process(sub_func_graph) || changed;
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
@ -0,0 +1,38 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_AXIS_NORMALIZER_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_AXIS_NORMALIZER_H_
|
||||
|
||||
#include "ir/func_graph.h"
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
// change Reduce nodes' axis to non-negative value
|
||||
class AxisNormalizer : public Pass {
|
||||
public:
|
||||
AxisNormalizer() : Pass("axis_normalizer") {}
|
||||
~AxisNormalizer() = default;
|
||||
bool Run(const FuncGraphPtr &func_graph) override;
|
||||
|
||||
private:
|
||||
bool Process(const FuncGraphPtr &func_graph);
|
||||
int64_t NormAxis(int64_t x, size_t rank);
|
||||
bool IsReduce(const AnfNodePtr &node);
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_AXIS_NORMALIZER_H_
|
Loading…
Reference in new issue