support automatic generate ops.fbs

pull/7198/head
chenjianping 5 years ago
parent a75665ce49
commit 6a0e8e2968

@ -46,7 +46,7 @@ abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::ve
CheckAndConvertUtils::CheckInteger("weight rank", w_shape.size(), kEqual, 4, prim_name); CheckAndConvertUtils::CheckInteger("weight rank", w_shape.size(), kEqual, 4, prim_name);
CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kEqual, 4, prim_name); CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kEqual, 4, prim_name);
CheckAndConvertUtils::Check("x_shape[1] / group", x_shape[1] / conv_prim->GetGroup(), kEqual, "w_shape[1]", CheckAndConvertUtils::Check("x_shape[1] / group", x_shape[1] / conv_prim->get_group(), kEqual, "w_shape[1]",
w_shape[1], conv_prim->name()); w_shape[1], conv_prim->name());
auto out_channel = conv_prim->GetOutputChannel(); auto out_channel = conv_prim->GetOutputChannel();
CheckAndConvertUtils::Check("out_channel", out_channel, kEqual, "w_shape[0]", w_shape[0], conv_prim->name()); CheckAndConvertUtils::Check("out_channel", out_channel, kEqual, "w_shape[0]", w_shape[0], conv_prim->name());
@ -114,9 +114,9 @@ TypePtr Conv2dInferType(const PrimitivePtr &prim, const std::vector<AbstractBase
return TypeIdToType(infer_type); return TypeIdToType(infer_type);
} }
} // namespace } // namespace
Conv2d::Conv2d() : PrimitiveC(kConv2DName) { InitIOName({"x", "w"}, {"output"}); } Conv2D::Conv2D() : PrimitiveC(kConv2DName) { InitIOName({"x", "w"}, {"output"}); }
void Conv2d::Init(int out_channel, const std::vector<int> &kernel_size, int mode, const std::string &pad_mode, void Conv2D::Init(int out_channel, const std::vector<int> &kernel_size, int mode, const std::string &pad_mode,
const std::vector<int> &pad, const std::vector<int> &stride, const std::vector<int> &dilation, const std::vector<int> &pad, const std::vector<int> &stride, const std::vector<int> &dilation,
int group) { int group) {
auto prim_name = this->name(); auto prim_name = this->name();
@ -139,49 +139,49 @@ void Conv2d::Init(int out_channel, const std::vector<int> &kernel_size, int mode
this->SetOutChannel(CheckAndConvertUtils::CheckInteger("out_channel", out_channel, kGreaterThan, 0, prim_name)); this->SetOutChannel(CheckAndConvertUtils::CheckInteger("out_channel", out_channel, kGreaterThan, 0, prim_name));
this->SetGroup(CheckAndConvertUtils::CheckInteger("group", group, kGreaterThan, 0, prim_name)); this->SetGroup(CheckAndConvertUtils::CheckInteger("group", group, kGreaterThan, 0, prim_name));
} }
std::vector<int> Conv2d::GetKernelSize() const { std::vector<int> Conv2D::GetKernelSize() const {
auto value_ptr = GetAttr(kKernelSize); auto value_ptr = GetAttr(kKernelSize);
return GetValue<std::vector<int>>(value_ptr); return GetValue<std::vector<int>>(value_ptr);
} }
std::vector<int> Conv2d::GetStride() const { std::vector<int> Conv2D::GetStride() const {
auto value_ptr = GetAttr(kStride); auto value_ptr = GetAttr(kStride);
return GetValue<std::vector<int>>(value_ptr); return GetValue<std::vector<int>>(value_ptr);
} }
std::vector<int> Conv2d::GetDilation() const { std::vector<int> Conv2D::GetDilation() const {
auto value_ptr = GetAttr(kDilation); auto value_ptr = GetAttr(kDilation);
return GetValue<std::vector<int>>(value_ptr); return GetValue<std::vector<int>>(value_ptr);
} }
std::string Conv2d::GetPadMode() const { std::string Conv2D::GetPadMode() const {
auto value_ptr = this->GetAttr(kPadMode); auto value_ptr = this->GetAttr(kPadMode);
return GetValue<string>(value_ptr); return GetValue<string>(value_ptr);
} }
std::vector<int> Conv2d::GetPad() const { std::vector<int> Conv2D::GetPad() const {
auto value_ptr = this->GetAttr(kPad); auto value_ptr = this->GetAttr(kPad);
return GetValue<std::vector<int>>(value_ptr); return GetValue<std::vector<int>>(value_ptr);
} }
int Conv2d::GetMode() const { int Conv2D::GetMode() const {
auto value_ptr = this->GetAttr(kMode); auto value_ptr = this->GetAttr(kMode);
return GetValue<int>(value_ptr); return GetValue<int>(value_ptr);
} }
int Conv2d::GetGroup() const { int Conv2D::get_group() const {
auto value_ptr = this->GetAttr(kGroup); auto value_ptr = this->GetAttr(kGroup);
return GetValue<int>(value_ptr); return GetValue<int>(value_ptr);
} }
int Conv2d::GetOutputChannel() const { int Conv2D::GetOutputChannel() const {
auto value_ptr = this->GetAttr(kOutputChannel); auto value_ptr = this->GetAttr(kOutputChannel);
return GetValue<int>(value_ptr); return GetValue<int>(value_ptr);
} }
void Conv2d::SetKernelSize(const std::vector<int> &kernel_size) { this->AddAttr(kKernelSize, MakeValue(kernel_size)); } void Conv2D::SetKernelSize(const std::vector<int> &kernel_size) { this->AddAttr(kKernelSize, MakeValue(kernel_size)); }
void Conv2d::SetStride(const std::vector<int> &stride) { this->AddAttr(kStride, MakeValue(stride)); } void Conv2D::SetStride(const std::vector<int> &stride) { this->AddAttr(kStride, MakeValue(stride)); }
void Conv2d::SetDilation(const std::vector<int> &dilation) { this->AddAttr(kDilation, MakeValue(dilation)); } void Conv2D::SetDilation(const std::vector<int> &dilation) { this->AddAttr(kDilation, MakeValue(dilation)); }
void Conv2d::SetPadMode(const std::string &pad_mode) { this->AddAttr(kPadMode, MakeValue(pad_mode)); } void Conv2D::SetPadMode(const std::string &pad_mode) { this->AddAttr(kPadMode, MakeValue(pad_mode)); }
void Conv2d::SetPad(const std::vector<int> &pad) { this->AddAttr(kPad, MakeValue(pad)); } void Conv2D::SetPad(const std::vector<int> &pad) { this->AddAttr(kPad, MakeValue(pad)); }
void Conv2d::SetMode(int mode) { this->AddAttr(kMode, MakeValue(mode)); } void Conv2D::SetMode(int mode) { this->AddAttr(kMode, MakeValue(mode)); }
void Conv2d::SetGroup(int group) { this->AddAttr(kGroup, MakeValue(group)); } void Conv2D::SetGroup(int group) { this->AddAttr(kGroup, MakeValue(group)); }
void Conv2d::SetOutChannel(int output_channel) { this->AddAttr(kOutputChannel, MakeValue(output_channel)); } void Conv2D::SetOutChannel(int output_channel) { this->AddAttr(kOutputChannel, MakeValue(output_channel)); }
void Conv2d::SetPadList(const std::vector<int> &pad_list) { this->AddAttr(kPadList, MakeValue(pad_list)); } void Conv2D::SetPadList(const std::vector<int> &pad_list) { this->AddAttr(kPadList, MakeValue(pad_list)); }
AbstractBasePtr Conv2dInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr Conv2dInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) { const std::vector<AbstractBasePtr> &input_args) {

@ -25,11 +25,11 @@
#include "abstract/abstract_value.h" #include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h" #include "utils/check_convert_utils.h"
namespace mindspore { namespace mindspore {
class Conv2d : public PrimitiveC { class Conv2D : public PrimitiveC {
public: public:
Conv2d(); Conv2D();
~Conv2d() = default; ~Conv2D() = default;
MS_DECLARE_PARENT(Conv2d, PrimitiveC); MS_DECLARE_PARENT(Conv2D, PrimitiveC);
void Init(int out_channel, const std::vector<int> &kernel_size, int mode = 1, const std::string &pad_mode = "valid", void Init(int out_channel, const std::vector<int> &kernel_size, int mode = 1, const std::string &pad_mode = "valid",
const std::vector<int> &pad = {0, 0, 0, 0}, const std::vector<int> &stride = {1, 1, 1, 1}, const std::vector<int> &pad = {0, 0, 0, 0}, const std::vector<int> &stride = {1, 1, 1, 1},
const std::vector<int> &dilation = {1, 1, 1, 1}, int group = 1); const std::vector<int> &dilation = {1, 1, 1, 1}, int group = 1);
@ -39,7 +39,7 @@ class Conv2d : public PrimitiveC {
std::string GetPadMode() const; std::string GetPadMode() const;
std::vector<int> GetPad() const; std::vector<int> GetPad() const;
int GetMode() const; int GetMode() const;
int GetGroup() const; int get_group() const;
int GetOutputChannel() const; int GetOutputChannel() const;
void SetKernelSize(const std::vector<int> &kernel_size); void SetKernelSize(const std::vector<int> &kernel_size);
void SetStride(const std::vector<int> &stride); void SetStride(const std::vector<int> &stride);
@ -53,6 +53,6 @@ class Conv2d : public PrimitiveC {
}; };
AbstractBasePtr Conv2dInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr Conv2dInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args); const std::vector<AbstractBasePtr> &input_args);
using PrimConv2dPtr = std::shared_ptr<Conv2d>; using PrimConv2dPtr = std::shared_ptr<Conv2D>;
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CORE_C_OPS_CONV2D_H_ #endif // MINDSPORE_CORE_C_OPS_CONV2D_H_

@ -228,6 +228,9 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/nnacl)
if (NOT WIN32) if (NOT WIN32)
if (ENABLE_TOOLS) if (ENABLE_TOOLS)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/tools/benchmark) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/tools/benchmark)
if (NOT PLATFORM_ARM32 AND NOT PLATFORM_ARM64)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/tools/schema_gen)
endif ()
endif() endif()
if (BUILD_TESTCASES) if (BUILD_TESTCASES)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/test) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/test)

@ -0,0 +1,23 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/schema_def.h"
#ifdef PRIMITIVE_WRITEABLE
#include "c_ops/conv2d.h"
#endif
OP_SCHEMA_DEF(Conv2D)
OP_ATTR(group, int)
OP_SCHEMA_DEF_END(Conv2D)

@ -0,0 +1,73 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_OPS_SCHEMA_DEF_H_
#define MINDSPORE_LITE_SRC_OPS_SCHEMA_DEF_H_
#include <string>
#include "src/ops/schema_register.h"
#ifdef PRIMITIVE_WRITEABLE
#include "c_ops/conv2d.h"
#include "schema/inner/model_generated.h"
#endif
#ifdef GEN_SCHEMA_DEF
#define OP_SCHEMA_DEF(OP) \
namespace mindspore::lite::ops { \
std::string Gen##OP##Def() { \
std::string op_def = "table "; \
op_def.append(#OP); \
op_def.append(" {\n");
#elif PRIMITIVE_WRITEABLE
#define OP_SCHEMA_DEF(OP) \
namespace mindspore::lite::ops { \
mindspore::schema::OP##T *PrimitiveOp2SchemaOp(const mindspore::OP *op) { \
mindspore::schema::OP##T *result_op = new (std::nothrow) mindspore::schema::OP##T();
#else
#define OP_SCHEMA_DEF(OP)
#endif
#ifdef GEN_SCHEMA_DEF
#define OP_ATTR(key, type) op_def.append(#key).append(": ").append(#type).append(";\n");
#elif PRIMITIVE_WRITEABLE
#define OP_ATTR(key, type) result_op->key = op->get_##key();
#else
#define OP_ATTR(key, type)
#endif
#ifdef GEN_SCHEMA_DEF
#define OP_ATTR_WITH_VALUE(key, type, value) \
op_def.append(#key).append(": ").append(#type).append(" = ").append(#value).append(";\n");
#elif PRIMITIVE_WRITEABLE
#define OP_ATTR_WITH_VALUE(key, type, value) result_op->key = op->get_##key();
#else
#define OP_ATTR_WITH_VALUE(key, type, value)
#endif
#ifdef GEN_SCHEMA_DEF
#define OP_SCHEMA_DEF_END(OP) \
op_def.append("}\n\n"); \
return op_def; \
} \
SchemaOpRegister g_schema_op_##OP(Gen##OP##Def); \
} // namespace mindspore::lite::ops
#elif PRIMITIVE_WRITEABLE
#define OP_SCHEMA_DEF_END(OP) \
return result_op; \
} \
} // namespace mindspore::lite::ops
#else
#define OP_SCHEMA_DEF_END(OP)
#endif
#endif // MINDSPORE_LITE_SRC_OPS_SCHEMA_DEF_H_

@ -0,0 +1,52 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_OPS_SCHEMA_REGISTER_H_
#define MINDSPORE_LITE_SRC_OPS_SCHEMA_REGISTER_H_
#include <string>
#include <vector>
#include <functional>
namespace mindspore::lite::ops {
using GetSchemaDef = std::function<std::string()>;
class SchemaRegisterImpl {
public:
SchemaRegisterImpl() = default;
static SchemaRegisterImpl *Instance() {
static SchemaRegisterImpl instance;
return &instance;
}
void OpPush(GetSchemaDef func) { op_def_funcs_.push_back(func); }
void TypePush(GetSchemaDef func) { type_def_funcs_.push_back(func); }
const std::vector<GetSchemaDef> &GetAllOpDefCreateFuncs() const { return op_def_funcs_; }
const std::vector<GetSchemaDef> &GetAllTypeDefCreateFuncs() const { return type_def_funcs_; }
private:
std::vector<GetSchemaDef> op_def_funcs_;
std::vector<GetSchemaDef> type_def_funcs_;
};
class SchemaOpRegister {
public:
explicit SchemaOpRegister(GetSchemaDef func) { SchemaRegisterImpl::Instance()->OpPush(func); }
};
} // namespace mindspore::lite::ops
#endif // MINDSPORE_LITE_SRC_OPS_SCHEMA_REGISTER_H_

@ -0,0 +1,17 @@
# add shared link library
add_compile_definitions(GEN_SCHEMA_DEF)
set(COMMON_SRC
${CMAKE_CURRENT_SOURCE_DIR}/../common/flag_parser.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/file_utils.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/utils.cc
)
add_executable(schema_gen
${CMAKE_CURRENT_SOURCE_DIR}/main.cc
${CMAKE_CURRENT_SOURCE_DIR}/schema_gen.cc
${CMAKE_CURRENT_SOURCE_DIR}/schema_type_def.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../src/ops/ops_def.cc
${COMMON_SRC})
target_link_libraries(schema_gen mindspore-lite pthread)
install(TARGETS schema_gen
RUNTIME DESTINATION ${MAIN_DIR}-${RUN_X86_COMPONENT_NAME}/schema_gen COMPONENT ${RUN_X86_COMPONENT_NAME})

@ -0,0 +1,19 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/schema_gen/schema_gen.h"
int main(int argc, const char **argv) { return mindspore::lite::RunSchemaGen(argc, argv); }

@ -0,0 +1,88 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/schema_gen/schema_gen.h"
#include <sys/stat.h>
#include <unistd.h>
#include <iostream>
#include <fstream>
#include "include/errorcode.h"
#include "src/ops/schema_register.h"
#include "src/common/log_adapter.h"
namespace mindspore::lite {
using mindspore::lite::ops::SchemaRegisterImpl;
int SchemaGen::Init() {
if (this->flags_ == nullptr) {
return RET_ERROR;
}
MS_LOG(INFO) << "Export Path = " << flags_->export_path_;
SchemaRegisterImpl *instance = SchemaRegisterImpl::Instance();
if (instance == nullptr) {
MS_LOG(ERROR) << "get instance fail!";
return RET_ERROR;
}
std::string path = flags_->export_path_ + "/ops.fbs";
if (access((path).c_str(), F_OK) == 0) {
chmod((path).c_str(), S_IWUSR);
}
std::ofstream output(path, std::ofstream::binary);
if (!output.is_open()) {
MS_LOG(ERROR) << "Can not open file: " << path;
return RET_ERROR;
}
std::string ns = "namespace mindspore.schema;\n\n";
output.write(ns.c_str(), ns.length());
for (auto &&func : instance->GetAllTypeDefCreateFuncs()) {
std::string &&str = func();
output.write(str.c_str(), str.length());
}
for (auto &&func : instance->GetAllOpDefCreateFuncs()) {
std::string &&str = func();
output.write(str.c_str(), str.length());
}
output.close();
chmod(path.c_str(), S_IRUSR);
return RET_OK;
}
int RunSchemaGen(int argc, const char **argv) {
SchemaGenFlags flags;
Option<std::string> err = flags.ParseFlags(argc, argv);
if (err.IsSome()) {
std::cerr << err.Get() << std::endl;
std::cerr << flags.Usage() << std::endl;
return RET_ERROR;
}
if (flags.help) {
std::cerr << flags.Usage() << std::endl;
return 0;
}
SchemaGen gen(&flags);
int ret = gen.Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "schema gen fail!ret: " << ret;
}
return ret;
}
} // namespace mindspore::lite

@ -0,0 +1,42 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_SCHEMA_GEN_SCHEMA_GEN_H_
#define MINDSPORE_LITE_TOOLS_SCHEMA_GEN_SCHEMA_GEN_H_
#include <string>
#include "tools/common/flag_parser.h"
namespace mindspore::lite {
class SchemaGenFlags : public virtual FlagParser {
public:
SchemaGenFlags() { AddFlag(&SchemaGenFlags::export_path_, "exportPath", "schema define export path", "."); }
public:
std::string export_path_ = ".";
};
class SchemaGen {
public:
explicit SchemaGen(SchemaGenFlags *flags) : flags_(flags) {}
int Init();
private:
SchemaGenFlags *flags_;
};
int RunSchemaGen(int argc, const char **argv);
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_TOOLS_SCHEMA_GEN_SCHEMA_GEN_H_

@ -0,0 +1,62 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/schema_gen/schema_type_def.h"
SCHEMA_ENUM_DEF(ResizeMethod, byte)
SCHEMA_ENUM_ATTR_WITH_VALUE(UNKNOW, -1)
SCHEMA_ENUM_ATTR_WITH_VALUE(BILINEAR, 0)
SCHEMA_ENUM_ATTR_WITH_VALUE(NEAREST_NEIGHBOR, 1)
OP_SCHEMA_DEF_END(ResizeMethod)
SCHEMA_ENUM_DEF(Format, int)
SCHEMA_ENUM_ATTR_WITH_VALUE(NCHW, 0)
SCHEMA_ENUM_ATTR(NHWC)
SCHEMA_ENUM_ATTR(NHWC4)
SCHEMA_ENUM_ATTR(HWKC)
SCHEMA_ENUM_ATTR(HWCK)
SCHEMA_ENUM_ATTR(KCHW)
SCHEMA_ENUM_ATTR(CKHW)
SCHEMA_ENUM_ATTR(KHWC)
SCHEMA_ENUM_ATTR(CHWK)
SCHEMA_ENUM_ATTR(HW)
SCHEMA_ENUM_ATTR(HW4)
SCHEMA_ENUM_ATTR(NC)
SCHEMA_ENUM_ATTR(NC4)
SCHEMA_ENUM_ATTR_WITH_VALUE(NC4HW4, 100)
SCHEMA_ENUM_ATTR(NUM_OF_FORMAT)
OP_SCHEMA_DEF_END(Format)
SCHEMA_ENUM_DEF(ActivationType, byte)
SCHEMA_ENUM_ATTR_WITH_VALUE(NO_ACTIVATION, 0)
SCHEMA_ENUM_ATTR_WITH_VALUE(RELU, 1)
SCHEMA_ENUM_ATTR_WITH_VALUE(SIGMOID, 2)
SCHEMA_ENUM_ATTR_WITH_VALUE(RELU6, 3)
SCHEMA_ENUM_ATTR_WITH_VALUE(ELU, 4)
SCHEMA_ENUM_ATTR_WITH_VALUE(LEAKY_RELU, 5)
SCHEMA_ENUM_ATTR_WITH_VALUE(ABS, 6)
SCHEMA_ENUM_ATTR_WITH_VALUE(RELU1, 7)
SCHEMA_ENUM_ATTR_WITH_VALUE(SOFTSIGN, 8)
SCHEMA_ENUM_ATTR_WITH_VALUE(SOFTPLUS, 9)
SCHEMA_ENUM_ATTR_WITH_VALUE(TANH, 10)
SCHEMA_ENUM_ATTR_WITH_VALUE(SELU, 11)
SCHEMA_ENUM_ATTR_WITH_VALUE(HSWISH, 12)
SCHEMA_ENUM_ATTR_WITH_VALUE(HSIGMOID, 13)
SCHEMA_ENUM_ATTR_WITH_VALUE(THRESHOLDRELU, 14)
SCHEMA_ENUM_ATTR_WITH_VALUE(LINEAR, 15)
SCHEMA_ENUM_ATTR_WITH_VALUE(HARD_TANH, 16)
SCHEMA_ENUM_ATTR_WITH_VALUE(SIGN, 17)
SCHEMA_ENUM_ATTR_WITH_VALUE(UNKNOW, 18)
OP_SCHEMA_DEF_END(ActivationType)

@ -0,0 +1,41 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_SCHEMA_GEN_SCHEMA_TYPE_DEF_H_
#define MINDSPORE_LITE_TOOLS_SCHEMA_GEN_SCHEMA_TYPE_DEF_H_
#include <string>
#include "tools/schema_gen/schema_type_register.h"
#define SCHEMA_ENUM_DEF(T, B) \
namespace mindspore::lite::ops { \
std::string GenEnumDef##T() { \
std::string def = "enum "; \
def.append(#T); \
def.append(" : "); \
def.append(#B); \
def.append(" {\n");
#define SCHEMA_ENUM_ATTR_WITH_VALUE(key, value) def.append(#key).append(" = ").append(#value).append(",\n");
#define SCHEMA_ENUM_ATTR(key) def.append(#key).append(",\n");
#define OP_SCHEMA_DEF_END(T) \
def.append("}\n\n"); \
return def; \
} \
SchemaTypeRegister g_schema_enum_##T(GenEnumDef##T); \
} // namespace mindspore::lite::ops
#endif // MINDSPORE_LITE_TOOLS_SCHEMA_GEN_SCHEMA_TYPE_DEF_H_

@ -0,0 +1,27 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_SCHEMA_GEN_SCHEMA_TYPE_REGISTER_H_
#define MINDSPORE_LITE_TOOLS_SCHEMA_GEN_SCHEMA_TYPE_REGISTER_H_
#include "src/ops/schema_register.h"
namespace mindspore::lite::ops {
class SchemaTypeRegister {
public:
explicit SchemaTypeRegister(GetSchemaDef func) { SchemaRegisterImpl::Instance()->TypePush(func); }
};
} // namespace mindspore::lite::ops
#endif // MINDSPORE_LITE_TOOLS_SCHEMA_GEN_SCHEMA_TYPE_REGISTER_H_
Loading…
Cancel
Save