|
|
|
@ -18,6 +18,7 @@
|
|
|
|
|
#include <pybind11/pybind11.h>
|
|
|
|
|
#include <unordered_map>
|
|
|
|
|
#include <memory>
|
|
|
|
|
#include <map>
|
|
|
|
|
#include "utils/log_adapter.h"
|
|
|
|
|
#include "utils/overload.h"
|
|
|
|
|
#include "utils/context/ms_context.h"
|
|
|
|
@ -35,6 +36,9 @@ constexpr auto kPartialFlag = "partial_flag";
|
|
|
|
|
constexpr auto kReshapeType = "reshape_type";
|
|
|
|
|
constexpr auto kOpPattern = "op_pattern";
|
|
|
|
|
constexpr auto kDynamicFormat = "dynamic_format";
|
|
|
|
|
constexpr auto kFormatAgnostic = "formatAgnostic";
|
|
|
|
|
constexpr auto kBroadcast = "broadcast";
|
|
|
|
|
constexpr auto kReduce = "reduce";
|
|
|
|
|
constexpr auto kDtypeFormat = "dtype_format";
|
|
|
|
|
constexpr auto kAttr = "attr";
|
|
|
|
|
constexpr auto kIputs = "inputs";
|
|
|
|
@ -95,13 +99,19 @@ bool OpLib::RegOp(const std::string &json_string, const std::string &impl_path)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void OpLib::DecodeTBESpecificInfo(const nlohmann::json &obj, const std::shared_ptr<OpInfo> &op_info) {
|
|
|
|
|
const std::map<std::string, kernel::OpPattern> kOpPatternMap = {{kFormatAgnostic, kFormatAgnosticPattern},
|
|
|
|
|
{kFormatAgnostic, kBroadcastPattern},
|
|
|
|
|
{kReduce, kReducePattern},
|
|
|
|
|
{kDynamicFormat, kDynamicFormatPattern}};
|
|
|
|
|
op_info->set_async_flag(obj.at(kAsyncFlag));
|
|
|
|
|
op_info->set_binfile_name(obj.at(kBinfileName));
|
|
|
|
|
op_info->set_compute_cost(obj.at(kComputeCost));
|
|
|
|
|
op_info->set_kernel_name(obj.at(kKernelName));
|
|
|
|
|
op_info->set_partial_flag(obj.at(kPartialFlag));
|
|
|
|
|
if (obj.find(kOpPattern) != obj.end()) {
|
|
|
|
|
op_info->set_op_pattern(obj.at(kOpPattern));
|
|
|
|
|
if (kOpPatternMap.find(obj.at(kOpPattern)) != kOpPatternMap.end()) {
|
|
|
|
|
op_info->set_op_pattern(obj.at(kOpPattern));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (obj.find(kDynamicFormat) != obj.end()) {
|
|
|
|
|
op_info->set_dynamic_format(obj.at(kDynamicFormat));
|
|
|
|
|