!11607 【GraphKernel】Raise akg ReduceSum precision

From: @dayschan
Reviewed-by: @gaoxiong1,@ckey_dou
Signed-off-by: @ckey_dou
pull/11607/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit cd22f43019

@ -0,0 +1,130 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/graph_kernel/raise_reduction_precision.h"
#include <vector>
#include <string>
#include <algorithm>
#include <memory>
#include "base/core_ops.h"
#include "utils/utils.h"
#include "backend/optimizer/common/helper.h"
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "ir/tensor.h"
#include "backend/kernel_compiler/kernel_build_info.h"
#include "backend/kernel_compiler/common_utils.h"
#include "runtime/device/kernel_info.h"
namespace mindspore {
namespace opt {
bool RaiseReductionPrecision::IsFp16ReduceSum(const AnfNodePtr &node) {
return IsPrimitiveCNode(node, prim::kPrimReduceSum) && AnfAlgo::GetInputDeviceDataType(node, 0) == kNumberTypeFloat16;
}
AnfNodePtr RaiseReductionPrecision::CreateCast(const AnfNodePtr &input, const TypePtr &dst_type, std::string format) {
auto func_graph = input->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
AnfNodePtrList inputs = {NewValueNode(prim::kPrimCast->Clone()), input};
auto cnode = CreateCNode(inputs, func_graph, {.format = format, .shape = GetShape(input), .type = dst_type});
AnfAlgo::SetNodeAttr("dst_type", MakeValue(kernel::TypeId2String(dst_type->type_id())), cnode);
return cnode;
}
AnfNodePtr RaiseReductionPrecision::CreateReduceSum(const AnfNodePtr &node, const AnfNodePtr &input) {
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
cnode->set_input(1, input);
cnode->set_abstract(std::make_shared<abstract::AbstractTensor>(kFloat32, GetShape(node)));
kernel::KernelBuildInfo::KernelBuildInfoBuilder info_builder;
info_builder.SetInputsFormat({AnfAlgo::GetInputFormat(node, 0)});
info_builder.SetInputsDeviceType({kFloat32->type_id()});
info_builder.SetOutputsFormat({AnfAlgo::GetOutputFormat(node, 0)});
info_builder.SetOutputsDeviceType({kFloat32->type_id()});
info_builder.SetProcessor(AnfAlgo::GetProcessor(node));
info_builder.SetKernelType(KernelType::AKG_KERNEL);
info_builder.SetFusionType(kernel::FusionType::OPAQUE);
AnfAlgo::SetSelectKernelBuildInfo(info_builder.Build(), cnode.get());
return node;
}
void RaiseReductionPrecision::ReplaceNode(const AnfNodePtr &reduce_node, const AnfNodePtr &cast_node) {
auto mng = reduce_node->func_graph()->manager();
MS_EXCEPTION_IF_NULL(mng);
// use a copy of user, since the following `mng->Replace` will change the original users of reduce_node.
auto users = mng->node_users()[reduce_node];
for (const auto &user : users) {
auto user_node = user.first;
auto user_index = user.second;
if (IsPrimitiveCNode(user_node, prim::kPrimCast) &&
AnfAlgo::GetOutputDeviceDataType(user_node, 0) == kNumberTypeFloat32) {
mng->Replace(user_node, reduce_node);
} else {
if (user_node->isa<CNode>()) {
user_node->cast<CNodePtr>()->set_input(user_index, cast_node);
}
}
}
}
bool RaiseReductionPrecision::Process(const FuncGraphPtr &func_graph) {
auto mng = func_graph->manager();
if (mng == nullptr) {
mng = Manage(func_graph, true);
func_graph->set_manager(mng);
}
auto todos = TopoSort(func_graph->get_return());
bool changed = false;
for (auto node : todos) {
if (IsFp16ReduceSum(node)) {
auto cast1 = CreateCast(node->cast<CNodePtr>()->input(1), kFloat32, AnfAlgo::GetInputFormat(node, 0));
auto new_reduce = CreateReduceSum(node, cast1);
auto cast2 = CreateCast(new_reduce, kFloat16, AnfAlgo::GetOutputFormat(node, 0));
ReplaceNode(node, cast2);
changed = true;
}
}
if (changed) {
mng->RemoveRoots();
mng->KeepRoots({func_graph});
}
return changed;
}
bool RaiseReductionPrecision::Run(const FuncGraphPtr &func_graph) {
auto mng = func_graph->manager();
if (mng == nullptr) {
mng = Manage(func_graph, true);
func_graph->set_manager(mng);
}
auto todos = TopoSort(func_graph->get_return());
bool changed = false;
for (const auto &node : todos) {
if (AnfAlgo::IsGraphKernel(node)) {
auto sub_func_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
MS_ERROR_IF_NULL(sub_func_graph);
changed = Process(sub_func_graph) || changed;
}
}
if (changed) {
mng->RemoveRoots();
mng->KeepRoots({func_graph});
}
return changed;
}
} // namespace opt
} // namespace mindspore

@ -0,0 +1,39 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_RAISE_REDUCTION_PRECISION_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_RAISE_REDUCTION_PRECISION_H_
#include <string>
#include "backend/optimizer/common/optimizer.h"
namespace mindspore {
namespace opt {
class RaiseReductionPrecision : public Pass {
public:
RaiseReductionPrecision() : Pass("raise_reduction_precision") {}
~RaiseReductionPrecision() override = default;
bool Run(const FuncGraphPtr &func_graph);
private:
bool IsFp16ReduceSum(const AnfNodePtr &node);
bool Process(const FuncGraphPtr &func_graph);
AnfNodePtr CreateCast(const AnfNodePtr &input, const TypePtr &dst_type, std::string format);
AnfNodePtr CreateReduceSum(const AnfNodePtr &node, const AnfNodePtr &input);
void ReplaceNode(const AnfNodePtr &src_node, const AnfNodePtr &dst_node);
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_RAISE_REDUCTION_PRECISION_H_

@ -47,6 +47,7 @@
#include "backend/optimizer/graph_kernel/tensor_promotion.h"
#include "backend/optimizer/graph_kernel/graph_kernel_splitter.h"
#include "backend/optimizer/graph_kernel/graph_kernel_expander.h"
#include "backend/optimizer/graph_kernel/raise_reduction_precision.h"
#include "backend/optimizer/graph_kernel/graph_kernel_cse.h"
#include "backend/optimizer/graph_kernel/shape_ops_splitter.h"
#include "backend/optimizer/graph_kernel/value_graph_binder.h"
@ -182,6 +183,7 @@ void GPUSession::GraphKernelOptimize(const std::shared_ptr<KernelGraph> &kernel_
pm->AddPass(std::make_shared<opt::ShapeOpsSplitter>(duplicated_ops));
pm->AddPass(std::make_shared<opt::BasicOpsFusion>());
pm->AddPass(std::make_shared<opt::EliminateRedundantOutput>());
pm->AddPass(std::make_shared<opt::RaiseReductionPrecision>());
pm->AddPass(std::make_shared<opt::GraphKernelCSE>(duplicated_ops));
pm->AddPass(std::make_shared<opt::ArithmeticSimplify>());
pm->AddPass(std::make_shared<opt::GraphKernelCSE>(duplicated_ops));

Loading…
Cancel
Save