convert unsupported kernel in aicore to aicpu

pull/1079/head
WilliamLian 5 years ago
parent 7ab3f5c348
commit 691b0648e3

@ -85,7 +85,7 @@ const std::map<TypeId, size_t> type_map = {{kNumberTypeBool, 1}, {kNumberType
} while (0)
template <typename T>
T Ceil(T n1, T n2) {
T DivCeil(T n1, T n2) {
return (n2 != 0) ? (n1 - 1) / n2 + 1 : 0;
}
@ -371,15 +371,48 @@ std::vector<size_t> C1hwncoc0DeviceShape(const std::vector<size_t> &shape) {
device_shape.push_back(kCubeSize);
return device_shape;
}
std::vector<size_t> FracZc04DeviceShape(const std::vector<size_t> &shape) {
if (!CheckDims(shape)) {
MS_LOG(EXCEPTION) << "Check dims failed.";
}
std::vector<size_t> device_shape;
size_t c0 = 4;
size_t first_dim = DivCeil(c0 * shape[2] * shape[3], kCubeSize);
size_t no = DivCeil(DivCeil(shape[0], kCubeSize) * kCubeSize, kCubeSize);
device_shape.push_back(first_dim);
device_shape.push_back(no);
device_shape.push_back(kCubeSize);
device_shape.push_back(kCubeSize);
return device_shape;
}
std::vector<size_t> Nc1hwc04DeviceShape(const std::vector<size_t> &shape) {
if (!CheckDims(shape)) {
MS_LOG(EXCEPTION) << "Check dims failed.";
}
std::vector<size_t> device_shape;
size_t C1 = 1;
size_t C0 = 4;
device_shape.push_back(shape[0]);
device_shape.push_back(C1);
device_shape.push_back(shape[2]);
device_shape.push_back(shape[3]);
device_shape.push_back(C0);
return device_shape;
}
} // namespace
std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format) {
using DeviceShapeTransfer = std::function<std::vector<size_t>(const std::vector<size_t> &)>;
const std::map<std::string, DeviceShapeTransfer> device_shape_map{
{kOpFormat_NCHW, NchwDeviceShape}, {kOpFormat_NHWC, NhwcDeviceShape},
{kOpFormat_HWCN, HwchDeviceShape}, {kOpFormat_FRAC_Z, FracZDeviceShape},
{kOpFormat_NC1HWC0, Nc1hwc0DeviceShape}, {kOpFormat_C1HWNCoC0, C1hwncoc0DeviceShape},
};
const std::map<std::string, DeviceShapeTransfer> device_shape_map{{kOpFormat_NCHW, NchwDeviceShape},
{kOpFormat_NHWC, NhwcDeviceShape},
{kOpFormat_HWCN, HwchDeviceShape},
{kOpFormat_FRAC_Z, FracZDeviceShape},
{kOpFormat_NC1HWC0, Nc1hwc0DeviceShape},
{kOpFormat_C1HWNCoC0, C1hwncoc0DeviceShape},
{kOpFormat_FRACTAL_Z_C04, FracZc04DeviceShape},
{kOpFormat_NC1HWC0_C04, Nc1hwc04DeviceShape}};
if (format == kOpFormat_ND || format == kOpFormat_DEFAULT) {
return shape;
@ -506,13 +539,13 @@ bool NchwToFracZ(const FormatArgs &args, void *result) {
MS_LOG(ERROR) << "Illegal dtype.";
return false;
}
size_t c1 = Ceil(c, c0);
size_t c1 = DivCeil(c, c0);
size_t hw = h * w;
size_t chw = c * hw;
size_t hwc0 = hw * c0;
size_t nchw = n * chw;
size_t hf_cnt = Ceil(n, kCubeSize);
size_t hf_cnt = DivCeil(n, kCubeSize);
size_t vf_cnt = c1 * hw;
size_t fractal_ele_cnt = c0 * kCubeSize;
size_t total_ele_cnt = hf_cnt * vf_cnt * fractal_ele_cnt;
@ -775,7 +808,7 @@ bool NchwToNc1hwc0(const FormatArgs &args, void *result) {
MS_LOG(ERROR) << "Illegal dtype.";
return false;
}
size_t c1 = Ceil(c, c0);
size_t c1 = DivCeil(c, c0);
size_t hw = h * w;
size_t chw = c * hw;
size_t c1hwc0 = c1 * hw * c0;

@ -34,6 +34,7 @@ namespace ascend {
namespace {
const float kWegihtBaseScore = 1;
const float kFeatureMapBaseScore = 10;
constexpr auto kPriChoosenFormat = "pri_format";
enum MatchCountPriority : int {
MATCH_COUNT_PRIORITY_BEGIN = 0,
MATCH_DTYPE_COUNT = MATCH_COUNT_PRIORITY_BEGIN,
@ -85,6 +86,7 @@ string GetPriorityMatchFormat(const CNodePtr &cnode) {
if (need_change_nd) {
priority_matched_format = kOpFormat_DEFAULT;
}
AnfAlgo::SetNodeAttr(kPriChoosenFormat, MakeValue(priority_matched_format), cnode);
return priority_matched_format;
}
/**
@ -394,9 +396,9 @@ void PrintRaiseOrReducePrecisionSelectedInfo(const CNodePtr &cnode,
std::ostringstream buffer;
buffer << cnode->DebugString();
if (precision_reduce) {
buffer << " reduce precision, node datatype: ";
buffer << " reduce precision, node datatype: \n";
} else {
buffer << " raise precision, node datatype: ";
buffer << " raise precision, node datatype: \n";
}
PrintInputAndOutputInferType(buffer, cnode);
buffer << ", select kernel:" << selected_kernel_build_info->ToString();
@ -464,66 +466,57 @@ std::vector<std::shared_ptr<kernel::KernelBuildInfo>> FilterRaisedOrReducePrecis
}
} // namespace
std::shared_ptr<kernel::KernelBuildInfo> CanHitKernelInfo(
int *status, const CNodePtr &kernel_node,
const std::vector<std::shared_ptr<kernel::KernelBuildInfo>> &kernel_info_list) {
KernelSelectStatus SetMatchedKernelInfo(const CNodePtr &kernel_node,
const std::vector<std::shared_ptr<kernel::KernelBuildInfo>> &kernel_info_list) {
MS_EXCEPTION_IF_NULL(kernel_node);
KernelSelectStatus select_status = kNoMatched;
bool precision_reduce = false;
std::shared_ptr<kernel::KernelBuildInfo> selected_kernel_info = nullptr;
// Matched kernel info
// Filter kernel info matched with me infered type
auto filtered_kernel_info_list = GetAllMatchedFilteredKernelInfo(kernel_node, kernel_info_list);
if (!filtered_kernel_info_list.empty()) {
selected_kernel_info = ChooseMatchedKernelInfo(kernel_node, filtered_kernel_info_list);
select_status = kStatusAllMatched;
} else {
// selected kernel info using raised precision or reduce precision
filtered_kernel_info_list =
FilterRaisedOrReducePrecisionMatchedKernelInfo(kernel_node, kernel_info_list, &precision_reduce);
selected_kernel_info = ChooseMatchedKernelInfo(kernel_node, filtered_kernel_info_list);
if (selected_kernel_info == nullptr) {
return nullptr;
return select_status;
} else {
PrintRaiseOrReducePrecisionSelectedInfo(kernel_node, selected_kernel_info, precision_reduce);
*status = precision_reduce ? kStatusReducePrecision : kStatusRaisePrecision;
select_status = precision_reduce ? kStatusReducePrecision : kStatusRaisePrecision;
}
}
return selected_kernel_info;
// Set kernel info to the anfnode
AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_info, kernel_node.get());
// Set format and data type for input tensor.
SetTensorDeviceInfo(*selected_kernel_info, kernel_node);
return select_status;
}
int SelectKernelInfo(const CNodePtr &kernel_node) {
KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node) {
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list;
int status = kStatusAllMatched;
MS_EXCEPTION_IF_NULL(kernel_node);
kernel::KernelQuery(kernel_node, &kernel_info_list);
// filter kernel info matched with me infered type
auto selected_kernel_info = CanHitKernelInfo(&status, kernel_node, kernel_info_list);
if (selected_kernel_info == nullptr) {
auto select_status = SetMatchedKernelInfo(kernel_node, kernel_info_list);
// If aicore not find valid kernel info reloading aicpu kernel info list to find it
if (select_status == kNoMatched) {
MS_LOG(WARNING) << "The node [" << kernel_node->DebugString()
<< "] cannot find valid TBE kernel info, try to get aicpu kernel info";
kernel::AicpuQuery(kernel_node, &kernel_info_list);
selected_kernel_info = CanHitKernelInfo(&status, kernel_node, kernel_info_list);
kernel::AICpuQuery(kernel_node, &kernel_info_list);
select_status = SetMatchedKernelInfo(kernel_node, kernel_info_list);
}
if (selected_kernel_info == nullptr) {
// The kernel info not finded both in the aicpu kernel list & aicore kernel list
if (select_status == kNoMatched) {
std::ostringstream buffer;
PrintInputAndOutputInferType(buffer, kernel_node);
MS_EXCEPTION(TypeError) << "The node [" << kernel_node->DebugString()
<< "] cannot find valid kernel info, not supported the type " << buffer.str();
}
AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_info, kernel_node.get());
// Set format and data type for input tensor.
SetTensorDeviceInfo(*selected_kernel_info, kernel_node);
return status;
}
bool CheckKernelAccuracySupported(const CNodePtr &kernel_node,
const kernel::KernelBuildInfoPtr &new_kernel_build_info) {
MS_EXCEPTION_IF_NULL(kernel_node);
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list;
kernel::KernelQuery(kernel_node, &kernel_info_list);
auto result = std::find_if(kernel_info_list.begin(), kernel_info_list.end(),
[&new_kernel_build_info](const kernel::KernelBuildInfoPtr item) {
MS_EXCEPTION_IF_NULL(item);
return *item == *new_kernel_build_info;
});
return result != kernel_info_list.end();
return select_status;
}
} // namespace ascend
} // namespace device

@ -21,8 +21,13 @@
namespace mindspore {
namespace device {
namespace ascend {
int SelectKernelInfo(const CNodePtr &kernel_node);
bool CheckKernelAccuracySupported(const CNodePtr &kernel_node, const kernel::KernelBuildInfoPtr &new_kernel_build_info);
enum KernelSelectStatus {
kNoMatched = -1,
kStatusAllMatched = 0,
kStatusReducePrecision = 1,
kStatusRaisePrecision = 2,
};
KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node);
} // namespace ascend
} // namespace device
} // namespace mindspore

@ -35,7 +35,7 @@ void HcclMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<K
std::vector<std::string> input_format, output_format;
std::vector<TypeId> input_type, output_type;
for (const auto &data_type : data_type_list) {
for (const auto &format : k4DSupportFormat) {
for (const auto &format : kOpFormatList) {
auto builder = std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>();
input_format.clear();
input_format.push_back(format);

@ -35,14 +35,18 @@ void FilterInvalidKernelInfo(const CNodePtr &kernel_node,
return AnfAlgo::GetOutputTensorNum(kernel_node) == kernel_build_info->GetOutputNum() &&
AnfAlgo::GetInputTensorNum(kernel_node) == kernel_build_info->GetInputNum();
});
kernel_info_list->clear();
if (!filtered_list.empty()) {
kernel_info_list->clear();
(void)std::copy(filtered_list.begin(), filtered_list.end(), std::back_inserter(*kernel_info_list));
} else {
MS_LOG(EXCEPTION) << "node" << kernel_node->DebugString() << "'s output size : ["
<< AnfAlgo::GetOutputTensorNum(kernel_node) << "]"
<< "input size : [" << AnfAlgo::GetInputTensorNum(kernel_node)
<< "] cannot match any kernelInfo !";
MS_LOG(WARNING) << "All kernel Info list does not match any kernel info ";
for (size_t index; index < kernel_info_list->size(); ++index) {
MS_EXCEPTION_IF_NULL(kernel_info_list->at(index));
MS_LOG(WARNING) << "kernel [ " << index << " ] :" << kernel_info_list->at(index)->ToString();
}
MS_LOG(WARNING) << "node" << kernel_node->DebugString() << "'s output size : ["
<< AnfAlgo::GetOutputTensorNum(kernel_node) << "]"
<< "input size : [" << AnfAlgo::GetInputTensorNum(kernel_node) << "] cannot match any kernelInfo !";
}
}
} // namespace
@ -50,7 +54,6 @@ void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel
MS_EXCEPTION_IF_NULL(kernel_node);
MS_EXCEPTION_IF_NULL(kernel_info_list);
TbeMetadataInfo(kernel_node, kernel_info_list);
if (kernel_info_list->empty()) {
AicpuMetadataInfo(kernel_node, kernel_info_list);
}
@ -68,12 +71,41 @@ void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel
FilterInvalidKernelInfo(kernel_node, kernel_info_list);
}
void AicpuQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list) {
void AICpuQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list) {
MS_EXCEPTION_IF_NULL(kernel_node);
MS_EXCEPTION_IF_NULL(kernel_info_list);
kernel_info_list->clear();
AicpuMetadataInfo(kernel_node, kernel_info_list);
FilterInvalidKernelInfo(kernel_node, kernel_info_list);
}
bool IsSupportedByAiCpu(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info) {
MS_EXCEPTION_IF_NULL(kernel_node);
MS_EXCEPTION_IF_NULL(select_kernel_build_info);
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list;
auto cnode = kernel_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
AicpuMetadataInfo(cnode, &kernel_info_list);
FilterInvalidKernelInfo(cnode, &kernel_info_list);
return std::any_of(kernel_info_list.begin(), kernel_info_list.end(),
[&select_kernel_build_info](const kernel::KernelBuildInfoPtr item) {
MS_EXCEPTION_IF_NULL(item);
return *item == *select_kernel_build_info;
});
}
bool IsSupportedByAiCore(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info) {
MS_EXCEPTION_IF_NULL(kernel_node);
MS_EXCEPTION_IF_NULL(select_kernel_build_info);
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list;
auto cnode = kernel_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
TbeMetadataInfo(cnode, &kernel_info_list);
FilterInvalidKernelInfo(cnode, &kernel_info_list);
return std::any_of(kernel_info_list.begin(), kernel_info_list.end(),
[&select_kernel_build_info](const kernel::KernelBuildInfoPtr item) {
MS_EXCEPTION_IF_NULL(item);
return *item == *select_kernel_build_info;
});
}
} // namespace kernel
} // namespace mindspore

@ -26,7 +26,9 @@
namespace mindspore {
namespace kernel {
void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list);
void AicpuQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list);
void AICpuQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list);
bool IsSupportedByAiCpu(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info);
bool IsSupportedByAiCore(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_KERNEL_KERNEL_QUERY_H_

@ -551,11 +551,6 @@ bool ParseMetadata(const CNodePtr &kernel_node, const std::shared_ptr<const OpIn
}
bool IsShapeMatchFormat(const std::vector<size_t> &shape, const std::string &format) {
const std::set<std::string> kOpFormatList = {kOpFormat_DEFAULT, kOpFormat_NC1KHKWHWC0, kOpFormat_ND,
kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_HWCN,
kOpFormat_NC1HWC0, kOpFormat_FRAC_Z, kOpFormat_C1HWNCoC0,
kOpFormat_FRAC_NZ, kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04};
// if format is default, it remarkes support all format
if (kOpFormatList.find(format) == kOpFormatList.end()) {
MS_LOG(EXCEPTION) << "Got the unknown format " << format;

@ -54,6 +54,7 @@
#include "pre_activate/pass/optimize_dependence.h"
#include "pre_activate/pass/erase_visit_attr.h"
#include "pre_activate/ascend/format_type/insert_cast.h"
#include "pre_activate/ascend/format_type/convert_unsupported_transnode_to_aicpu.h"
#include "pre_activate/pass/eliminate_redundant_op.h"
#include "pre_activate/pass/common_subexpression_elimination.h"
#include "pre_activate/ascend/format_type/merge_cast_to_op.h"
@ -172,6 +173,7 @@ void AscendMixPrecision(const std::shared_ptr<session::KernelGraph> &kernel_grap
mixed_precision_pm->AddPass(std::make_shared<MergeCastToOp>());
mixed_precision_pm->AddPass(std::make_shared<LayerNormBetaGammaBackpropFusion>());
mixed_precision_pm->AddPass(std::make_shared<EraseVisitAttr>());
mixed_precision_pm->AddPass(std::make_shared<ConvertUnSupportNodeToAICPU>());
optimizer->AddPassManager(mixed_precision_pm);
(void)optimizer->Optimize(kernel_graph);
kernel_graph->SetExecOrderByDefault();

@ -268,6 +268,7 @@ AnfNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr
}
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cast.get());
AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {origin_shape}, cast.get());
AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(true), cast);
return cast;
}

@ -30,10 +30,6 @@ class KernelSelect {
KernelSelect() = default;
virtual ~KernelSelect() = default;
virtual void SelectKernel(const CNodePtr &cnode) { device::ascend::SelectKernelInfo(cnode); }
virtual bool CheckKernelAccuracySupported(const CNodePtr &kernel_node,
const kernel::KernelBuildInfoPtr &new_kernel_build_info) {
return device::ascend::CheckKernelAccuracySupported(kernel_node, new_kernel_build_info);
}
};
using KernelSelectPtr = std::shared_ptr<KernelSelect>;
@ -41,8 +37,13 @@ class SupportedChecker {
public:
SupportedChecker() = default;
virtual ~SupportedChecker() = default;
virtual bool CheckSupported(const AnfNodePtr &anf_node, const kernel::KernelBuildInfoPtr &select_kernel_build_info) {
return kernel::CheckSupported(anf_node, select_kernel_build_info);
virtual bool CheckAiCoreSupported(const AnfNodePtr &anf_node,
const kernel::KernelBuildInfoPtr &select_kernel_build_info) {
return kernel::IsSupportedByAiCore(anf_node, select_kernel_build_info);
}
virtual bool CheckAiCpuSupported(const AnfNodePtr &anf_node,
const kernel::KernelBuildInfoPtr &select_kernel_build_info) {
return kernel::IsSupportedByAiCpu(anf_node, select_kernel_build_info);
}
};
using SupportedCheckerPtr = std::shared_ptr<SupportedChecker>;

@ -0,0 +1,54 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pre_activate/ascend/format_type/convert_unsupported_transnode_to_aicpu.h"
#include <memory>
#include "session/anf_runtime_algorithm.h"
#include "kernel/kernel_build_info.h"
#include "kernel/kernel_query.h"
namespace mindspore {
namespace opt {
const BaseRef ConvertUnSupportNodeToAICPU::DefinePattern() const {
VarPtr X = std::make_shared<Var>();
VarPtr Xs = std::make_shared<SeqVar>();
return VectorRef({X, Xs});
}
const AnfNodePtr ConvertUnSupportNodeToAICPU::Process(const mindspore::FuncGraphPtr &,
const mindspore::AnfNodePtr &node,
const mindspore::EquivPtr &) const {
if (node == nullptr || !node->isa<CNode>()) {
return nullptr;
}
auto node_name = AnfAlgo::GetCNodeName(node);
if (node_name != prim::KPrimTransData->name() || node_name != prim::kPrimCast->name()) {
return nullptr;
}
auto kernel_builder_info = AnfAlgo::GetSelectKernelBuildInfo(node);
if (supported_checker_->CheckAiCoreSupported(node, kernel_builder_info)) {
return node;
} else if (supported_checker_->CheckAiCpuSupported(node, kernel_builder_info)) {
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(kernel_builder_info);
builder->SetKernelType(AICPU_KERNEL);
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get());
} else {
MS_LOG(EXCEPTION) << " kernel " << kernel_builder_info->ToString() << "is not supported in AiCPU & AiCore : node ["
<< node->DebugString() << "]";
}
return node;
}
} // namespace opt
} // namespace mindspore

@ -0,0 +1,37 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <memory>
#include "pre_activate/common/optimizer.h"
#include "pre_activate/ascend/ascend_helper.h"
#ifndef MINDSPORE_CONVERT_UNSUPPORTED_NODE_TO_AICPU_H
#define MINDSPORE_CONVERT_UNSUPPORTED_NODE_TO_AICPU_H
namespace mindspore {
namespace opt {
class ConvertUnSupportNodeToAICPU : public PatternProcessPass {
public:
explicit ConvertUnSupportNodeToAICPU(bool multigraph = true)
: PatternProcessPass("convert_unsupported_node_to_aicpu", multigraph),
supported_checker_(std::make_shared<SupportedChecker>()) {}
~ConvertUnSupportNodeToAICPU() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
SupportedCheckerPtr supported_checker_;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CONVERT_UNSUPPORTED_NODE_TO_AICPU_H

@ -13,8 +13,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_DEVICE_OPTIMIZER_FORMAT_TYPE_PASS_INSERT_CAST_FOR_RUNOP_H_
#define MINDSPORE_CCSRC_DEVICE_OPTIMIZER_FORMAT_TYPE_PASS_INSERT_CAST_FOR_RUNOP_H_
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_CAST_FOR_RUNOP_H_
#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_CAST_FOR_RUNOP_H_
#include <string>
#include "pre_activate/common/optimizer.h"
@ -32,4 +32,4 @@ class RunOpInsertCast : public PatternProcessPass {
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_DEVICE_OPTIMIZER_FORMAT_TYPE_PASS_INSERT_CAST_FOR_RUNOP_H_
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_CAST_FOR_RUNOP_H_

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_DEVICE_OPTIMIZER_FORMAT_TYPE_PASS_INSERT_TRANSDATA_FOR_RUNOP_H_
#define MINDSPORE_CCSRC_DEVICE_OPTIMIZER_FORMAT_TYPE_PASS_INSERT_TRANSDATA_FOR_RUNOP_H_
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_TRANSDATA_FOR_RUNOP_H_
#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_TRANSDATA_FOR_RUNOP_H_
#include <string>
#include <utility>
@ -41,4 +41,4 @@ class RunOpInsertTransData : public PatternProcessPass {
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_DEVICE_OPTIMIZER_FORMAT_TYPE_PASS_INSERT_TRANSDATA_FOR_RUNOP_H_
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_TRANSDATA_FOR_RUNOP_H_

@ -128,7 +128,7 @@ const AnfNodePtr TopKSplit::Process(const FuncGraphPtr &func_graph, const AnfNod
auto indices_const = CreateValueNode(new_cnode);
new_cnode->add_input(indices_const);
MS_EXCEPTION_IF_NULL(supported_checker_);
if (!supported_checker_->CheckSupported(new_cnode, CreateKernelBuildInfo())) {
if (!supported_checker_->CheckAiCoreSupported(new_cnode, CreateKernelBuildInfo())) {
return nullptr;
}

@ -53,7 +53,7 @@ const AnfNodePtr TransposeTransDataFusion::Process(const FuncGraphPtr &func_grap
new_transdata_builder->SetProcessor(transdata_kernel_build_info->processor());
auto new_fusion_transdata = std::make_shared<Primitive>(kTransDataOpName);
if (kernel_select_->CheckKernelAccuracySupported(transdata_cnode, new_transdata_builder->Build())) {
if (supported_checker_->CheckAiCoreSupported(transdata_cnode, new_transdata_builder->Build())) {
std::vector<AnfNodePtr> inputs = {NewValueNode(new_fusion_transdata),
utils::cast<AnfNodePtr>((*equiv)[input_varptr_])};
auto new_node = func_graph->NewCNode(inputs);

@ -34,7 +34,7 @@ class TransposeTransDataFusion : public PatternProcessPass {
explicit TransposeTransDataFusion(bool multigraph = true)
: PatternProcessPass("transpose_transdata_fusion", multigraph) {
input_varptr_ = std::make_shared<Var>();
kernel_select_ = std::make_shared<KernelSelect>();
supported_checker_ = std::make_shared<SupportedChecker>();
}
~TransposeTransDataFusion() override = default;
const BaseRef DefinePattern() const override;
@ -42,7 +42,9 @@ class TransposeTransDataFusion : public PatternProcessPass {
private:
VarPtr input_varptr_;
KernelSelectPtr kernel_select_;
private:
SupportedCheckerPtr supported_checker_;
};
} // namespace opt
} // namespace mindspore

@ -329,9 +329,9 @@ void AscendSession::SelectKernel(const KernelGraph &kernel_graph) const {
size_t reduce_precision_count = 0;
for (const auto &cnode : kernel_graph.execution_order()) {
auto status = device::ascend::SelectKernelInfo(cnode);
if (status == kStatusRaisePrecision) {
if (status == device::ascend::kStatusRaisePrecision) {
raise_precision_count++;
} else if (status == kStatusReducePrecision) {
} else if (status == device::ascend::kStatusReducePrecision) {
reduce_precision_count++;
}
MS_LOG(INFO) << "Select ApplyKernel: " << cnode->DebugString();

@ -27,6 +27,8 @@
namespace mindspore {
namespace session {
namespace {
constexpr auto kIsFeatureMapOutput = "IsFeatureMapOutput";
constexpr auto kIsFeatureMapInputList = "IsFeatureMapInputList";
void PushNoVisitedNode(const AnfNodePtr &node, std::queue<AnfNodePtr> *que,
std::unordered_set<AnfNodePtr> *visited_nodes) {
MS_EXCEPTION_IF_NULL(que);
@ -180,11 +182,24 @@ CNodePtr KernelGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) {
cnode->set_abstract(std::make_shared<abstract::AbstractNone>());
// create kernel_info from new parameter
auto kernel_info = std::make_shared<device::KernelInfo>();
std::vector<size_t> feature_map_input_indexs;
// if the node only has the primitive(such as getNext) or the node's input has a feature map input
// then the node's output is a feature map output
if (inputs.size() == 1 || std::any_of(inputs.begin() + 1, inputs.end(),
[&](const AnfNodePtr &node) { return AnfAlgo::IsFeatureMapOutput(node); })) {
for (size_t index = 1; index < inputs.size(); ++index) {
auto node = inputs[index];
if (AnfAlgo::IsFeatureMapOutput(node)) {
feature_map_input_indexs.push_back(index);
}
}
if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimCast->name()) {
AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(false), cnode);
}
if (inputs.size() == 1 || !feature_map_input_indexs.empty()) {
kernel_info->SetFeatureMapFlag(true);
AnfAlgo::SetNodeAttr(kIsFeatureMapOutput, MakeValue(true), cnode);
AnfAlgo::SetNodeAttr(kIsFeatureMapInputList, MakeValue(feature_map_input_indexs), cnode);
} else {
AnfAlgo::SetNodeAttr(kIsFeatureMapOutput, MakeValue(false), cnode);
}
cnode->set_kernel_info(kernel_info);
AnfAlgo::SetGraphId(graph_id_, cnode.get());

@ -139,6 +139,7 @@ constexpr auto kFusionOpConv2DBackpropInputAddNReluGradV2Name = "FusionOp_Conv2D
// attr key name
constexpr auto kAttrInputNames = "input_names";
constexpr auto kIsBackendCast = "is_backed_cast";
constexpr auto kAttrOutputNames = "output_names";
constexpr auto kAttrVisited = "visited";
constexpr auto kAttrShape = "shape";
@ -196,10 +197,6 @@ constexpr auto kControlDependBehindIndex = 2;
// index define of depend
constexpr auto kRealInputIndexInDepend = 1;
constexpr auto kDependAttachNodeIndex = 2;
// status of kernel select result
const int kStatusReducePrecision = -1;
const int kStatusRaisePrecision = 1;
const int kStatusAllMatched = 0;
// format
constexpr auto kOpFormat_DEFAULT = "DefaultFormat";
constexpr auto kOpFormat_NC1KHKWHWC0 = "NC1KHKWHWC0";
@ -213,18 +210,11 @@ constexpr auto kOpFormat_FRAC_NZ = "FRACTAL_NZ";
constexpr auto kOpFormat_C1HWNCoC0 = "C1HWNCoC0";
constexpr auto kOpFormat_NC1HWC0_C04 = "NC1HWC0_C04";
constexpr auto kOpFormat_FRACTAL_Z_C04 = "FRACTAL_Z_C04";
const std::set<std::string> k1DSupportFormat = {kOpFormat_DEFAULT, kOpFormat_NCHW, kOpFormat_NHWC,
kOpFormat_FRAC_Z, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0,
kOpFormat_C1HWNCoC0};
const std::set<std::string> k2DSupportFormat = {kOpFormat_DEFAULT, kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_FRAC_Z,
kOpFormat_NC1KHKWHWC0};
const std::set<std::string> k3DSupportFormat = {kOpFormat_DEFAULT, kOpFormat_NC1KHKWHWC0};
const std::set<std::string> k4DSupportFormat = k1DSupportFormat;
const std::vector<std::set<std::string>> kShapeSupportFormatMap = {k1DSupportFormat, k2DSupportFormat, k3DSupportFormat,
k4DSupportFormat};
const std::set<std::string> kOpFormatList = {kOpFormat_DEFAULT, kOpFormat_NC1KHKWHWC0, kOpFormat_ND,
kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_HWCN,
kOpFormat_NC1HWC0, kOpFormat_FRAC_Z, kOpFormat_C1HWNCoC0,
kOpFormat_FRAC_NZ, kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04};
const std::set<std::string> kDefaultCompatibleFormat = {kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_HWCN};
const std::set<std::string> kOptOperatorSet = {
kMomentumOpName, kApplyMomentumOpName, kApplyAdadeltaOpName,
kApplyAdagradOpName, kApplyAdagradDAName, kApplyAdamOpName,

File diff suppressed because it is too large Load Diff

@ -39,7 +39,7 @@ class MockSupportedChecker : public SupportedChecker {
public:
MockSupportedChecker() = default;
~MockSupportedChecker() override = default;
bool CheckSupported(const AnfNodePtr &anf_node, const kernel::KernelBuildInfoPtr &select_kernel_build_info) override {
bool CheckAiCoreSupported(const AnfNodePtr &anf_node, const kernel::KernelBuildInfoPtr &select_kernel_build_info) override {
return true;
}
}; // namespace opt

@ -37,6 +37,15 @@ class TestHWTransposeTransdataFusion : public BackendCommon {
UT::PyFuncGraphFetcher get_py_fun_;
};
class MockSupportedChecker : public SupportedChecker {
public:
MockSupportedChecker() = default;
~MockSupportedChecker() override = default;
bool CheckAiCoreSupported(const AnfNodePtr &anf_node, const kernel::KernelBuildInfoPtr &select_kernel_build_info) override {
return true;
}
};
class MockInsertTransOpKernelSelectTrans4Dto5D : public KernelSelect {
public:
MockInsertTransOpKernelSelectTrans4Dto5D() = default;
@ -60,37 +69,6 @@ class MockInsertTransOpKernelSelectTrans4Dto5D : public KernelSelect {
}
};
class MockTransposeTransdataFusionKernelSelect : public KernelSelect {
public:
MockTransposeTransdataFusionKernelSelect() = default;
~MockTransposeTransdataFusionKernelSelect() override = default;
bool CheckKernelAccuracySupported(const CNodePtr &kernel_node,
const kernel::KernelBuildInfoPtr &new_kernel_build_info) override {
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list;
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
builder.SetInputsFormat({kOpFormat_NCHW});
builder.SetOutputsFormat({kOpFormat_DEFAULT});
builder.SetInputsDeviceType({kNumberTypeFloat16});
builder.SetOutputsDeviceType({kNumberTypeFloat16});
builder.SetKernelType(KernelType::AUTO_DIFF_KERNEL);
builder.SetFusionType(kernel::FusionType::OPAQUE);
builder.SetProcessor(kernel::Processor::AICORE);
kernel_info_list.push_back(builder.Build());
MS_LOG(INFO) << "transpose transdata fusion success";
MS_LOG(INFO) << "new transdata build info input format:" << new_kernel_build_info->GetInputFormat(0)
<< ",outputformat:" << new_kernel_build_info->GetOutputFormat(0)
<< ",kerneltype:" << new_kernel_build_info->kernel_type()
<< ",fusiontype:" << new_kernel_build_info->fusion_type()
<< ",process:" << new_kernel_build_info->processor();
auto result = std::find_if(kernel_info_list.begin(), kernel_info_list.end(),
[&new_kernel_build_info](kernel::KernelBuildInfoPtr item) {
MS_EXCEPTION_IF_NULL(item);
return *item == *new_kernel_build_info;
});
return result != kernel_info_list.end();
}
};
TEST_F(TestHWTransposeTransdataFusion, test_transpose_transdata_fusion) {
/*
* def before(input0, input1):
@ -128,7 +106,7 @@ TEST_F(TestHWTransposeTransdataFusion, test_transpose_transdata_fusion) {
insert_trans_op_pass->kernel_select_ = std::make_shared<MockInsertTransOpKernelSelectTrans4Dto5D>();
pm->AddPass(insert_trans_op_pass);
auto transpose_transdata_pass = std::make_shared<opt::TransposeTransDataFusion>();
transpose_transdata_pass->kernel_select_ = std::make_shared<MockTransposeTransdataFusionKernelSelect>();
transpose_transdata_pass->supported_checker_ = std::make_shared<MockSupportedChecker>();
pm->AddPass(transpose_transdata_pass);
optimizer->AddPassManager(pm);
FuncGraphPtr new_graph = optimizer->Optimize(kg);

Loading…
Cancel
Save