!387 auto mix precision

Merge pull request !387 from liubuyu/master
pull/387/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 7c06d292c8

@ -45,64 +45,6 @@ enum MatchCountPriority : int {
const size_t kMaxCount = 0xffffffff;
const int kUnSupportMixedDataTypeIndex = -1;
const std::set<std::string> kOpFormatList = {kOpFormat_DEFAULT, kOpFormat_NC1KHKWHWC0, kOpFormat_ND,
kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_HWCN,
kOpFormat_NC1HWC0, kOpFormat_FRAC_Z, kOpFormat_C1HWNCoC0,
kOpFormat_FRAC_NZ, kOpFormat_NC1HWC0_C04};
bool IsShapeMatchFormat(const std::vector<size_t> &shape, const std::string &format) {
// if format is default, it remarkes support all format
if (kOpFormatList.find(format) == kOpFormatList.end()) {
MS_LOG(EXCEPTION) << "got the unknown format " << format;
}
if (format == kOpFormat_DEFAULT) {
return true;
}
// if shape size is 0, the shape will be a scalar
if (shape.empty()) {
return true;
}
if (shape.size() > kShapeSupportFormatMap.size()) {
return false;
}
if (format == kOpFormat_FRAC_NZ && shape.size() >= 2) {
return true;
}
return !(kShapeSupportFormatMap[shape.size() - 1].find(format) == kShapeSupportFormatMap[shape.size() - 1].end());
}
bool IsValidKernelInfo(const std::shared_ptr<CNode> &kernel_node, const kernel::KernelBuildInfo &kernel_build_info) {
MS_EXCEPTION_IF_NULL(kernel_node);
auto check_function = [](const std::vector<size_t> &shape, const std::string &format) -> bool {
if (!IsShapeMatchFormat(shape, format)) {
return false;
}
for (auto shape_value : shape) {
if (shape_value == 0) {
MS_LOG(EXCEPTION) << "dimension size of the tensor shape should be a positive integer, but got " << shape_value;
}
}
return true;
};
if (AnfAlgo::GetCNodeName(kernel_node) == prim::kPrimCast->name()) {
return AnfAlgo::GetOutputInferDataType(kernel_node, 0) == kernel_build_info.GetOutputDeviceType(0) &&
AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0) == kernel_build_info.GetInputDeviceType(0);
}
for (size_t index = 0; index < kernel_build_info.GetOutputNum(); ++index) {
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, index);
if (!check_function(output_shape, kernel_build_info.GetOutputFormat(index))) {
return false;
}
}
for (size_t index = 0; index < kernel_build_info.GetInputNum(); ++index) {
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, index);
if (!check_function(input_shape, kernel_build_info.GetInputFormat(index))) {
return false;
}
}
return true;
}
bool MatchInferOutputDataType(const CNodePtr &cnode, const kernel::KernelBuildInfo &kernel_build_info) {
MS_EXCEPTION_IF_NULL(cnode);
// Check input data type
@ -459,6 +401,29 @@ int PrecisionReduce(const std::vector<int> &node_mix_precision_datatype_index,
// raise precision
int selected_index = RaiseDataTypePrecisionSelect(node_mix_precision_datatype_index, node_mix_precision_datatype,
kernel_support_datatype, kernel_match_datatype_idx);
if (selected_index != -1) {
int max_match = 0;
auto iter = kernel_match_datatype_idx->begin();
int match_count = 0;
while (iter != kernel_match_datatype_idx->end()) {
auto kernel_datatypes = kernel_support_datatype.find(iter->first);
if (kernel_datatypes == kernel_support_datatype.end()) {
MS_LOG(EXCEPTION) << "Can not find kernel index" << iter->first << "'s datatype.";
}
if (kernel_datatypes->second.size() < node_mix_precision_datatype.size()) {
MS_LOG(EXCEPTION) << "Kernel datatype size is not equal to node datatype size!";
}
for (size_t i = 0; i < node_mix_precision_datatype.size(); ++i) {
if (node_mix_precision_datatype[i] == kernel_datatypes->second[i]) {
++match_count;
}
}
if (match_count > max_match) {
selected_index = SizeToInt(iter->first);
}
++iter;
}
}
if (selected_index == -1 && context_ptr->enable_reduce_precision()) {
selected_index =
RaiseOrReduceDataTypePrecisionSelect(node_mix_precision_datatype_index, node_mix_precision_datatype,
@ -507,9 +472,6 @@ void SelectKernelInfo(const CNodePtr &kernel_node) {
kernel::KernelQuery(kernel_node, &kernel_info_list);
std::vector<int> most_match_counts = {-1, -1, -1, -1};
int selected_index = -1;
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
bool auto_mixed_precision = context_ptr->auto_mixed_precision_flag();
std::unordered_map<size_t, std::vector<int>> kernel_match_datatype_idx;
std::unordered_map<size_t, std::vector<TypeId>> kernel_support_datatype;
std::vector<int> node_mix_precision_datatype_index;
@ -517,16 +479,13 @@ void SelectKernelInfo(const CNodePtr &kernel_node) {
for (size_t info_index = 0; info_index < kernel_info_list.size(); ++info_index) {
std::vector<int> cur_kernel_info_match_counts = {0, 0, 0, 0};
auto kernel_build_info = *(kernel_info_list[info_index]);
if (!IsValidKernelInfo(kernel_node, kernel_build_info)) {
continue;
}
std::vector<int> support_indexes;
std::vector<TypeId> support_datatypes;
AddNodeAndKernelDataType(kernel_node, kernel_build_info, &support_indexes, &node_mix_precision_datatype,
&support_datatypes, &node_mix_precision_datatype_index);
kernel_match_datatype_idx[info_index] = support_indexes;
kernel_support_datatype[info_index] = support_datatypes;
if (!auto_mixed_precision && !MatchInferOutputDataType(kernel_node, kernel_build_info)) {
if (!MatchInferOutputDataType(kernel_node, kernel_build_info)) {
continue;
}
std::shared_ptr<kernel::KernelBuildInfo> kernel_info_ptr = kernel_info_list[info_index];

@ -19,6 +19,7 @@
#include <unordered_map>
#include <memory>
#include <map>
#include <set>
#include "session/anf_runtime_algorithm.h"
#include "kernel/oplib/oplib.h"
@ -510,6 +511,64 @@ bool ParseMetadata(const CNodePtr &kernel_node, const std::shared_ptr<const OpIn
return true;
}
bool IsShapeMatchFormat(const std::vector<size_t> &shape, const std::string &format) {
const std::set<std::string> kOpFormatList = {kOpFormat_DEFAULT, kOpFormat_NC1KHKWHWC0, kOpFormat_ND,
kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_HWCN,
kOpFormat_NC1HWC0, kOpFormat_FRAC_Z, kOpFormat_C1HWNCoC0,
kOpFormat_FRAC_NZ, kOpFormat_NC1HWC0_C04};
// if format is default, it remarkes support all format
if (kOpFormatList.find(format) == kOpFormatList.end()) {
MS_LOG(EXCEPTION) << "Got the unknown format " << format;
}
if (format == kOpFormat_DEFAULT) {
return true;
}
// if shape size is 0, the shape will be a scalar
if (shape.empty()) {
return true;
}
if (shape.size() > kShapeSupportFormatMap.size()) {
return false;
}
if (format == kOpFormat_FRAC_NZ && shape.size() >= 2) {
return true;
}
return !(kShapeSupportFormatMap[shape.size() - 1].find(format) == kShapeSupportFormatMap[shape.size() - 1].end());
}
bool IsValidKernelInfo(const std::shared_ptr<CNode> &kernel_node, const kernel::KernelBuildInfo &kernel_build_info) {
MS_EXCEPTION_IF_NULL(kernel_node);
auto check_function = [](const std::vector<size_t> &shape, const std::string &format) -> bool {
if (!IsShapeMatchFormat(shape, format)) {
return false;
}
for (auto shape_value : shape) {
if (shape_value == 0) {
MS_LOG(EXCEPTION) << "Dimension size of the tensor shape should be a positive integer, but got " << shape_value;
}
}
return true;
};
for (size_t index = 0; index < kernel_build_info.GetOutputNum(); ++index) {
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, index);
if (!check_function(output_shape, kernel_build_info.GetOutputFormat(index))) {
return false;
}
}
for (size_t index = 0; index < kernel_build_info.GetInputNum(); ++index) {
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, index);
if (!check_function(input_shape, kernel_build_info.GetInputFormat(index))) {
return false;
}
}
if (AnfAlgo::GetCNodeName(kernel_node) == prim::kPrimCast->name()) {
return AnfAlgo::GetOutputInferDataType(kernel_node, 0) == kernel_build_info.GetOutputDeviceType(0) &&
AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0) == kernel_build_info.GetInputDeviceType(0);
}
return true;
}
void TbeMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list) {
MS_EXCEPTION_IF_NULL(kernel_node);
MS_EXCEPTION_IF_NULL(kernel_info_list);
@ -534,7 +593,7 @@ void TbeMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<Ke
if (context_ptr->execution_mode() == kPynativeMode) {
kernel_info_list->push_back(parse_info);
} else {
if (CheckSupported(kernel_node, parse_info)) {
if (IsValidKernelInfo(kernel_node, *(parse_info)) && CheckSupported(kernel_node, parse_info)) {
kernel_info_list->push_back(parse_info);
} else {
MS_LOG(INFO) << "CheckSupported Failed for TBE op" << op_name << " kernel info.";

@ -37,6 +37,7 @@
#include "pre_activate/ascend/ir_fusion/transpose_reshape_fusion.h"
#include "pre_activate/ascend/ir_fusion/adam_apply_one_fusion.h"
#include "pre_activate/ascend/ir_fusion/adam_apply_one_with_decay_rule.h"
#include "pre_activate/ascend/ir_fusion/parameter_and_transop_fusion.h"
#include "pre_activate/ascend/ir_fusion/transpose_transdata_fusion.h"
#include "pre_activate/ascend/ir_fusion/transdata_split.h"
#include "pre_activate/ascend/ir_fission/topk_split.h"
@ -243,6 +244,7 @@ void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kern
auto optimizer = std::make_shared<GraphOptimizer>();
auto other_pm = std::make_shared<PassManager>("other_pm");
other_pm->AddPass(std::make_shared<AllReduceFusion>());
other_pm->AddPass(std::make_shared<ParameterTransOpFusion>());
other_pm->AddPass(std::make_shared<BufferFusion>());
other_pm->AddPass(std::make_shared<GetitemTuple>());
other_pm->AddPass(std::make_shared<CommonSubexpressionElimination>());

@ -0,0 +1,120 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pre_activate/ascend/ir_fusion/parameter_and_transop_fusion.h"
#include <memory>
#include "session/anf_runtime_algorithm.h"
#include "utils/utils.h"
#include "operator/ops.h"
#include "device/kernel_info.h"
#include "pre_activate/common/helper.h"
#include "pre_activate/common/optimizer.h"
#include "pre_activate/ascend/ascend_helper.h"
namespace mindspore {
namespace opt {
const AnfNodePtr ParamTransRoad(const FuncGraphPtr &func_graph, const AnfNodePtr &node, bool first_flag,
std::vector<CNodePtr> *trans_road) {
if (node == nullptr) {
MS_LOG(ERROR) << "nullptr";
return nullptr;
}
if (node->isa<CNode>()) {
auto cnode = node->cast<CNodePtr>();
auto op_name = AnfAlgo::GetCNodeName(cnode);
auto manager = func_graph->manager();
if (manager == nullptr) {
return nullptr;
}
if (op_name == prim::kPrimCast->name() || op_name == prim::kPrimTranspose->name() ||
op_name == prim::kPrimReshape->name() || op_name == kTransDataOpName) {
auto users = manager->node_users()[node];
if (users.size() > 1 && !first_flag) {
return nullptr;
}
trans_road->push_back(cnode);
first_flag = false;
auto next_node = AnfAlgo::GetInputNode(cnode, 0);
if (next_node->isa<Parameter>() || next_node->isa<ValueNode>()) {
return next_node;
}
return ParamTransRoad(func_graph, next_node, first_flag, trans_road);
}
} else if (node->isa<Parameter>() || node->isa<ValueNode>()) {
return node;
}
return nullptr;
}
bool ParameterTransOpFusion::Run(const FuncGraphPtr &func_graph) {
if (func_graph == nullptr) {
MS_LOG(ERROR) << "Func graph is nullptr";
return false;
}
auto manager = func_graph->manager();
if (manager == nullptr) {
return false;
}
std::vector<AnfNodePtr> node_list = TopoSort(func_graph->get_return());
bool changed = false;
for (auto node : node_list) {
if (node == nullptr || !node->isa<CNode>()) {
continue;
}
auto cnode = node->cast<CNodePtr>();
auto node_name = AnfAlgo::GetCNodeName(cnode);
if (node_name == prim::kPrimCast->name() || node_name == prim::kPrimTranspose->name() ||
node_name == prim::kPrimReshape->name() || node_name == kTransDataOpName) {
MS_LOG(DEBUG) << "Skip trans op";
continue;
}
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); input_index++) {
std::vector<CNodePtr> trans_road;
bool first_flag = true;
auto final_node = ParamTransRoad(func_graph, AnfAlgo::GetInputNode(cnode, input_index), first_flag, &trans_road);
if (final_node != nullptr && trans_road.size() == 3 && AnfAlgo::GetCNodeName(trans_road[0]) == kTransDataOpName &&
AnfAlgo::GetCNodeName(trans_road[1]) == prim::kPrimCast->name() &&
AnfAlgo::GetCNodeName(trans_road[2]) == kTransDataOpName) {
auto cur_transop = trans_road[0];
auto format = AnfAlgo::GetOutputFormat(cur_transop, 0);
auto dtype = AnfAlgo::GetOutputDeviceDataType(cur_transop, 0);
auto param_format = AnfAlgo::GetOutputFormat(final_node, 0);
auto param_dtype = AnfAlgo::GetOutputDeviceDataType(final_node, 0);
auto cast = trans_road[1];
auto cast_format = AnfAlgo::GetOutputFormat(cast, 0);
auto cast_build_info = cast->kernel_info()->select_kernel_build_info();
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
builder.SetOutputsFormat({format});
builder.SetInputsFormat({format});
builder.SetInputsDeviceType({param_dtype});
builder.SetOutputsDeviceType({dtype});
builder.SetKernelType(cast_build_info->kernel_type());
builder.SetFusionType(cast_build_info->fusion_type());
builder.SetProcessor(cast_build_info->processor());
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cast.get());
if (param_format == format && param_dtype != dtype) {
manager->Replace(trans_road[2], final_node);
manager->Replace(cur_transop, cast);
}
changed = true;
}
}
}
return changed;
}
} // namespace opt
} // namespace mindspore

@ -0,0 +1,41 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_PARAMETER_AND_TRANSOP_FUSION_H_
#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_PARAMETER_AND_TRANSOP_FUSION_H_
#include <vector>
#include <string>
#include <utility>
#include <memory>
#include "ir/anf.h"
#include "pre_activate/common/pass.h"
namespace mindspore {
namespace opt {
class ParameterTransOpFusion : public Pass {
public:
explicit ParameterTransOpFusion(size_t groups = 1) : Pass("Parameter_and_transop_fusion"), groups_(groups) {}
~ParameterTransOpFusion() override = default;
bool Run(const FuncGraphPtr &graph) override;
private:
size_t groups_ = 1;
};
} // namespace opt
} // namespace mindspore
#endif

@ -44,6 +44,12 @@ cast_op_info = TBERegOp("Cast") \
.dtype_format(DataType.F16_Default, DataType.U8_Default) \
.dtype_format(DataType.F16_Default, DataType.F32_Default) \
.dtype_format(DataType.F16_Default, DataType.I32_Default) \
.dtype_format(DataType.F16_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F16_FracZ, DataType.F32_FracZ) \
.dtype_format(DataType.F16_FracNZ, DataType.F32_FracNZ) \
.dtype_format(DataType.F32_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_FracZ, DataType.F16_FracZ) \
.dtype_format(DataType.F32_FracNZ, DataType.F16_FracNZ) \
.dtype_format(DataType.F32_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.I32_Default) \
.get_op_info()

Loading…
Cancel
Save