add new mode for operator info register

pull/116/head
zjun 5 years ago
parent 2d44dd1cb3
commit 651d9aae40

@ -61,6 +61,7 @@ class OpIOInfo {
std::string name() const { return name_; }
bool need_compile() const { return need_compile_; }
std::string param_type() const { return param_type_; }
std::string reshape_type() const { return reshape_type_; }
std::string shape() const { return shape_; }
std::vector<std::string> dtypes() const { return dtypes_; }
std::vector<std::string> formats() const { return formats_; }
@ -69,6 +70,7 @@ class OpIOInfo {
void set_name(const std::string& name) { name_ = name; }
void set_need_compile(const bool need_compile) { need_compile_ = need_compile; }
void set_param_type(const std::string& param_type) { param_type_ = param_type; }
void set_reshape_type(const std::string& reshape_type) { reshape_type_ = reshape_type; }
void set_shape(const std::string& shape) { shape_ = shape; }
void set_dtypes(const std::vector<std::string>& dtype) { dtypes_ = dtype; }
void set_formats(const std::vector<std::string>& formats) { formats_ = formats; }
@ -78,6 +80,7 @@ class OpIOInfo {
std::string name_;
bool need_compile_ = false;
std::string param_type_;
std::string reshape_type_;
std::string shape_;
std::vector<std::string> dtypes_;
std::vector<std::string> formats_;
@ -96,6 +99,8 @@ class OpInfo {
int compute_cost() const { return compute_cost_; }
std::string kernel_name() const { return kernel_name_; }
bool partial_flag() const { return partial_flag_; }
bool dynamic_format() const { return dynamic_format_; }
std::string op_pattern() const { return op_pattern_; }
std::vector<std::shared_ptr<OpAttr>> attrs_ptr() const { return attrs_ptr_; }
std::vector<std::shared_ptr<OpIOInfo>> inputs_ptr() const { return inputs_ptr_; }
std::vector<std::shared_ptr<OpIOInfo>> outputs_ptr() const { return outputs_ptr_; }
@ -110,6 +115,8 @@ class OpInfo {
void set_compute_cost(const int compute_cost) { compute_cost_ = compute_cost; }
void set_kernel_name(const std::string& kernel_name) { kernel_name_ = kernel_name; }
void set_partial_flag(const bool partial_flag) { partial_flag_ = partial_flag; }
void set_dynamic_format(const bool dynamic_format) { dynamic_format_ = dynamic_format; }
void set_op_pattern(const std::string op_pattern) { op_pattern_ = op_pattern; }
void add_attrs_ptr(const std::shared_ptr<OpAttr>& attr) { attrs_ptr_.push_back(attr); }
void add_inputs_ptr(const std::shared_ptr<OpIOInfo>& input) { inputs_ptr_.push_back(input); }
void add_outputs_ptr(const std::shared_ptr<OpIOInfo>& output) { outputs_ptr_.push_back(output); }
@ -129,6 +136,8 @@ class OpInfo {
int compute_cost_ = 0;
std::string kernel_name_;
bool partial_flag_ = false;
bool dynamic_format_ = false;
std::string op_pattern_;
std::vector<std::shared_ptr<OpAttr>> attrs_ptr_;
std::vector<std::shared_ptr<OpIOInfo>> inputs_ptr_;
std::vector<std::shared_ptr<OpIOInfo>> outputs_ptr_;

@ -26,18 +26,22 @@ namespace mindspore {
namespace kernel {
constexpr auto kImplyType = "imply_type";
constexpr auto kOpName = "op_name";
constexpr auto kTbe = "TBE";
constexpr auto kAkg = "akg";
constexpr auto kAutodiff = "AutoDiff";
constexpr auto kFusionType = "fusion_type";
constexpr auto kAsyncFlag = "async_flag";
constexpr auto kBinfileName = "binfile_name";
constexpr auto kComputeCost = "compute_cost";
constexpr auto kKernelName = "kernel_name";
constexpr auto kPartialFlag = "partial_flag";
constexpr auto kReshapeType = "reshape_type";
constexpr auto kOpPattern = "op_pattern";
constexpr auto kDynamicFormat = "dynamic_format";
constexpr auto kDtypeFormat = "dtype_format";
constexpr auto kAttr = "attr";
constexpr auto kIputs = "inputs";
constexpr auto kOutputs = "outputs";
constexpr auto kTbe = "TBE";
constexpr auto kAkg = "akg";
constexpr auto kAutodiff = "AutoDiff";
constexpr auto kName = "name";
constexpr auto kParamType = "param_type";
constexpr auto kDtype = "dtype";
@ -89,8 +93,8 @@ bool OpLib::DecodeOpInfo(const nlohmann::json& obj, const mindspore::kernel::OpI
std::shared_ptr<OpInfo> op_info = std::make_shared<OpInfo>();
MS_EXCEPTION_IF_NULL(op_info);
op_info->set_op_name(obj.at(kOpName));
op_info->set_imply_type(imply_type);
op_info->set_impl_path(impl_path);
op_info->set_imply_type(imply_type);
op_info->set_fusion_type(obj.at(kFusionType));
if (imply_type == kTBE) {
op_info->set_async_flag(obj.at(kAsyncFlag));
@ -98,6 +102,12 @@ bool OpLib::DecodeOpInfo(const nlohmann::json& obj, const mindspore::kernel::OpI
op_info->set_compute_cost(obj.at(kComputeCost));
op_info->set_kernel_name(obj.at(kKernelName));
op_info->set_partial_flag(obj.at(kPartialFlag));
if (obj.find(kOpPattern) != obj.end()) {
op_info->set_op_pattern(obj.at(kOpPattern));
}
if (obj.find(kDynamicFormat) != obj.end()) {
op_info->set_dynamic_format(obj.at(kDynamicFormat));
}
}
auto attrs = obj.at(kAttr);
for (const auto& attr : attrs) {
@ -106,16 +116,20 @@ bool OpLib::DecodeOpInfo(const nlohmann::json& obj, const mindspore::kernel::OpI
return false;
}
}
nlohmann::json dtype_format;
if (obj.find(kDtypeFormat) != obj.end()) {
dtype_format = obj.at(kDtypeFormat);
}
auto inputs = obj.at(kIputs);
for (const auto& input : inputs) {
if (!DecodeInputOutput(input, imply_type, kInput, op_info)) {
if (!DecodeInputOutput(input, imply_type, kInput, op_info, dtype_format)) {
MS_LOG(DEBUG) << "DecodeInputOutput Failed";
return false;
}
}
auto outputs = obj.at(kOutputs);
for (const auto& output : outputs) {
if (!DecodeInputOutput(output, imply_type, kOutput, op_info)) {
if (!DecodeInputOutput(output, imply_type, kOutput, op_info, dtype_format)) {
MS_LOG(DEBUG) << "DecodeInputOutput Failed";
return false;
}
@ -156,16 +170,42 @@ bool OpLib::DecodeAttr(const nlohmann::json& obj, const OpImplyType imply_type,
return ret;
}
bool OpLib::DecodeDtypeFormat(const nlohmann::json& dtype_format, const std::shared_ptr<OpIOInfo>& op_io,
size_t index) {
bool ret = true;
try {
std::vector<std::string> dtype;
std::vector<std::string> format;
for (const auto& it : dtype_format) {
dtype.emplace_back(it[index][0]);
format.emplace_back(it[index][1]);
}
op_io->set_dtypes(dtype);
op_io->set_formats(format);
} catch (const std::exception& e) {
MS_LOG(ERROR) << "DecodeDtypeFormat falied" << e.what();
ret = false;
}
return ret;
}
bool OpLib::DecodeInputOutput(const nlohmann::json& obj, const OpImplyType imply_type, const OpIOType io_type,
const std::shared_ptr<OpInfo>& op_info) {
const std::shared_ptr<OpInfo>& op_info, const nlohmann::json& dtype_format) {
bool ret = true;
try {
std::shared_ptr<OpIOInfo> op_io = std::make_shared<OpIOInfo>();
MS_EXCEPTION_IF_NULL(op_io);
op_io->set_index(obj.at(kIndex));
op_io->set_name(obj.at(kName));
op_io->set_dtypes(obj.at(kDtype));
op_io->set_formats(obj.at(kFormat));
if (!dtype_format.empty()) {
if (!DecodeDtypeFormat(dtype_format, op_io, op_info->inputs_ptr().size() + op_info->outputs_ptr().size())) {
MS_LOG(ERROR) << "Decode dtype format failed";
return false;
}
} else {
op_io->set_dtypes(obj.at(kDtype));
op_io->set_formats(obj.at(kFormat));
}
if (op_io->dtypes().size() != op_io->formats().size()) {
MS_LOG(DEBUG) << "op" << op_io->name() << "dtype size:" << op_io->dtypes()
<< "is not equal to format size:" << op_io->formats();
@ -181,6 +221,9 @@ bool OpLib::DecodeInputOutput(const nlohmann::json& obj, const OpImplyType imply
if (obj.find(kShape) != obj.end()) {
op_io->set_shape(obj.at(kShape));
}
if (obj.find(kReshapeType) != obj.end()) {
op_io->set_reshape_type(obj.at(kReshapeType));
}
}
if (io_type == kInput) {

@ -38,8 +38,10 @@ class OpLib {
static bool DecodeOpInfo(const nlohmann::json& obj, const OpImplyType imply_type, const std::string& impl_path);
static bool DecodeAttr(const nlohmann::json& obj, const OpImplyType imply_type,
const std::shared_ptr<OpInfo>& op_info);
static bool DecodeDtypeFormat(const nlohmann::json& dtype_format, const std::shared_ptr<OpIOInfo>& op_io,
size_t index);
static bool DecodeInputOutput(const nlohmann::json& obj, const OpImplyType imply_type, const OpIOType io_type,
const std::shared_ptr<OpInfo>& op_info);
const std::shared_ptr<OpInfo>& op_info, const nlohmann::json& dtype_format);
static bool GetRefInfo(const std::shared_ptr<OpInfo>& op_info);
static bool CheckRepetition(const std::shared_ptr<OpInfo>& op_info);
};

@ -30,7 +30,7 @@ Note:
from .primitive import Primitive, PrimitiveWithInfer, prim_attr_register
from .vm_impl_registry import get_vm_impl_fn, vm_impl_registry
from .op_info_register import op_info_register
from .op_info_register import op_info_register, TBERegOp, DataType
from .primitive import constexpr
from .._c_expression import signature_rw, signature_kind
@ -40,6 +40,6 @@ __primitive__ = [
]
__all__ = ["get_vm_impl_fn", "vm_impl_registry",
"op_info_register",
"op_info_register", "TBERegOp", "DataType",
"constexpr"]
__all__.extend(__primitive__)

@ -14,208 +14,41 @@
# ============================================================================
"""AdamApplyOneWithDecay op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
adam_apply_one_with_decay_op_info = TBERegOp("AdamApplyOneWithDecay") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("adam_apply_one_with_decay.so") \
.compute_cost(10) \
.kernel_name("adam_apply_one_with_decay") \
.partial_flag(True) \
.input(0, "input0", False, "required", "all") \
.input(1, "input1", False, "required", "all") \
.input(2, "input2", False, "required", "all") \
.input(3, "input3", False, "required", "all") \
.input(4, "input4", False, "required", "all") \
.input(5, "mul0_x", False, "required", "all") \
.input(6, "mul1_x", False, "required", "all") \
.input(7, "mul2_x", False, "required", "all") \
.input(8, "mul3_x", False, "required", "all") \
.input(9, "mul4_x", False, "required", "all") \
.input(10, "add2_y", False, "required", "all") \
.output(0, "output0", False, "required", "all") \
.output(1, "output1", False, "required", "all") \
.output(2, "output2", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default) \
.get_op_info()
@op_info_register("""{
"op_name": "AdamApplyOneWithDecay",
"imply_type": "TBE",
"fusion_type": "OPAQUE",
"async_flag": false,
"binfile_name": "adam_apply_one_with_decay.so",
"compute_cost": 10,
"kernel_name": "adam_apply_one_with_decay",
"partial_flag": true,
"attr": [
],
"inputs": [
{
"index": 0,
"dtype": [
"float16", "float"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "input0",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 1,
"dtype": [
"float16", "float"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "input1",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 2,
"dtype": [
"float16", "float"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "input2",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 3,
"dtype": [
"float16", "float"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "input3",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 4,
"dtype": [
"float16", "float"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "input4",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 5,
"dtype": [
"float16", "float"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "mul0_x",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 6,
"dtype": [
"float16", "float"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "mul1_x",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 7,
"dtype": [
"float16", "float"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "mul2_x",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 8,
"dtype": [
"float16", "float"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "mul3_x",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 9,
"dtype": [
"float16", "float"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "mul4_x",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 10,
"dtype": [
"float16", "float"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "add2_y",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16", "float"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "output0",
"need_compile": true,
"param_type": "required",
"shape": "all"
},
{
"index": 1,
"dtype": [
"float16", "float"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "output1",
"need_compile": true,
"param_type": "required",
"shape": "all"
},
{
"index": 2,
"dtype": [
"float16", "float"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "output2",
"need_compile": true,
"param_type": "required",
"shape": "all"
}
]
}""")
@op_info_register(adam_apply_one_with_decay_op_info)
def _adam_apply_one_with_decay_tbe():
"""AdamApplyOneWithDecay TBE register"""
return

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save