!4527 [MS][LITE]optimize the interface of model and kernel registry

Merge pull request !4527 from zhaizhiqiang/master
pull/4527/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 857c0301a8

@ -25,12 +25,12 @@
namespace mindspore {
#define MS_API __attribute__((visibility("default")))
namespace lite {
/// \brief ModelImpl defined the implement class of Model in MindSpore Lite.
///
/// \note List public class and interface for reference.
class ModelImpl;
namespace lite {
/// \brief Primitive defined as prototype of operator.
///
/// \note List public class and interface for reference.
@ -67,11 +67,6 @@ class MS_API Model {
/// \return the pointer of graph defined in flatbuffers.
const schema::MetaGraph *GetMetaGraph() const;
/// \brief Get MindSpore Lite ModelImpl.
///
/// \return the pointer of MindSpore Lite ModelImpl.
ModelImpl *model_impl();
/// \brief Free MetaGraph in MindSpore Lite Model.
void FreeMetaGraph();

@ -8,10 +8,8 @@ set(LITE_SRC
${CMAKE_CURRENT_SOURCE_DIR}/ir/tensor.cc
${CMAKE_CURRENT_SOURCE_DIR}/context.cc
${CMAKE_CURRENT_SOURCE_DIR}/executor.cc
${CMAKE_CURRENT_SOURCE_DIR}/kernel_factory.cc
${CMAKE_CURRENT_SOURCE_DIR}/kernel_registry.cc
${CMAKE_CURRENT_SOURCE_DIR}/lite_kernel.cc
${CMAKE_CURRENT_SOURCE_DIR}/model.cc
${CMAKE_CURRENT_SOURCE_DIR}/populate_parameter.cc
${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cc
)
@ -21,44 +19,11 @@ if (SUPPORT_GPU)
list(APPEND LITE_SRC ${CMAKE_CURRENT_SOURCE_DIR}/runtime/kernel/opencl/utils.cc)
endif()
if (SUPPORT_TRAIN)
set(ANF_SRC
${ANF_SRC}
# ${CCSRC_DIR}/common/trans.cc
# ${CCSRC_DIR}/utils/lite/base_ref_utils.cc
# ${CCSRC_DIR}/runtime/kernel/kernel_compiler/kernel_build_info.cc
# ${CCSRC_DIR}/session/lite/anf_runtime_algorithm_extends.cc
# ${CCSRC_DIR}/session/lite/session_basic_extends.cc
# ${CCSRC_DIR}/session/anf_runtime_algorithm.cc
# ${CCSRC_DIR}/session/session_basic.cc
# ${CCSRC_DIR}/session/kernel_graph.cc
# ${CCSRC_DIR}/session/session_factory.cc
# ${CCSRC_DIR}/device/kernel_info.cc
# ${CCSRC_DIR}/device/kernel_runtime.cc
# ${CCSRC_DIR}/device/lite/kernel_runtime_extends.cc
)
set(PASS_SRC)
set(LITE_SRC
${LITE_SRC}
${ANF_SRC}
# ${PASS_SRC}
# ${CMAKE_CURRENT_SOURCE_DIR}/common/anf_importer/anf_importer.cc
# ${CMAKE_CURRENT_SOURCE_DIR}/common/anf_importer/import_from_meta_graph.cc
# ${CMAKE_CURRENT_SOURCE_DIR}/ir/primitive_value.cc
# ${CMAKE_CURRENT_SOURCE_DIR}/train/lite_kernel_runtime.cc
# ${CMAKE_CURRENT_SOURCE_DIR}/train/train_session.cc
# ${CMAKE_CURRENT_SOURCE_DIR}/train/model_impl.cc
${CMAKE_CURRENT_SOURCE_DIR}/lite_session.cc # temporary
${CMAKE_CURRENT_SOURCE_DIR}/model_impl.cc # temporary
)
else ()
set(LITE_SRC
${LITE_SRC}
${CMAKE_CURRENT_SOURCE_DIR}/lite_session.cc
${CMAKE_CURRENT_SOURCE_DIR}/model_impl.cc
)
endif ()
set(LITE_SRC
${LITE_SRC}
${CMAKE_CURRENT_SOURCE_DIR}/lite_session.cc
${CMAKE_CURRENT_SOURCE_DIR}/model.cc
)
if (SUPPORT_GPU)
set(LITE_SRC

@ -1,53 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "mindspore/lite/src/kernel_factory.h"
#include "utils/log_adapter.h"
#include "src/populate_parameter.h"
#include "schema/model_generated.h"
using mindspore::kernel::KERNEL_ARCH;
using mindspore::kernel::KernelKey;
using mindspore::kernel::LiteKernel;
namespace mindspore::lite {
KernelFactory::KernelFactory() = default;
KernelFactory::~KernelFactory() = default;
KernelFactory *KernelFactory::GetInstance() {
static KernelFactory instance;
return &instance;
}
LiteKernel *KernelFactory::GetKernel(const std::vector<tensor::Tensor *> &in_tensors,
const std::vector<tensor::Tensor *> &out_tensors, const lite::Primitive *primitive,
const Context *ctx, const kernel::KernelKey &key) {
MS_EXCEPTION_IF_NULL(primitive);
MS_EXCEPTION_IF_NULL(ctx);
auto parameter = kernel::PopulateParameter(primitive);
if (parameter == nullptr) {
MS_LOG(ERROR) << "PopulateParameter return nullptr, type: " << schema::EnumNamePrimitiveType(primitive->Type());
return nullptr;
}
auto creator = KernelRegistry::GetInstance()->GetCreator(key);
if (creator != nullptr) {
auto kernel = creator(in_tensors, out_tensors, parameter, ctx, key, primitive);
return kernel;
}
return nullptr;
}
} // namespace mindspore::lite

@ -1,40 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_KERNEL_FACTORY_H_
#define MINDSPORE_LITE_SRC_KERNEL_FACTORY_H_
#include <vector>
#include "mindspore/lite/src/lite_kernel.h"
#include "mindspore/lite/src/kernel_registry.h"
#include "mindspore/lite/include/context.h"
#include "mindspore/lite/src/ir/tensor.h"
#include "schema/model_generated.h"
namespace mindspore::lite {
class KernelFactory {
public:
KernelFactory();
virtual ~KernelFactory();
static KernelFactory *GetInstance();
kernel::LiteKernel *GetKernel(const std::vector<tensor::Tensor *> &in_tensors,
const std::vector<tensor::Tensor *> &out_tensors, const lite::Primitive *primitive,
const Context *ctx, const kernel::KernelKey &key);
};
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_SRC_KERNEL_FACTORY_H_

@ -16,6 +16,7 @@
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "ir/dtype/type_id.h"
#include "src/populate_parameter.h"
#ifdef ENABLE_ARM64
#include <asm/hwcap.h>
#include "common/utils.h"
@ -120,4 +121,23 @@ void KernelRegistry::RegKernel(const KERNEL_ARCH arch, const TypeId data_type, c
bool KernelRegistry::Merge(const std::unordered_map<KernelKey, KernelCreator> &newCreators) { return false; }
const kernel::KernelCreator *KernelRegistry::GetCreatorArrays() { return creator_arrays_; }
kernel::LiteKernel *KernelRegistry::GetKernel(const std::vector<tensor::Tensor *> &in_tensors,
const std::vector<tensor::Tensor *> &out_tensors,
const lite::Primitive *primitive, const Context *ctx,
const kernel::KernelKey &key) {
MS_EXCEPTION_IF_NULL(primitive);
MS_EXCEPTION_IF_NULL(ctx);
auto parameter = kernel::PopulateParameter(primitive);
if (parameter == nullptr) {
MS_LOG(ERROR) << "PopulateParameter return nullptr, type: " << schema::EnumNamePrimitiveType(primitive->Type());
return nullptr;
}
auto creator = GetCreator(key);
if (creator != nullptr) {
auto kernel = creator(in_tensors, out_tensors, parameter, ctx, key, primitive);
return kernel;
}
return nullptr;
}
} // namespace mindspore::lite

@ -17,9 +17,9 @@
#ifndef MINDSPORE_LITE_SRC_KERNEL_REGISTRY_H_
#define MINDSPORE_LITE_SRC_KERNEL_REGISTRY_H_
#include <map>
#include <string>
#include <unordered_map>
#include <vector>
#include "src/lite_kernel.h"
#include "schema/model_generated.h"
@ -39,6 +39,9 @@ class KernelRegistry {
void RegKernel(const kernel::KERNEL_ARCH arch, const TypeId data_type, const schema::PrimitiveType type,
kernel::KernelCreator creator);
bool Merge(const std::unordered_map<kernel::KernelKey, kernel::KernelCreator> &newCreators);
kernel::LiteKernel *GetKernel(const std::vector<tensor::Tensor *> &in_tensors,
const std::vector<tensor::Tensor *> &out_tensors, const lite::Primitive *primitive,
const Context *ctx, const kernel::KernelKey &key);
protected:
kernel::KernelCreator *creator_arrays_ = nullptr;

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -1,53 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_MODEL_IMPL_H_
#define MINDSPORE_LITE_SRC_MODEL_IMPL_H_
#include <map>
#include <memory>
#include <string>
#include "schema/model_generated.h"
#include "src/ops/ops.h"
namespace mindspore {
namespace lite {
class ModelImpl {
public:
static ModelImpl *Import(const char *model_buf, size_t size);
ModelImpl() = default;
explicit ModelImpl(const char *model_buf, size_t size) : model_buf_(model_buf), buf_size_(size) {
meta_graph_ = schema::GetMetaGraph(model_buf);
}
virtual ~ModelImpl();
lite::Primitive *GetOp(const std::string &name) const;
const schema::MetaGraph *meta_graph() const;
void FreeMetaGraph();
int BuildOps();
protected:
lite::Primitive *CopyPrimitive(const schema::Primitive *src_prim);
protected:
const char *model_buf_;
size_t buf_size_;
const schema::MetaGraph *meta_graph_ = nullptr;
std::map<std::string, lite::Primitive *> ops_;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_INCLUDE_MODEL_H_

@ -19,7 +19,7 @@
#include "src/runtime/kernel/arm/int8/argminmax_int8.h"
#include "src/runtime/kernel/arm/nnacl/arithmetic_common.h"
#include "schema/model_generated.h"
#include "src/kernel_factory.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "include/context.h"

@ -18,7 +18,7 @@
#include "src/runtime/kernel/arm/fp32/batch_to_space.h"
#include "src/runtime/kernel/arm/int8/batch_to_space_int8.h"
#include "schema/model_generated.h"
#include "src/kernel_factory.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "include/context.h"

@ -16,7 +16,7 @@
#include "src/runtime/kernel/arm/base/caffeprelu_base.h"
#include <vector>
#include "schema/model_generated.h"
#include "src/kernel_factory.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "include/context.h"

@ -19,7 +19,7 @@
#include "src/runtime/kernel/arm/fp32/concat.h"
#include "src/runtime/kernel/arm/nnacl/fp32/concat.h"
#include "schema/model_generated.h"
#include "src/kernel_factory.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "include/context.h"

@ -17,7 +17,7 @@
#include "src/runtime/kernel/arm/base/convolution_base.h"
#include <float.h>
#include "schema/model_generated.h"
#include "src/kernel_factory.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
using mindspore::lite::KernelRegistrar;

@ -18,7 +18,7 @@
#include "src/runtime/kernel/arm/int8/crop_int8.h"
#include "src/runtime/kernel/arm/fp32/crop.h"
#include "schema/model_generated.h"
#include "src/kernel_factory.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "include/context.h"

@ -19,7 +19,7 @@
#include "src/runtime/kernel/arm/int8/depth_to_space_int8.h"
#include "src/runtime/kernel/arm/nnacl/arithmetic_common.h"
#include "schema/model_generated.h"
#include "src/kernel_factory.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "include/context.h"

@ -17,7 +17,7 @@
#include "src/runtime/kernel/arm/int8/fullconnection_int8.h"
#include "src/runtime/kernel/arm/fp32/fullconnection.h"
#include "schema/model_generated.h"
#include "src/kernel_factory.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "include/context.h"

@ -16,7 +16,7 @@
#include "src/runtime/kernel/arm/base/matmul_base.h"
#include "src/runtime/kernel/arm/fp32/matmul.h"
#include "src/runtime/kernel/arm/int8/matmul_int8.h"
#include "src/kernel_factory.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "include/context.h"

@ -17,7 +17,7 @@
#include "src/runtime/kernel/arm/fp32/pad.h"
#include "src/runtime/kernel/arm/int8/pad_int8.h"
#include "schema/model_generated.h"
#include "src/kernel_factory.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "include/context.h"

@ -18,7 +18,7 @@
#include "src/runtime/kernel/arm/int8/pooling_int8.h"
#include "src/runtime/kernel/arm/fp32/pooling.h"
#include "schema/model_generated.h"
#include "src/kernel_factory.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "include/context.h"

@ -17,7 +17,7 @@
#include <vector>
#include "src/runtime/kernel/arm/int8/prelu_int8.h"
#include "schema/model_generated.h"
#include "src/kernel_factory.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "include/context.h"

@ -18,7 +18,7 @@
#include <cmath>
#include "src/runtime/kernel/arm/base/prior_box.h"
#include "schema/model_generated.h"
#include "src/kernel_factory.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "include/context.h"
#include "src/runtime/runtime_api.h"

@ -18,7 +18,7 @@
#include "src/runtime/kernel/arm/int8/reshape_int8.h"
#include "src/runtime/kernel/arm/fp32/reshape.h"
#include "schema/model_generated.h"
#include "src/kernel_factory.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "include/context.h"

@ -20,7 +20,7 @@
#include "src/runtime/kernel/arm/fp32/softmax.h"
#include "src/runtime/kernel/arm/nnacl/fp32/softmax.h"
#include "schema/model_generated.h"
#include "src/kernel_factory.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
using mindspore::lite::KernelRegistrar;

@ -18,7 +18,7 @@
#include "src/runtime/kernel/arm/int8/split_int8.h"
#include "src/runtime/kernel/arm/fp32/split.h"
#include "schema/model_generated.h"
#include "src/kernel_factory.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "include/context.h"

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save