diff --git a/mindspore/lite/include/model.h b/mindspore/lite/include/model.h index 1ecab3dada..67fdbb4ff0 100644 --- a/mindspore/lite/include/model.h +++ b/mindspore/lite/include/model.h @@ -25,12 +25,12 @@ namespace mindspore { #define MS_API __attribute__((visibility("default"))) +namespace lite { /// \brief ModelImpl defined the implement class of Model in MindSpore Lite. /// /// \note List public class and interface for reference. class ModelImpl; -namespace lite { /// \brief Primitive defined as prototype of operator. /// /// \note List public class and interface for reference. @@ -67,11 +67,6 @@ class MS_API Model { /// \return the pointer of graph defined in flatbuffers. const schema::MetaGraph *GetMetaGraph() const; - /// \brief Get MindSpore Lite ModelImpl. - /// - /// \return the pointer of MindSpore Lite ModelImpl. - ModelImpl *model_impl(); - /// \brief Free MetaGraph in MindSpore Lite Model. void FreeMetaGraph(); diff --git a/mindspore/lite/src/CMakeLists.txt b/mindspore/lite/src/CMakeLists.txt index 85dee64691..a1d5c6152f 100644 --- a/mindspore/lite/src/CMakeLists.txt +++ b/mindspore/lite/src/CMakeLists.txt @@ -8,10 +8,8 @@ set(LITE_SRC ${CMAKE_CURRENT_SOURCE_DIR}/ir/tensor.cc ${CMAKE_CURRENT_SOURCE_DIR}/context.cc ${CMAKE_CURRENT_SOURCE_DIR}/executor.cc - ${CMAKE_CURRENT_SOURCE_DIR}/kernel_factory.cc ${CMAKE_CURRENT_SOURCE_DIR}/kernel_registry.cc ${CMAKE_CURRENT_SOURCE_DIR}/lite_kernel.cc - ${CMAKE_CURRENT_SOURCE_DIR}/model.cc ${CMAKE_CURRENT_SOURCE_DIR}/populate_parameter.cc ${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cc ) @@ -21,44 +19,11 @@ if (SUPPORT_GPU) list(APPEND LITE_SRC ${CMAKE_CURRENT_SOURCE_DIR}/runtime/kernel/opencl/utils.cc) endif() -if (SUPPORT_TRAIN) - set(ANF_SRC - ${ANF_SRC} -# ${CCSRC_DIR}/common/trans.cc -# ${CCSRC_DIR}/utils/lite/base_ref_utils.cc -# ${CCSRC_DIR}/runtime/kernel/kernel_compiler/kernel_build_info.cc -# ${CCSRC_DIR}/session/lite/anf_runtime_algorithm_extends.cc -# ${CCSRC_DIR}/session/lite/session_basic_extends.cc -# ${CCSRC_DIR}/session/anf_runtime_algorithm.cc -# ${CCSRC_DIR}/session/session_basic.cc -# ${CCSRC_DIR}/session/kernel_graph.cc -# ${CCSRC_DIR}/session/session_factory.cc -# ${CCSRC_DIR}/device/kernel_info.cc -# ${CCSRC_DIR}/device/kernel_runtime.cc -# ${CCSRC_DIR}/device/lite/kernel_runtime_extends.cc - ) - set(PASS_SRC) - set(LITE_SRC - ${LITE_SRC} - ${ANF_SRC} - # ${PASS_SRC} - # ${CMAKE_CURRENT_SOURCE_DIR}/common/anf_importer/anf_importer.cc - # ${CMAKE_CURRENT_SOURCE_DIR}/common/anf_importer/import_from_meta_graph.cc - # ${CMAKE_CURRENT_SOURCE_DIR}/ir/primitive_value.cc - # ${CMAKE_CURRENT_SOURCE_DIR}/train/lite_kernel_runtime.cc - # ${CMAKE_CURRENT_SOURCE_DIR}/train/train_session.cc - # ${CMAKE_CURRENT_SOURCE_DIR}/train/model_impl.cc - ${CMAKE_CURRENT_SOURCE_DIR}/lite_session.cc # temporary - ${CMAKE_CURRENT_SOURCE_DIR}/model_impl.cc # temporary - ) - -else () - set(LITE_SRC - ${LITE_SRC} - ${CMAKE_CURRENT_SOURCE_DIR}/lite_session.cc - ${CMAKE_CURRENT_SOURCE_DIR}/model_impl.cc - ) -endif () +set(LITE_SRC + ${LITE_SRC} + ${CMAKE_CURRENT_SOURCE_DIR}/lite_session.cc + ${CMAKE_CURRENT_SOURCE_DIR}/model.cc + ) if (SUPPORT_GPU) set(LITE_SRC diff --git a/mindspore/lite/src/kernel_factory.cc b/mindspore/lite/src/kernel_factory.cc deleted file mode 100644 index 257ba0fa0b..0000000000 --- a/mindspore/lite/src/kernel_factory.cc +++ /dev/null @@ -1,53 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "mindspore/lite/src/kernel_factory.h" -#include "utils/log_adapter.h" -#include "src/populate_parameter.h" -#include "schema/model_generated.h" - -using mindspore::kernel::KERNEL_ARCH; -using mindspore::kernel::KernelKey; -using mindspore::kernel::LiteKernel; - -namespace mindspore::lite { -KernelFactory::KernelFactory() = default; - -KernelFactory::~KernelFactory() = default; - -KernelFactory *KernelFactory::GetInstance() { - static KernelFactory instance; - return &instance; -} - -LiteKernel *KernelFactory::GetKernel(const std::vector &in_tensors, - const std::vector &out_tensors, const lite::Primitive *primitive, - const Context *ctx, const kernel::KernelKey &key) { - MS_EXCEPTION_IF_NULL(primitive); - MS_EXCEPTION_IF_NULL(ctx); - auto parameter = kernel::PopulateParameter(primitive); - if (parameter == nullptr) { - MS_LOG(ERROR) << "PopulateParameter return nullptr, type: " << schema::EnumNamePrimitiveType(primitive->Type()); - return nullptr; - } - auto creator = KernelRegistry::GetInstance()->GetCreator(key); - if (creator != nullptr) { - auto kernel = creator(in_tensors, out_tensors, parameter, ctx, key, primitive); - return kernel; - } - return nullptr; -} -} // namespace mindspore::lite diff --git a/mindspore/lite/src/kernel_factory.h b/mindspore/lite/src/kernel_factory.h deleted file mode 100644 index 086bdf7b46..0000000000 --- a/mindspore/lite/src/kernel_factory.h +++ /dev/null @@ -1,40 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_SRC_KERNEL_FACTORY_H_ -#define MINDSPORE_LITE_SRC_KERNEL_FACTORY_H_ - -#include -#include "mindspore/lite/src/lite_kernel.h" -#include "mindspore/lite/src/kernel_registry.h" -#include "mindspore/lite/include/context.h" -#include "mindspore/lite/src/ir/tensor.h" -#include "schema/model_generated.h" - -namespace mindspore::lite { -class KernelFactory { - public: - KernelFactory(); - virtual ~KernelFactory(); - - static KernelFactory *GetInstance(); - kernel::LiteKernel *GetKernel(const std::vector &in_tensors, - const std::vector &out_tensors, const lite::Primitive *primitive, - const Context *ctx, const kernel::KernelKey &key); -}; -} // namespace mindspore::lite - -#endif // MINDSPORE_LITE_SRC_KERNEL_FACTORY_H_ diff --git a/mindspore/lite/src/kernel_registry.cc b/mindspore/lite/src/kernel_registry.cc index 71be828337..f4fb136231 100644 --- a/mindspore/lite/src/kernel_registry.cc +++ b/mindspore/lite/src/kernel_registry.cc @@ -16,6 +16,7 @@ #include "src/kernel_registry.h" #include "include/errorcode.h" #include "ir/dtype/type_id.h" +#include "src/populate_parameter.h" #ifdef ENABLE_ARM64 #include #include "common/utils.h" @@ -120,4 +121,23 @@ void KernelRegistry::RegKernel(const KERNEL_ARCH arch, const TypeId data_type, c bool KernelRegistry::Merge(const std::unordered_map &newCreators) { return false; } const kernel::KernelCreator *KernelRegistry::GetCreatorArrays() { return creator_arrays_; } + +kernel::LiteKernel *KernelRegistry::GetKernel(const std::vector &in_tensors, + const std::vector &out_tensors, + const lite::Primitive *primitive, const Context *ctx, + const kernel::KernelKey &key) { + MS_EXCEPTION_IF_NULL(primitive); + MS_EXCEPTION_IF_NULL(ctx); + auto parameter = kernel::PopulateParameter(primitive); + if (parameter == nullptr) { + MS_LOG(ERROR) << "PopulateParameter return nullptr, type: " << schema::EnumNamePrimitiveType(primitive->Type()); + return nullptr; + } + auto creator = GetCreator(key); + if (creator != nullptr) { + auto kernel = creator(in_tensors, out_tensors, parameter, ctx, key, primitive); + return kernel; + } + return nullptr; +} } // namespace mindspore::lite diff --git a/mindspore/lite/src/kernel_registry.h b/mindspore/lite/src/kernel_registry.h index 3352338203..2c1d608378 100644 --- a/mindspore/lite/src/kernel_registry.h +++ b/mindspore/lite/src/kernel_registry.h @@ -17,9 +17,9 @@ #ifndef MINDSPORE_LITE_SRC_KERNEL_REGISTRY_H_ #define MINDSPORE_LITE_SRC_KERNEL_REGISTRY_H_ -#include #include #include +#include #include "src/lite_kernel.h" #include "schema/model_generated.h" @@ -39,6 +39,9 @@ class KernelRegistry { void RegKernel(const kernel::KERNEL_ARCH arch, const TypeId data_type, const schema::PrimitiveType type, kernel::KernelCreator creator); bool Merge(const std::unordered_map &newCreators); + kernel::LiteKernel *GetKernel(const std::vector &in_tensors, + const std::vector &out_tensors, const lite::Primitive *primitive, + const Context *ctx, const kernel::KernelKey &key); protected: kernel::KernelCreator *creator_arrays_ = nullptr; diff --git a/mindspore/lite/src/model.cc b/mindspore/lite/src/model.cc index 6e97ddac6c..0fce835c43 100644 --- a/mindspore/lite/src/model.cc +++ b/mindspore/lite/src/model.cc @@ -14,16 +14,310 @@ * limitations under the License. */ -// #ifdef SUPPORT_TRAIN -// #include "src/train/model_impl.h" -// #else -#include "src/model_impl.h" -// #endif #include "include/model.h" #include "utils/log_adapter.h" +#include "src/ops/ops.h" namespace mindspore::lite { +class ModelImpl { + public: + static ModelImpl *Import(const char *model_buf, size_t size); + ModelImpl() = default; + explicit ModelImpl(const char *model_buf, size_t size) : model_buf_(model_buf), buf_size_(size) { + meta_graph_ = schema::GetMetaGraph(model_buf); + } + virtual ~ModelImpl(); + lite::Primitive *GetOp(const std::string &name) const; + const schema::MetaGraph *meta_graph() const; + void FreeMetaGraph(); + int BuildOps(); + + protected: + lite::Primitive *CopyPrimitive(const schema::Primitive *src_prim); + + protected: + const char *model_buf_; + size_t buf_size_; + const schema::MetaGraph *meta_graph_ = nullptr; + std::map ops_; +}; + +ModelImpl *ModelImpl::Import(const char *model_buf, size_t size) { + if (model_buf == nullptr) { + MS_LOG(ERROR) << "The model buf is nullptr"; + return nullptr; + } + flatbuffers::Verifier verify((const uint8_t *)model_buf, size); + if (!schema::VerifyMetaGraphBuffer(verify)) { + MS_LOG(ERROR) << "The buffer is invalid and fail to create graph."; + return nullptr; + } + auto *inner_model_buf = new (std::nothrow) char[size]; + if (inner_model_buf == nullptr) { + MS_LOG(ERROR) << "new model buf fail."; + return nullptr; + } + memcpy(inner_model_buf, model_buf, size); + auto model = new (std::nothrow) ModelImpl(inner_model_buf, size); + if (model == nullptr) { + MS_LOG(ERROR) << "Create modelImpl failed"; + return nullptr; + } + auto ret = model->BuildOps(); + if (0 != ret) { + MS_LOG(ERROR) << "BuildOps failed"; + return nullptr; + } + return model; +} + +lite::Primitive *ModelImpl::GetOp(const std::string &name) const { + auto iter = ops_.find(name); + if (iter == ops_.end()) { + return nullptr; + } else { + return iter->second; + } +} + +ModelImpl::~ModelImpl() { + delete[](this->model_buf_); + for (auto iter : ops_) { + delete (iter.second); + } + ops_.clear(); +} + +void ModelImpl::FreeMetaGraph() { + delete[](this->model_buf_); + model_buf_ = nullptr; +} + +const schema::MetaGraph *ModelImpl::meta_graph() const { return this->meta_graph_; } + +lite::Primitive *ModelImpl::CopyPrimitive(const schema::Primitive *src_prim) { + MS_EXCEPTION_IF_NULL(src_prim); + auto op_type = src_prim->value_type(); + switch (op_type) { + case schema::PrimitiveType_SoftMax: + return new lite::SoftMax(const_cast(src_prim)); + case schema::PrimitiveType_Activation: + return new lite::Activation(const_cast(src_prim)); + case schema::PrimitiveType_Conv2D: + return new lite::Conv2D(const_cast(src_prim)); + case schema::PrimitiveType_DeConv2D: + return new lite::DeConv2D(const_cast(src_prim)); + case schema::PrimitiveType_Reduce: + return new lite::Reduce(const_cast(src_prim)); + case schema::PrimitiveType_Pooling: + return new lite::Pooling(const_cast(src_prim)); + case schema::PrimitiveType_DepthwiseConv2D: + return new lite::DepthwiseConv2D(const_cast(src_prim)); + case schema::PrimitiveType_FusedBatchNorm: + return new lite::FusedBatchNorm(const_cast(src_prim)); + case schema::PrimitiveType_BatchNorm: + return new lite::BatchNorm(const_cast(src_prim)); + case schema::PrimitiveType_FullConnection: + return new lite::FullConnection(const_cast(src_prim)); + case schema::PrimitiveType_Power: + return new lite::Power(const_cast(src_prim)); + case schema::PrimitiveType_Range: + return new lite::Range(const_cast(src_prim)); + case schema::PrimitiveType_Mul: + return new lite::Mul(const_cast(src_prim)); + case schema::PrimitiveType_Add: + return new lite::Add(const_cast(src_prim)); + case schema::PrimitiveType_Sub: + return new lite::Sub(const_cast(src_prim)); + case schema::PrimitiveType_Div: + return new lite::Div(const_cast(src_prim)); + case schema::PrimitiveType_BiasAdd: + return new lite::BiasAdd(const_cast(src_prim)); + case schema::PrimitiveType_ExpandDims: + return new lite::ExpandDims(const_cast(src_prim)); + case schema::PrimitiveType_ArgMax: + return new lite::ArgMax(const_cast(src_prim)); + case schema::PrimitiveType_ArgMin: + return new lite::ArgMin(const_cast(src_prim)); + case schema::PrimitiveType_Cast: + return new lite::Cast(const_cast(src_prim)); + case schema::PrimitiveType_Reshape: + return new lite::Reshape(const_cast(src_prim)); + case schema::PrimitiveType_Scale: + return new lite::Scale(const_cast(src_prim)); + case schema::PrimitiveType_Eltwise: + return new lite::Eltwise(const_cast(src_prim)); + case schema::PrimitiveType_Concat: + return new lite::Concat(const_cast(src_prim)); + case schema::PrimitiveType_Fill: + return new lite::Fill(const_cast(src_prim)); + case schema::PrimitiveType_Transpose: + return new lite::Transpose(const_cast(src_prim)); + case schema::PrimitiveType_Slice: + return new lite::Slice(const_cast(src_prim)); + case schema::PrimitiveType_Squeeze: + return new lite::Squeeze(const_cast(src_prim)); + case schema::PrimitiveType_Nchw2Nhwc: + return new lite::Nchw2Nhwc(const_cast(src_prim)); + case schema::PrimitiveType_Nhwc2Nchw: + return new lite::Nhwc2Nchw(const_cast(src_prim)); + case schema::PrimitiveType_Flatten: + return new lite::Flatten(const_cast(src_prim)); + case schema::PrimitiveType_Mean: + return new lite::Mean(const_cast(src_prim)); + case schema::PrimitiveType_Stack: + return new lite::Stack(const_cast(src_prim)); + case schema::PrimitiveType_Crop: + return new lite::Crop(const_cast(src_prim)); + case schema::PrimitiveType_SquaredDifference: + return new lite::SquaredDifference(const_cast(src_prim)); + case schema::PrimitiveType_AddN: + return new lite::AddN(const_cast(src_prim)); + case schema::PrimitiveType_Abs: + return new lite::Abs(const_cast(src_prim)); + case schema::PrimitiveType_Sin: + return new lite::Sin(const_cast(src_prim)); + case schema::PrimitiveType_Cos: + return new lite::Cos(const_cast(src_prim)); + case schema::PrimitiveType_Log: + return new lite::Log(const_cast(src_prim)); + case schema::PrimitiveType_Sqrt: + return new lite::Sqrt(const_cast(src_prim)); + case schema::PrimitiveType_Rsqrt: + return new lite::Rsqrt(const_cast(src_prim)); + case schema::PrimitiveType_Square: + return new lite::Square(const_cast(src_prim)); + case schema::PrimitiveType_Exp: + return new lite::Exp(const_cast(src_prim)); + case schema::PrimitiveType_Gather: + return new lite::Gather(const_cast(src_prim)); + case schema::PrimitiveType_GatherNd: + return new lite::GatherNd(const_cast(src_prim)); + case schema::PrimitiveType_LocalResponseNormalization: + return new lite::LocalResponseNormalization(const_cast(src_prim)); + case schema::PrimitiveType_Maximum: + return new lite::Maximum(const_cast(src_prim)); + case schema::PrimitiveType_Minimum: + return new lite::Minimum(const_cast(src_prim)); + case schema::PrimitiveType_Pad: + return new lite::Pad(const_cast(src_prim)); + case schema::PrimitiveType_StridedSlice: + return new lite::StridedSlice(const_cast(src_prim)); + case schema::PrimitiveType_Prelu: + return new lite::Prelu(const_cast(src_prim)); + case schema::PrimitiveType_CaffePReLU: + return new lite::CaffePReLU(const_cast(src_prim)); + case schema::PrimitiveType_Round: + return new lite::Round(const_cast(src_prim)); + case schema::PrimitiveType_Reverse: + return new lite::Reverse(const_cast(src_prim)); + case schema::PrimitiveType_ReverseSequence: + return new lite::ReverseSequence(const_cast(src_prim)); + case schema::PrimitiveType_LogicalAnd: + return new lite::LogicalAnd(const_cast(src_prim)); + case schema::PrimitiveType_LogicalOr: + return new lite::LogicalOr(const_cast(src_prim)); + case schema::PrimitiveType_LogicalNot: + return new lite::LogicalNot(const_cast(src_prim)); + case schema::PrimitiveType_FloorDiv: + return new lite::FloorDiv(const_cast(src_prim)); + case schema::PrimitiveType_FloorMod: + return new lite::FloorMod(const_cast(src_prim)); + case schema::PrimitiveType_Equal: + return new lite::Equal(const_cast(src_prim)); + case schema::PrimitiveType_NotEqual: + return new lite::NotEqual(const_cast(src_prim)); + case schema::PrimitiveType_Less: + return new lite::Less(const_cast(src_prim)); + case schema::PrimitiveType_LessEqual: + return new lite::LessEqual(const_cast(src_prim)); + case schema::PrimitiveType_Greater: + return new lite::Greater(const_cast(src_prim)); + case schema::PrimitiveType_GreaterEqual: + return new lite::GreaterEqual(const_cast(src_prim)); + case schema::PrimitiveType_Floor: + return new lite::Floor(const_cast(src_prim)); + case schema::PrimitiveType_Ceil: + return new lite::Ceil(const_cast(src_prim)); + case schema::PrimitiveType_Split: + return new lite::Split(const_cast(src_prim)); + case schema::PrimitiveType_OneHot: + return new lite::OneHot(const_cast(src_prim)); + case schema::PrimitiveType_SpaceToDepth: + return new lite::SpaceToDepth(const_cast(src_prim)); + case schema::PrimitiveType_Tile: + return new lite::Tile(const_cast(src_prim)); + case schema::PrimitiveType_Resize: + return new lite::Resize(const_cast(src_prim)); + case schema::PrimitiveType_Unstack: + return new lite::Unstack(const_cast(src_prim)); + case schema::PrimitiveType_Unique: + return new lite::Unique(const_cast(src_prim)); + case schema::PrimitiveType_TopK: + return new lite::TopK(const_cast(src_prim)); + case schema::PrimitiveType_MatMul: + return new lite::MatMul(const_cast(src_prim)); + case schema::PrimitiveType_QuantDTypeCast: + return new lite::QuantDTypeCast(const_cast(src_prim)); + case schema::PrimitiveType_EmbeddingLookup: + return new lite::EmbeddingLookup(const_cast(src_prim)); + case schema::PrimitiveType_Elu: + return new lite::Elu(const_cast(src_prim)); + case schema::PrimitiveType_DeDepthwiseConv2D: + return new lite::DeconvDepthwiseConv2D(const_cast(src_prim)); + case schema::PrimitiveType_Shape: + return new lite::Shape(const_cast(src_prim)); + default: + break; + } + return nullptr; +} + +int ModelImpl::BuildOps() { + if (this->meta_graph_ == nullptr) { + MS_LOG(ERROR) << "mete_graph is nullptr"; + return -1; + } + MS_EXCEPTION_IF_NULL(meta_graph_->nodes()); + for (size_t i = 0; i < meta_graph_->nodes()->size(); i++) { + auto cNode = meta_graph_->nodes()->GetAs(i); + auto name = cNode->name()->str(); + auto srcPrim = cNode->primitive(); + + this->ops_[name] = CopyPrimitive(srcPrim); + // flatbuffers::FlatBufferBuilder fbb(1024); + // schema::Conv2DBuilder conv2DBuilder(fbb); + // conv2DBuilder.add_padMode(srcPrim->value_as_Conv2D()->padMode()); + // conv2DBuilder.add_channelOut(srcPrim->value_as_Conv2D()->channelOut()); + // conv2DBuilder.add_channelIn(srcPrim->value_as_Conv2D()->channelIn()); + // conv2DBuilder.add_strideH(srcPrim->value_as_Conv2D()->strideH()); + // conv2DBuilder.add_strideW(srcPrim->value_as_Conv2D()->strideW()); + // conv2DBuilder.add_dilateH(srcPrim->value_as_Conv2D()->dilateH()); + // conv2DBuilder.add_dilateW(srcPrim->value_as_Conv2D()->dilateW()); + // conv2DBuilder.add_kernelH(srcPrim->value_as_Conv2D()->kernelH()); + // conv2DBuilder.add_kernelW(srcPrim->value_as_Conv2D()->kernelW()); + // conv2DBuilder.add_padUp(srcPrim->value_as_Conv2D()->padUp()); + // conv2DBuilder.add_padDown(srcPrim->value_as_Conv2D()->padDown()); + // conv2DBuilder.add_padLeft(srcPrim->value_as_Conv2D()->padLeft()); + // conv2DBuilder.add_padRight(srcPrim->value_as_Conv2D()->padRight()); + // conv2DBuilder.add_format(srcPrim->value_as_Conv2D()->format()); + // conv2DBuilder.add_group(srcPrim->value_as_Conv2D()->group()); + // conv2DBuilder.add_activationType(srcPrim->value_as_Conv2D()->activationType()); + // schema::PrimitiveBuilder primBuilder(fbb); + // primBuilder.add_value_type(srcPrim->value_type()); + // primBuilder.add_value(conv2DBuilder.Finish()); + // + // fbb.Finish(conv2DBuilder.Finish()); + // auto buf = fbb.GetBufferPointer(); + // auto conv2D = flatbuffers::GetRoot(buf); + // fbb.Clear(); + // + // return const_cast(opDef); + } + return 0; +} + Model *Model::Import(const char *model_buf, size_t size) { auto model = new Model(); if (model_buf == nullptr) { @@ -55,8 +349,4 @@ const schema::MetaGraph *Model::GetMetaGraph() const { return model_impl_->meta_graph(); } -ModelImpl *Model::model_impl() { - MS_EXCEPTION_IF_NULL(model_impl_); - return this->model_impl_; -} } // namespace mindspore::lite diff --git a/mindspore/lite/src/model_impl.cc b/mindspore/lite/src/model_impl.cc deleted file mode 100644 index 7854161fab..0000000000 --- a/mindspore/lite/src/model_impl.cc +++ /dev/null @@ -1,297 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include "src/model_impl.h" -#include "utils/log_adapter.h" - -namespace mindspore::lite { -ModelImpl *ModelImpl::Import(const char *model_buf, size_t size) { - if (model_buf == nullptr) { - MS_LOG(ERROR) << "The model buf is nullptr"; - return nullptr; - } - flatbuffers::Verifier verify((const uint8_t *)model_buf, size); - if (!schema::VerifyMetaGraphBuffer(verify)) { - MS_LOG(ERROR) << "The buffer is invalid and fail to create graph."; - return nullptr; - } - auto *inner_model_buf = new (std::nothrow) char[size]; - if (inner_model_buf == nullptr) { - MS_LOG(ERROR) << "new model buf fail."; - return nullptr; - } - memcpy(inner_model_buf, model_buf, size); - auto model = new (std::nothrow) ModelImpl(inner_model_buf, size); - if (model == nullptr) { - MS_LOG(ERROR) << "Create modelImpl failed"; - return nullptr; - } - auto ret = model->BuildOps(); - if (0 != ret) { - MS_LOG(ERROR) << "BuildOps failed"; - return nullptr; - } - return model; -} - -lite::Primitive *ModelImpl::GetOp(const std::string &name) const { - auto iter = ops_.find(name); - if (iter == ops_.end()) { - return nullptr; - } else { - return iter->second; - } -} - -ModelImpl::~ModelImpl() { - delete[](this->model_buf_); - for (auto iter : ops_) { - delete (iter.second); - } - ops_.clear(); -} - -void ModelImpl::FreeMetaGraph() { - delete[](this->model_buf_); - model_buf_ = nullptr; -} - -const schema::MetaGraph *ModelImpl::meta_graph() const { return this->meta_graph_; } - -lite::Primitive *ModelImpl::CopyPrimitive(const schema::Primitive *src_prim) { - MS_EXCEPTION_IF_NULL(src_prim); - auto op_type = src_prim->value_type(); - switch (op_type) { - case schema::PrimitiveType_SoftMax: - return new lite::SoftMax(const_cast(src_prim)); - case schema::PrimitiveType_Activation: - return new lite::Activation(const_cast(src_prim)); - case schema::PrimitiveType_Conv2D: - return new lite::Conv2D(const_cast(src_prim)); - case schema::PrimitiveType_DeConv2D: - return new lite::DeConv2D(const_cast(src_prim)); - case schema::PrimitiveType_Reduce: - return new lite::Reduce(const_cast(src_prim)); - case schema::PrimitiveType_Pooling: - return new lite::Pooling(const_cast(src_prim)); - case schema::PrimitiveType_DepthwiseConv2D: - return new lite::DepthwiseConv2D(const_cast(src_prim)); - case schema::PrimitiveType_FusedBatchNorm: - return new lite::FusedBatchNorm(const_cast(src_prim)); - case schema::PrimitiveType_BatchNorm: - return new lite::BatchNorm(const_cast(src_prim)); - case schema::PrimitiveType_FullConnection: - return new lite::FullConnection(const_cast(src_prim)); - case schema::PrimitiveType_Power: - return new lite::Power(const_cast(src_prim)); - case schema::PrimitiveType_Range: - return new lite::Range(const_cast(src_prim)); - case schema::PrimitiveType_Mul: - return new lite::Mul(const_cast(src_prim)); - case schema::PrimitiveType_Add: - return new lite::Add(const_cast(src_prim)); - case schema::PrimitiveType_Sub: - return new lite::Sub(const_cast(src_prim)); - case schema::PrimitiveType_Div: - return new lite::Div(const_cast(src_prim)); - case schema::PrimitiveType_BiasAdd: - return new lite::BiasAdd(const_cast(src_prim)); - case schema::PrimitiveType_ExpandDims: - return new lite::ExpandDims(const_cast(src_prim)); - case schema::PrimitiveType_ArgMax: - return new lite::ArgMax(const_cast(src_prim)); - case schema::PrimitiveType_ArgMin: - return new lite::ArgMin(const_cast(src_prim)); - case schema::PrimitiveType_Cast: - return new lite::Cast(const_cast(src_prim)); - case schema::PrimitiveType_Reshape: - return new lite::Reshape(const_cast(src_prim)); - case schema::PrimitiveType_Scale: - return new lite::Scale(const_cast(src_prim)); - case schema::PrimitiveType_Eltwise: - return new lite::Eltwise(const_cast(src_prim)); - case schema::PrimitiveType_Concat: - return new lite::Concat(const_cast(src_prim)); - case schema::PrimitiveType_Fill: - return new lite::Fill(const_cast(src_prim)); - case schema::PrimitiveType_Transpose: - return new lite::Transpose(const_cast(src_prim)); - case schema::PrimitiveType_Slice: - return new lite::Slice(const_cast(src_prim)); - case schema::PrimitiveType_Squeeze: - return new lite::Squeeze(const_cast(src_prim)); - case schema::PrimitiveType_Nchw2Nhwc: - return new lite::Nchw2Nhwc(const_cast(src_prim)); - case schema::PrimitiveType_Nhwc2Nchw: - return new lite::Nhwc2Nchw(const_cast(src_prim)); - case schema::PrimitiveType_Flatten: - return new lite::Flatten(const_cast(src_prim)); - case schema::PrimitiveType_Mean: - return new lite::Mean(const_cast(src_prim)); - case schema::PrimitiveType_Stack: - return new lite::Stack(const_cast(src_prim)); - case schema::PrimitiveType_Crop: - return new lite::Crop(const_cast(src_prim)); - case schema::PrimitiveType_SquaredDifference: - return new lite::SquaredDifference(const_cast(src_prim)); - case schema::PrimitiveType_AddN: - return new lite::AddN(const_cast(src_prim)); - case schema::PrimitiveType_Abs: - return new lite::Abs(const_cast(src_prim)); - case schema::PrimitiveType_Sin: - return new lite::Sin(const_cast(src_prim)); - case schema::PrimitiveType_Cos: - return new lite::Cos(const_cast(src_prim)); - case schema::PrimitiveType_Log: - return new lite::Log(const_cast(src_prim)); - case schema::PrimitiveType_Sqrt: - return new lite::Sqrt(const_cast(src_prim)); - case schema::PrimitiveType_Rsqrt: - return new lite::Rsqrt(const_cast(src_prim)); - case schema::PrimitiveType_Square: - return new lite::Square(const_cast(src_prim)); - case schema::PrimitiveType_Exp: - return new lite::Exp(const_cast(src_prim)); - case schema::PrimitiveType_Gather: - return new lite::Gather(const_cast(src_prim)); - case schema::PrimitiveType_GatherNd: - return new lite::GatherNd(const_cast(src_prim)); - case schema::PrimitiveType_LocalResponseNormalization: - return new lite::LocalResponseNormalization(const_cast(src_prim)); - case schema::PrimitiveType_Maximum: - return new lite::Maximum(const_cast(src_prim)); - case schema::PrimitiveType_Minimum: - return new lite::Minimum(const_cast(src_prim)); - case schema::PrimitiveType_Pad: - return new lite::Pad(const_cast(src_prim)); - case schema::PrimitiveType_StridedSlice: - return new lite::StridedSlice(const_cast(src_prim)); - case schema::PrimitiveType_Prelu: - return new lite::Prelu(const_cast(src_prim)); - case schema::PrimitiveType_CaffePReLU: - return new lite::CaffePReLU(const_cast(src_prim)); - case schema::PrimitiveType_Round: - return new lite::Round(const_cast(src_prim)); - case schema::PrimitiveType_Reverse: - return new lite::Reverse(const_cast(src_prim)); - case schema::PrimitiveType_ReverseSequence: - return new lite::ReverseSequence(const_cast(src_prim)); - case schema::PrimitiveType_LogicalAnd: - return new lite::LogicalAnd(const_cast(src_prim)); - case schema::PrimitiveType_LogicalOr: - return new lite::LogicalOr(const_cast(src_prim)); - case schema::PrimitiveType_LogicalNot: - return new lite::LogicalNot(const_cast(src_prim)); - case schema::PrimitiveType_FloorDiv: - return new lite::FloorDiv(const_cast(src_prim)); - case schema::PrimitiveType_FloorMod: - return new lite::FloorMod(const_cast(src_prim)); - case schema::PrimitiveType_Equal: - return new lite::Equal(const_cast(src_prim)); - case schema::PrimitiveType_NotEqual: - return new lite::NotEqual(const_cast(src_prim)); - case schema::PrimitiveType_Less: - return new lite::Less(const_cast(src_prim)); - case schema::PrimitiveType_LessEqual: - return new lite::LessEqual(const_cast(src_prim)); - case schema::PrimitiveType_Greater: - return new lite::Greater(const_cast(src_prim)); - case schema::PrimitiveType_GreaterEqual: - return new lite::GreaterEqual(const_cast(src_prim)); - case schema::PrimitiveType_Floor: - return new lite::Floor(const_cast(src_prim)); - case schema::PrimitiveType_Ceil: - return new lite::Ceil(const_cast(src_prim)); - case schema::PrimitiveType_Split: - return new lite::Split(const_cast(src_prim)); - case schema::PrimitiveType_OneHot: - return new lite::OneHot(const_cast(src_prim)); - case schema::PrimitiveType_SpaceToDepth: - return new lite::SpaceToDepth(const_cast(src_prim)); - case schema::PrimitiveType_Tile: - return new lite::Tile(const_cast(src_prim)); - case schema::PrimitiveType_Resize: - return new lite::Resize(const_cast(src_prim)); - case schema::PrimitiveType_Unstack: - return new lite::Unstack(const_cast(src_prim)); - case schema::PrimitiveType_Unique: - return new lite::Unique(const_cast(src_prim)); - case schema::PrimitiveType_TopK: - return new lite::TopK(const_cast(src_prim)); - case schema::PrimitiveType_MatMul: - return new lite::MatMul(const_cast(src_prim)); - case schema::PrimitiveType_QuantDTypeCast: - return new lite::QuantDTypeCast(const_cast(src_prim)); - case schema::PrimitiveType_EmbeddingLookup: - return new lite::EmbeddingLookup(const_cast(src_prim)); - case schema::PrimitiveType_Elu: - return new lite::Elu(const_cast(src_prim)); - case schema::PrimitiveType_DeDepthwiseConv2D: - return new lite::DeconvDepthwiseConv2D(const_cast(src_prim)); - case schema::PrimitiveType_Shape: - return new lite::Shape(const_cast(src_prim)); - default: - break; - } - return nullptr; -} - -int ModelImpl::BuildOps() { - if (this->meta_graph_ == nullptr) { - MS_LOG(ERROR) << "mete_graph is nullptr"; - return -1; - } - MS_EXCEPTION_IF_NULL(meta_graph_->nodes()); - for (size_t i = 0; i < meta_graph_->nodes()->size(); i++) { - auto cNode = meta_graph_->nodes()->GetAs(i); - auto name = cNode->name()->str(); - auto srcPrim = cNode->primitive(); - - this->ops_[name] = CopyPrimitive(srcPrim); - // flatbuffers::FlatBufferBuilder fbb(1024); - // schema::Conv2DBuilder conv2DBuilder(fbb); - // conv2DBuilder.add_padMode(srcPrim->value_as_Conv2D()->padMode()); - // conv2DBuilder.add_channelOut(srcPrim->value_as_Conv2D()->channelOut()); - // conv2DBuilder.add_channelIn(srcPrim->value_as_Conv2D()->channelIn()); - // conv2DBuilder.add_strideH(srcPrim->value_as_Conv2D()->strideH()); - // conv2DBuilder.add_strideW(srcPrim->value_as_Conv2D()->strideW()); - // conv2DBuilder.add_dilateH(srcPrim->value_as_Conv2D()->dilateH()); - // conv2DBuilder.add_dilateW(srcPrim->value_as_Conv2D()->dilateW()); - // conv2DBuilder.add_kernelH(srcPrim->value_as_Conv2D()->kernelH()); - // conv2DBuilder.add_kernelW(srcPrim->value_as_Conv2D()->kernelW()); - // conv2DBuilder.add_padUp(srcPrim->value_as_Conv2D()->padUp()); - // conv2DBuilder.add_padDown(srcPrim->value_as_Conv2D()->padDown()); - // conv2DBuilder.add_padLeft(srcPrim->value_as_Conv2D()->padLeft()); - // conv2DBuilder.add_padRight(srcPrim->value_as_Conv2D()->padRight()); - // conv2DBuilder.add_format(srcPrim->value_as_Conv2D()->format()); - // conv2DBuilder.add_group(srcPrim->value_as_Conv2D()->group()); - // conv2DBuilder.add_activationType(srcPrim->value_as_Conv2D()->activationType()); - // schema::PrimitiveBuilder primBuilder(fbb); - // primBuilder.add_value_type(srcPrim->value_type()); - // primBuilder.add_value(conv2DBuilder.Finish()); - // - // fbb.Finish(conv2DBuilder.Finish()); - // auto buf = fbb.GetBufferPointer(); - // auto conv2D = flatbuffers::GetRoot(buf); - // fbb.Clear(); - // - // return const_cast(opDef); - } - return 0; -} -} // namespace mindspore::lite diff --git a/mindspore/lite/src/model_impl.h b/mindspore/lite/src/model_impl.h deleted file mode 100644 index 82a74e1bf6..0000000000 --- a/mindspore/lite/src/model_impl.h +++ /dev/null @@ -1,53 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_SRC_MODEL_IMPL_H_ -#define MINDSPORE_LITE_SRC_MODEL_IMPL_H_ - -#include -#include -#include -#include "schema/model_generated.h" -#include "src/ops/ops.h" - -namespace mindspore { -namespace lite { -class ModelImpl { - public: - static ModelImpl *Import(const char *model_buf, size_t size); - ModelImpl() = default; - explicit ModelImpl(const char *model_buf, size_t size) : model_buf_(model_buf), buf_size_(size) { - meta_graph_ = schema::GetMetaGraph(model_buf); - } - virtual ~ModelImpl(); - lite::Primitive *GetOp(const std::string &name) const; - const schema::MetaGraph *meta_graph() const; - void FreeMetaGraph(); - int BuildOps(); - - protected: - lite::Primitive *CopyPrimitive(const schema::Primitive *src_prim); - - protected: - const char *model_buf_; - size_t buf_size_; - const schema::MetaGraph *meta_graph_ = nullptr; - std::map ops_; -}; -} // namespace lite -} // namespace mindspore - -#endif // MINDSPORE_LITE_INCLUDE_MODEL_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/base/arg_min_max_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/arg_min_max_base.cc index 1f05c73e86..6fbf69ddbd 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/arg_min_max_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/arg_min_max_base.cc @@ -19,7 +19,7 @@ #include "src/runtime/kernel/arm/int8/argminmax_int8.h" #include "src/runtime/kernel/arm/nnacl/arithmetic_common.h" #include "schema/model_generated.h" -#include "src/kernel_factory.h" +#include "src/kernel_registry.h" #include "include/errorcode.h" #include "include/context.h" diff --git a/mindspore/lite/src/runtime/kernel/arm/base/batch_to_space_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/batch_to_space_base.cc index acc9ab8ef2..56ca65067a 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/batch_to_space_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/batch_to_space_base.cc @@ -18,7 +18,7 @@ #include "src/runtime/kernel/arm/fp32/batch_to_space.h" #include "src/runtime/kernel/arm/int8/batch_to_space_int8.h" #include "schema/model_generated.h" -#include "src/kernel_factory.h" +#include "src/kernel_registry.h" #include "include/errorcode.h" #include "include/context.h" diff --git a/mindspore/lite/src/runtime/kernel/arm/base/caffeprelu_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/caffeprelu_base.cc index 99736192bc..dd4f9ed296 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/caffeprelu_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/caffeprelu_base.cc @@ -16,7 +16,7 @@ #include "src/runtime/kernel/arm/base/caffeprelu_base.h" #include #include "schema/model_generated.h" -#include "src/kernel_factory.h" +#include "src/kernel_registry.h" #include "include/errorcode.h" #include "include/context.h" diff --git a/mindspore/lite/src/runtime/kernel/arm/base/concat_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/concat_base.cc index e28d88b771..756492767e 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/concat_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/concat_base.cc @@ -19,7 +19,7 @@ #include "src/runtime/kernel/arm/fp32/concat.h" #include "src/runtime/kernel/arm/nnacl/fp32/concat.h" #include "schema/model_generated.h" -#include "src/kernel_factory.h" +#include "src/kernel_registry.h" #include "include/errorcode.h" #include "include/context.h" diff --git a/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.cc index dddfe4e934..f1fada6ad1 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.cc @@ -17,7 +17,7 @@ #include "src/runtime/kernel/arm/base/convolution_base.h" #include #include "schema/model_generated.h" -#include "src/kernel_factory.h" +#include "src/kernel_registry.h" #include "include/errorcode.h" using mindspore::lite::KernelRegistrar; diff --git a/mindspore/lite/src/runtime/kernel/arm/base/crop_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/crop_base.cc index be27b1f9cc..0fc913bcfa 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/crop_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/crop_base.cc @@ -18,7 +18,7 @@ #include "src/runtime/kernel/arm/int8/crop_int8.h" #include "src/runtime/kernel/arm/fp32/crop.h" #include "schema/model_generated.h" -#include "src/kernel_factory.h" +#include "src/kernel_registry.h" #include "include/errorcode.h" #include "include/context.h" diff --git a/mindspore/lite/src/runtime/kernel/arm/base/depth_to_space_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/depth_to_space_base.cc index a5b284d0c2..3289b11a95 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/depth_to_space_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/depth_to_space_base.cc @@ -19,7 +19,7 @@ #include "src/runtime/kernel/arm/int8/depth_to_space_int8.h" #include "src/runtime/kernel/arm/nnacl/arithmetic_common.h" #include "schema/model_generated.h" -#include "src/kernel_factory.h" +#include "src/kernel_registry.h" #include "include/errorcode.h" #include "include/context.h" diff --git a/mindspore/lite/src/runtime/kernel/arm/base/fullconnection_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/fullconnection_base.cc index ecdf77cad4..41c13d0e0a 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/fullconnection_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/fullconnection_base.cc @@ -17,7 +17,7 @@ #include "src/runtime/kernel/arm/int8/fullconnection_int8.h" #include "src/runtime/kernel/arm/fp32/fullconnection.h" #include "schema/model_generated.h" -#include "src/kernel_factory.h" +#include "src/kernel_registry.h" #include "include/errorcode.h" #include "include/context.h" diff --git a/mindspore/lite/src/runtime/kernel/arm/base/matmul_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/matmul_base.cc index 4d0eb57420..ee3ed37599 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/matmul_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/matmul_base.cc @@ -16,7 +16,7 @@ #include "src/runtime/kernel/arm/base/matmul_base.h" #include "src/runtime/kernel/arm/fp32/matmul.h" #include "src/runtime/kernel/arm/int8/matmul_int8.h" -#include "src/kernel_factory.h" +#include "src/kernel_registry.h" #include "include/errorcode.h" #include "include/context.h" diff --git a/mindspore/lite/src/runtime/kernel/arm/base/pad.cc b/mindspore/lite/src/runtime/kernel/arm/base/pad.cc index 2e071706a7..4a4933c95e 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/pad.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/pad.cc @@ -17,7 +17,7 @@ #include "src/runtime/kernel/arm/fp32/pad.h" #include "src/runtime/kernel/arm/int8/pad_int8.h" #include "schema/model_generated.h" -#include "src/kernel_factory.h" +#include "src/kernel_registry.h" #include "include/errorcode.h" #include "include/context.h" diff --git a/mindspore/lite/src/runtime/kernel/arm/base/pooling_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/pooling_base.cc index c1decb1082..84b44af038 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/pooling_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/pooling_base.cc @@ -18,7 +18,7 @@ #include "src/runtime/kernel/arm/int8/pooling_int8.h" #include "src/runtime/kernel/arm/fp32/pooling.h" #include "schema/model_generated.h" -#include "src/kernel_factory.h" +#include "src/kernel_registry.h" #include "include/errorcode.h" #include "include/context.h" diff --git a/mindspore/lite/src/runtime/kernel/arm/base/prelu_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/prelu_base.cc index f590883576..4b232f6f3b 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/prelu_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/prelu_base.cc @@ -17,7 +17,7 @@ #include #include "src/runtime/kernel/arm/int8/prelu_int8.h" #include "schema/model_generated.h" -#include "src/kernel_factory.h" +#include "src/kernel_registry.h" #include "include/errorcode.h" #include "include/context.h" diff --git a/mindspore/lite/src/runtime/kernel/arm/base/prior_box.cc b/mindspore/lite/src/runtime/kernel/arm/base/prior_box.cc index a90a2ab07f..bd733a58b5 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/prior_box.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/prior_box.cc @@ -18,7 +18,7 @@ #include #include "src/runtime/kernel/arm/base/prior_box.h" #include "schema/model_generated.h" -#include "src/kernel_factory.h" +#include "src/kernel_registry.h" #include "include/errorcode.h" #include "include/context.h" #include "src/runtime/runtime_api.h" diff --git a/mindspore/lite/src/runtime/kernel/arm/base/reshape_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/reshape_base.cc index 1d9950f8b9..6553aa1a6c 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/reshape_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/reshape_base.cc @@ -18,7 +18,7 @@ #include "src/runtime/kernel/arm/int8/reshape_int8.h" #include "src/runtime/kernel/arm/fp32/reshape.h" #include "schema/model_generated.h" -#include "src/kernel_factory.h" +#include "src/kernel_registry.h" #include "include/errorcode.h" #include "include/context.h" diff --git a/mindspore/lite/src/runtime/kernel/arm/base/softmax_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/softmax_base.cc index 8e210aa086..a35be0765c 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/softmax_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/softmax_base.cc @@ -20,7 +20,7 @@ #include "src/runtime/kernel/arm/fp32/softmax.h" #include "src/runtime/kernel/arm/nnacl/fp32/softmax.h" #include "schema/model_generated.h" -#include "src/kernel_factory.h" +#include "src/kernel_registry.h" #include "include/errorcode.h" using mindspore::lite::KernelRegistrar; diff --git a/mindspore/lite/src/runtime/kernel/arm/base/split_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/split_base.cc index 8768226fd3..b7786daeca 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/split_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/split_base.cc @@ -18,7 +18,7 @@ #include "src/runtime/kernel/arm/int8/split_int8.h" #include "src/runtime/kernel/arm/fp32/split.h" #include "schema/model_generated.h" -#include "src/kernel_factory.h" +#include "src/kernel_registry.h" #include "include/errorcode.h" #include "include/context.h" diff --git a/mindspore/lite/src/runtime/kernel/arm/base/squeeze_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/squeeze_base.cc index 76dc57002a..3d285787da 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/squeeze_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/squeeze_base.cc @@ -17,7 +17,7 @@ #include #include "src/runtime/kernel/arm/int8/squeeze_int8.h" #include "schema/model_generated.h" -#include "src/kernel_factory.h" +#include "src/kernel_registry.h" #include "include/errorcode.h" #include "include/context.h" diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_base_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_base_fp16.cc index 55bfa4d583..4343fb4ad1 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_base_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_base_fp16.cc @@ -18,7 +18,7 @@ #include "src/runtime/kernel/arm/nnacl/fp16/cast_fp16.h" #include "src/runtime/kernel/arm/fp16/common_fp16.h" #include "schema/model_generated.h" -#include "src/kernel_factory.h" +#include "src/kernel_registry.h" #include "include/errorcode.h" #include "src/runtime/runtime_api.h" diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution.cc index 6eaf63f607..12fedfc17c 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution.cc @@ -22,7 +22,7 @@ #include "src/runtime/kernel/arm/nnacl/fp32/conv.h" #include "src/runtime/kernel/arm/nnacl/common_func.h" #include "schema/model_generated.h" -#include "src/kernel_factory.h" +#include "src/kernel_registry.h" #include "include/errorcode.h" #include "src/runtime/runtime_api.h" diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_slidewindow.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_slidewindow.cc index 7003610e1d..cdcbda7199 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_slidewindow.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_slidewindow.cc @@ -17,7 +17,7 @@ #include "src/runtime/kernel/arm/fp32/convolution_slidewindow.h" #include "src/runtime/kernel/arm/nnacl/common_func.h" #include "schema/model_generated.h" -#include "src/kernel_factory.h" +#include "src/kernel_registry.h" #include "include/errorcode.h" #include "src/runtime/runtime_api.h" diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/transpose.h b/mindspore/lite/src/runtime/kernel/arm/fp32/transpose.h index bfff0298d7..75460a09de 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/transpose.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/transpose.h @@ -20,7 +20,7 @@ #include #include "src/lite_kernel.h" -#include "src/kernel_factory.h" +#include "src/kernel_registry.h" namespace mindspore::kernel { diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/bn_grad.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/bn_grad.cc index 752a8440e9..cc073ef539 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/bn_grad.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/bn_grad.cc @@ -17,7 +17,7 @@ #include #include #include "schema/model_generated.h" -#include "src/kernel_factory.h" +#include "src/kernel_registry.h" #include "src/runtime/kernel/arm/fp32_grad/bn_grad.h" #include "src/runtime/kernel/arm/nnacl/fp32_grad/batch_norm.h" #include "include/errorcode.h" diff --git a/mindspore/lite/src/scheduler.cc b/mindspore/lite/src/scheduler.cc index a264e43915..06afe9896d 100644 --- a/mindspore/lite/src/scheduler.cc +++ b/mindspore/lite/src/scheduler.cc @@ -18,7 +18,7 @@ #include #include #include "include/errorcode.h" -#include "src/kernel_factory.h" +#include "src/kernel_registry.h" #include "src/common/graph_util.h" #include "src/common/utils.h" #if SUPPORT_GPU @@ -191,7 +191,7 @@ kernel::LiteKernel *Scheduler::ScheduleNode(const std::vector kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, data_type, primitive->Type()}; if (context_->device_ctx_.type == DT_GPU) { desc.arch = kernel::KERNEL_ARCH::kGPU; - auto *kernel = KernelFactory::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, desc); + auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, desc); if (nullptr != kernel) { kernel->set_desc(desc); return kernel; @@ -203,7 +203,7 @@ kernel::LiteKernel *Scheduler::ScheduleNode(const std::vector if ((context_->float16_priority && data_type == kNumberTypeFloat32) || data_type == kNumberTypeFloat16) { // check if support fp16 kernel::KernelKey key{desc.arch, kNumberTypeFloat16, desc.type}; - kernel = KernelFactory::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, key); + kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, key); if (kernel != nullptr) { MS_LOG(DEBUG) << "Get fp16 op success."; desc.data_type = kNumberTypeFloat16; @@ -215,7 +215,7 @@ kernel::LiteKernel *Scheduler::ScheduleNode(const std::vector if (data_type == kNumberTypeFloat16) { desc.data_type = kNumberTypeFloat32; } - kernel = KernelFactory::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, desc); + kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, desc); if (kernel != nullptr) { kernel->set_desc(desc); return kernel; diff --git a/mindspore/lite/src/train/base_ref_utils.cc b/mindspore/lite/src/train/base_ref_utils.cc deleted file mode 100644 index 61a39d38cb..0000000000 --- a/mindspore/lite/src/train/base_ref_utils.cc +++ /dev/null @@ -1,59 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/train/base_ref_utils.h" -#include -#include -// #include "utils/base_ref_utils.h" -#include "include/ms_tensor.h" -#include "src/ir/tensor.h" - -namespace mindspore { -std::vector> TransformBaseRefToMSTensor(const BaseRef &base_ref) { - std::vector> msTensors; - if (utils::isa(base_ref)) { - auto ref_list = utils::cast(base_ref); - for (size_t i = 0; i < ref_list.size(); ++i) { - if (utils::isa(ref_list[i])) { - auto tensor_ptr = utils::cast>(ref_list[i]); - MS_EXCEPTION_IF_NULL(tensor_ptr); - auto tensor = new tensor::LiteTensor(new tensor::Tensor(*tensor_ptr)); - msTensors.emplace_back(std::shared_ptr(tensor)); - } else { - MS_LOG(EXCEPTION) << "The output is not a tensor!"; - } - } - } else if (utils::isa(base_ref)) { - auto tensor_ptr = utils::cast>(base_ref); - MS_EXCEPTION_IF_NULL(tensor_ptr); - auto tensor = new tensor::LiteTensor(new tensor::Tensor(*tensor_ptr)); - msTensors.emplace_back(std::shared_ptr(tensor)); - } else { - MS_LOG(EXCEPTION) << "The output is not a base ref list or a tensor!"; - } - return msTensors; -} - -std::vector>> TransformVectorRefToMultiTensor( - const VectorRef &vector_ref) { - std::vector>> multiTensor; - for (size_t i = 0; i < vector_ref.size(); ++i) { - auto tensors = TransformBaseRefToMSTensor(vector_ref[i]); - multiTensor.emplace_back(tensors); - } - return multiTensor; -} -} // namespace mindspore diff --git a/mindspore/lite/src/train/base_ref_utils.h b/mindspore/lite/src/train/base_ref_utils.h deleted file mode 100644 index 2d4620ead6..0000000000 --- a/mindspore/lite/src/train/base_ref_utils.h +++ /dev/null @@ -1,30 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include "utils/base_ref.h" -#include "include/ms_tensor.h" - -#ifndef MINDSPORE_LITE_SRC_TRAIN_BASE_REF_UTILS_H_ -#define MINDSPORE_LITE_SRC_TRAIN_BASE_REF_UTILS_H_ -namespace mindspore { -std::vector> TransformBaseRefToMSTensor(const BaseRef &base_ref); - -std::vector>> TransformVectorRefToMultiTensor( - const VectorRef &vector_ref); -} // namespace mindspore -#endif // MINDSPORE_LITE_SRC_TRAIN_BASE_REF_UTILS_H_ diff --git a/mindspore/lite/src/train/import.hpp b/mindspore/lite/src/train/import.hpp deleted file mode 100644 index 12a8bf1244..0000000000 --- a/mindspore/lite/src/train/import.hpp +++ /dev/null @@ -1,50 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include "src/common/anf_importer/import_from_meta_graph.h" -namespace mindspore::lite::train { -std::shared_ptr Import(const char *model_buf, size_t size) { - MS_EXCEPTION_IF_NULL(model_buf); - flatbuffers::Verifier verify((const uint8_t *) model_buf, size); - if (!schema::VerifyMetaGraphBuffer(verify)) { - MS_LOG(ERROR) << "The buffer is invalid and fail to create graph."; - return nullptr; - } - // todo hangangqiang remove when copy primitive done - if (size <= 0) { - MS_LOG(ERROR) << "size is zero"; - return nullptr; - } - auto *inner_buf = new char[size]; - memcpy(inner_buf, model_buf, size); - auto meta_graph = schema::GetMetaGraph(inner_buf); - auto model = std::make_shared(meta_graph); - auto ret = model->BuildOps(); - if (0 != ret) { - MS_LOG(ERROR) << "BuildOps failed"; - return nullptr; - } - MS_EXCEPTION_IF_NULL(meta_graph); - auto importer = new AnfImporterFromMetaGraph(model); - auto ret2 = importer->Import(); - if (0 != ret2) { - MS_LOG(ERROR) << "Import anf_graph from meta_graph failed, ret2: " << ret2; - return nullptr; - } - return model; -} -} // namespace mindspore::lite::train diff --git a/mindspore/lite/src/train/lite_kernel_runtime.cc b/mindspore/lite/src/train/lite_kernel_runtime.cc deleted file mode 100644 index 656e8b3bb7..0000000000 --- a/mindspore/lite/src/train/lite_kernel_runtime.cc +++ /dev/null @@ -1,86 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/train/lite_kernel_runtime.h" -#include "backend/session/anf_runtime_algorithm.h" -namespace mindspore::lite { -std::vector LiteInferKernelRuntime::GetGraphInputs(const std::vector &execution_order) { - std::vector graph_inputs; - for (const auto &cnode : execution_order) { - bool is_graph_inputs = true; - for (const auto &input : cnode->inputs()) { - if (input->isa()) { - is_graph_inputs = false; - break; - } - } - if (is_graph_inputs) { - graph_inputs.emplace_back(cnode); - } - } - return graph_inputs; -} - -void LiteInferKernelRuntime::BindInputOutput(const session::KernelGraph *graph, - const std::vector &inputs, - std::vector *outputs) { - MS_EXCEPTION_IF_NULL(graph); - auto execution_order = graph->execution_order(); - auto graph_inputs = GetGraphInputs(execution_order); - int input_count = 0; - for (const auto &graph_input : graph_inputs) { - auto liteKernel = dynamic_cast(AnfAlgo::GetKernelMod(graph_input)); - for (auto input_tensor : liteKernel->GetInputs()) { - if (schema::NodeType_ValueNode == input_tensor->TensorType() && input_tensor->Data() != nullptr) { - continue; - } - input_tensor->SetData(inputs[input_count]->Data()); - input_count++; - } - } - - auto return_node = graph->get_return(); - for (const auto &return_input : return_node->inputs()) { - if (return_input->isa()) { - auto liteKernel = dynamic_cast(AnfAlgo::GetKernelMod(return_input)); - auto output_tensors = liteKernel->GetOutputs(); - for (auto output_tensor : output_tensors) { - // tensor::TensorPtr output_tensor_ptr(output_tensor); - outputs->push_back(output_tensor); - } - } - } -} - -bool LiteInferKernelRuntime::Run(session::KernelGraph *graph, const std::vector &inputs, - std::vector *outputs) { - MS_EXCEPTION_IF_NULL(graph); - BindInputOutput(graph, inputs, *outputs); - std::vector kernels; - auto nodes = graph->execution_order(); - for (const auto &node : nodes) { - auto liteKernel = dynamic_cast(AnfAlgo::GetKernelMod(node)); - if (liteKernel == nullptr) { - continue; - } - kernels.emplace_back(liteKernel); - } - kernel::LiteKernelUtil::TopologicalSortKernels(kernels); - Executor executor; - auto ret = executor.Run(inputs, *outputs, kernels); - return 0 == ret; -} -} // namespace mindspore::lite diff --git a/mindspore/lite/src/train/lite_kernel_runtime.h b/mindspore/lite/src/train/lite_kernel_runtime.h deleted file mode 100644 index c5ae2d04d9..0000000000 --- a/mindspore/lite/src/train/lite_kernel_runtime.h +++ /dev/null @@ -1,50 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_SRC_TRAIN_LITE_KERNEL_RUNTIME_H_ -#define MINDSPORE_LITE_SRC_TRAIN_LITE_KERNEL_RUNTIME_H_ - -#include -#include -#include -#include -#include "src/runtime/allocator.h" -#include "src/executor.h" -// #include "runtime/device/kernel_runtime.h" -#include "runtime/device/device_address.h" -#include "src/lite_kernel.h" -#include "backend/session/kernel_graph.h" -namespace mindspore::lite { -class LiteInferKernelRuntime { - public: - LiteInferKernelRuntime() = default; - ~LiteInferKernelRuntime() = default; - - bool Run(session::KernelGraph *graph, const std::vector &inputs, - std::vector *outputs); - - void AssignKernelAddress(session::KernelGraph *graph) {} - - protected: - void BindInputOutput(const session::KernelGraph *graph, const std::vector &inputs, - std::vector *outputs); - - std::vector GetGraphInputs(const std::vector &execution_order); -}; - -} // namespace mindspore::lite - -#endif // MINDSPORE_LITE_SRC_TRAIN_LITE_KERNEL_RUNTIME_H_ diff --git a/mindspore/lite/src/train/model_impl.cc b/mindspore/lite/src/train/model_impl.cc deleted file mode 100644 index 84794dfaaa..0000000000 --- a/mindspore/lite/src/train/model_impl.cc +++ /dev/null @@ -1,145 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include "src/train/model_impl.h" -#include "ir/func_graph.h" -#include "schema/model_generated.h" -#include "src/common/anf_importer/import_from_meta_graph.h" - -namespace mindspore::lite::train { - -std::shared_ptr ModelImpl::Import(const char *model_buf, size_t size) { - MS_EXCEPTION_IF_NULL(model_buf); - flatbuffers::Verifier verify((const uint8_t *)model_buf, size); - if (!schema::VerifyMetaGraphBuffer(verify)) { - MS_LOG(ERROR) << "The buffer is invalid and fail to create graph."; - return nullptr; - } - // todo hangangqiang remove when copy primitive done - auto *inner_buf = new char[size]; - memcpy(inner_buf, model_buf, size); - auto meta_graph = schema::GetMetaGraph(inner_buf); - auto func_graph_model = std::make_shared(meta_graph); - auto ret = func_graph_model->BuildOps(); - if (0 != ret) { - MS_LOG(ERROR) << "BuildOps failed"; - return nullptr; - } - AnfImporterFromMetaGraph anfImporter(func_graph_model); - anfImporter.Import(); - return func_graph_model; -} - -const lite::Primitive *ModelImpl::GetOp(const std::string &name) const { - auto iter = ops.find(name); - if (iter == ops.end()) { - return nullptr; - } else { - return iter->second; - } -} - -void ModelImpl::FreeMetaGraph() { delete this->meta_graph; } - -const schema::MetaGraph *ModelImpl::GetMetaGraph() const { return this->meta_graph; } - -lite::Primitive *ModelImpl::CopyPrimitive(const schema::Primitive *srcPrim) { - MS_EXCEPTION_IF_NULL(srcPrim); - auto op_type = srcPrim->value_type(); - switch (op_type) { - case schema::PrimitiveType_SoftMax: - return new lite::SoftMax(const_cast(srcPrim)); - case schema::PrimitiveType_Activation: - return new lite::Activation(const_cast(srcPrim)); - case schema::PrimitiveType_Conv2D: - return new lite::Conv2D(const_cast(srcPrim)); - case schema::PrimitiveType_Reduce: - return new lite::Reduce(const_cast(srcPrim)); - case schema::PrimitiveType_Pooling: - return new lite::Pooling(const_cast(srcPrim)); - case schema::PrimitiveType_DepthwiseConv2D: - return new lite::DepthwiseConv2D(const_cast(srcPrim)); - case schema::PrimitiveType_FusedBatchNorm: - return new lite::FusedBatchNorm(const_cast(srcPrim)); - case schema::PrimitiveType_CaffeBatchNorm: - return new lite::CaffeBatchNorm(const_cast(srcPrim)); - case schema::PrimitiveType_FullConnection: - return new lite::FullConnection(const_cast(srcPrim)); - case schema::PrimitiveType_Power: - return new lite::Power(const_cast(srcPrim)); - case schema::PrimitiveType_Range: - return new lite::Range(const_cast(srcPrim)); - case schema::PrimitiveType_Mul: - return new lite::Mul(const_cast(srcPrim)); - case schema::PrimitiveType_Add: - return new lite::Add(const_cast(srcPrim)); - case schema::PrimitiveType_Sub: - return new lite::Sub(const_cast(srcPrim)); - case schema::PrimitiveType_Div: - return new lite::Div(const_cast(srcPrim)); - case schema::PrimitiveType_BiasAdd: - return new lite::BiasAdd(const_cast(srcPrim)); - case schema::PrimitiveType_ExpandDims: - return new lite::ExpandDims(const_cast(srcPrim)); - case schema::PrimitiveType_ArgMax: - return new lite::ArgMax(const_cast(srcPrim)); - case schema::PrimitiveType_ArgMin: - return new lite::ArgMin(const_cast(srcPrim)); - case schema::PrimitiveType_Cast: - return new lite::Cast(const_cast(srcPrim)); - case schema::PrimitiveType_Reshape: - return new lite::Reshape(const_cast(srcPrim)); - case schema::PrimitiveType_Scale: - return new lite::Scale(const_cast(srcPrim)); - case schema::PrimitiveType_Eltwise: - return new lite::Eltwise(const_cast(srcPrim)); - case schema::PrimitiveType_Ceil: - return new lite::Ceil(const_cast(srcPrim)); - case schema::PrimitiveType_Concat: - return new lite::Concat(const_cast(srcPrim)); - case schema::PrimitiveType_Fill: - return new lite::Fill(const_cast(srcPrim)); - case schema::PrimitiveType_Transpose: - return new lite::Transpose(const_cast(srcPrim)); - case schema::PrimitiveType_Slice: - return new lite::Slice(const_cast(srcPrim)); - case schema::PrimitiveType_Nchw2Nhwc: - return new lite::Nchw2Nhwc(const_cast(srcPrim)); - case schema::PrimitiveType_Nhwc2Nchw: - return new lite::Nhwc2Nchw(const_cast(srcPrim)); - case schema::PrimitiveType_MatMul: - return new lite::MatMul(const_cast(srcPrim)); - default: - break; - } - return nullptr; -} - -int ModelImpl::BuildOps() { - if (this->meta_graph == nullptr) { - MS_LOG(ERROR) << "mete_graph is nullptr"; - return -1; - } - for (size_t i = 0; i < meta_graph->nodes()->size(); i++) { - auto cNode = meta_graph->nodes()->GetAs(i); - auto name = cNode->name()->str(); - auto srcPrim = cNode->primitive(); - this->ops[name] = CopyPrimitive(srcPrim); - } - return 0; -} -} // namespace mindspore::lite::train diff --git a/mindspore/lite/src/train/model_impl.h b/mindspore/lite/src/train/model_impl.h deleted file mode 100644 index d35956a855..0000000000 --- a/mindspore/lite/src/train/model_impl.h +++ /dev/null @@ -1,64 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_SRC_TRAIN_MODEL_IMPL_H_ -#define MINDSPORE_LITE_SRC_TRAIN_MODEL_IMPL_H_ - -#include -#include -#include -#include -#include "schema/model_generated.h" -#include "src/ops/ops.h" -#include "ir/func_graph.h" - -namespace mindspore::lite { -namespace train { -class ModelImpl : public FuncGraph { - public: - static std::shared_ptr Import(const char *model_buf, size_t size); // { return NULL; }; - ModelImpl() = default; - explicit ModelImpl(const schema::MetaGraph *graph) : meta_graph(graph) {} - ~ModelImpl() override = default; - const lite::Primitive *GetOp(const std::string &name) const; - const schema::MetaGraph *GetMetaGraph() const; - void FreeMetaGraph(); - int BuildOps(); - - void AddCNodeInputOutput(std::string name, const std::vector &input, const std::vector &output) { - std::vector *tuple = new std::vector[2]; - tuple[0] = input; - tuple[1] = output; - connectivity_[name] = tuple; - } - std::vector *GetCNodeInputOutputIndices(std::string name) { return connectivity_[name]; } - void AddAnfNode(int id, AnfNodePtr anf_ptr) { tensors_[id] = anf_ptr; } - AnfNodePtr GetAnfNode(int id) { return tensors_[id]; } - - protected: - lite::Primitive *CopyPrimitive(const schema::Primitive *srcPrim); - - protected: - const schema::MetaGraph *meta_graph = nullptr; - std::map tensors_; - std::map *> connectivity_; - std::map ops; -}; -} // namespace train -using ModelImpl = mindspore::lite::train::ModelImpl; -} // namespace mindspore::lite - -#endif // MINDSPORE_LITE_SRC_TRAIN_MODEL_IMPL_H_ diff --git a/mindspore/lite/src/train/train_anf_session.cc b/mindspore/lite/src/train/train_anf_session.cc deleted file mode 100644 index 9e7fb5b506..0000000000 --- a/mindspore/lite/src/train/train_anf_session.cc +++ /dev/null @@ -1,253 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include "src/train/train_anf_session.h" -#include "include/context.h" -#include "mindspore/ccsrc/runtime/device/kernel_info.h" -#include "mindspore/lite/src/train/train_session.h" -#include "mindspore/lite/src/kernel_factory.h" -#include "mindspore/lite/src/param_value_lite.h" -#include "common/utils.h" -#include "mindspore/lite/src/ops/ops.h" -#include "ir/anf.h" -#include "mindspore/lite/src/ir/tensor.h" -#include "abstract/abstract_value.h" -#include "backend/session/anf_runtime_algorithm.h" -#include "src/ir/primitive_value.h" -#include "src/train/model_impl.h" - -namespace mindspore { -namespace session { -static std::vector GetAnfNodeOutDims(const AnfNodePtr &anfNodePtr) { - auto nodeAbstract = anfNodePtr->abstract(); - if (nodeAbstract != nullptr) { - auto shape = nodeAbstract->GetShapeTrack(); - if (!shape->isa()) { - MS_LOG(EXCEPTION) << "Not a Shape"; - return {}; - } - auto dims = dyn_cast(shape)->shape(); - return dims; - } else { - MS_LOG(WARNING) << "abstract is nullptr, return empty dims"; - return {}; - } -} - -static schema::Format GetAnfNodeFormat(const AnfNodePtr &anfNodePtr) { - auto nodeAbstract = anfNodePtr->abstract(); - if (nodeAbstract != nullptr) { - return schema::Format_NHWC; // XXX TODO -- extract Format from AnfNode - } else { - MS_LOG(WARNING) << "abstract is nullptr, return schema::Format_NHWC"; - return schema::Format_NHWC; - } -} - -static TypeId GetAnfNodeOutTypeId(const AnfNodePtr &anfNodePtr) { - auto nodeAbstract = anfNodePtr->abstract(); - if (nodeAbstract != nullptr) { - return TypeId::kNumberTypeFloat32; // XXX TODO nodeAbstract->GetTypeTrack()->generic_type_id(); - } else { - MS_LOG(WARNING) << "abstract is nullptr, return kTypeUnknown"; - return TypeId::kTypeUnknown; - } -} - -void TrainANFSession::Init(lite::Context *context) { - MS_EXCEPTION_IF_NULL(context); - this->context_ = std::make_shared(context->thread_num_, context->allocator, context->device_ctx_); -} - -lite::tensor::Tensor *TrainANFSession::GetTensorForAnfNode(const AnfNodePtr anf_node) { - lite::tensor::Tensor *out_tensor = tensors_[anf_node]; - if (out_tensor == NULL) { - out_tensor = new lite::tensor::Tensor(GetAnfNodeOutTypeId(anf_node), - GetAnfNodeOutDims(anf_node)); //, schema::NodeType_Parameter); - tensors_[anf_node] = out_tensor; - } - return out_tensor; -} - -int TrainANFSession::BuildKernelInputAndOutputFromFuncGraph(const KernelGraphPtr &kernel_graph) { - auto return_node = kernel_graph->get_return(); - auto node_list = TopoSort(return_node); - auto model_imp = std::dynamic_pointer_cast(func_graph_); - for (auto &node : node_list) { - if (!node->isa()) { - continue; - } - KernelRelation kernel_relation; - auto cnode = node->cast(); - kernel_relation.node_full_name = cnode->fullname_with_scope(); - kernel_relation.cnode = cnode; - std::vector *cnode_io_indices = model_imp->GetCNodeInputOutputIndices(cnode->fullname_with_scope()); - if (cnode_io_indices == NULL) { - MS_LOG(WARNING) << "No IO vectors for " << cnode->fullname_with_scope(); - } else { - for (int i = 0; i < cnode_io_indices[1].size(); i++) { - AnfNodePtr anf_node = model_imp->GetAnfNode(cnode_io_indices[1].data()[i]); - kernel_relation.output_tensor.push_back(GetTensorForAnfNode(anf_node)); - } - } - lite::tensor::Tensor *tensor_ptr = nullptr; - for (size_t index = 1; index < cnode->inputs().size(); ++index) { - if (cnode->input(index)->isa()) { - auto input_cnode = cnode->input(index)->cast(); - auto input_kernel_relation = kernel_relation_infos_[input_cnode->fullname_with_scope()]; - // todo not support multi-outputs kernel sudo as spilt - tensor_ptr = input_kernel_relation.output_tensor.front(); - } else if (cnode->input(index)->isa()) { - auto input_parameter = cnode->input(index)->cast(); - auto para = input_parameter->default_param(); - auto param_value = std::dynamic_pointer_cast(para); - // auto dims = param_value->tensor_shape(); - // tensor_ptr = new lite::tensor::Tensor(param_value->tensor_type(), dims); // schema::NodeType_ValueNode); - tensor_ptr = GetTensorForAnfNode(cnode->input(index)); - if ((param_value != nullptr) && (param_value->tensor_size() != 0)) { - tensor_ptr->SetData(param_value->tensor_addr()); - } - } else if (cnode->input(index)->isa()) { - auto input_valuenode = cnode->input(index)->cast(); - // tensor_ptr = new lite::tensor::Tensor(GetAnfNodeOutTypeId(input_valuenode), - // GetAnfNodeOutDims(input_valuenode)); // schema::NodeType_Parameter); - tensor_ptr = GetTensorForAnfNode(input_valuenode); - // todo(yankai) - } else { - MS_ASSERT(false); - } - kernel_relation.input_tensor.push_back(tensor_ptr); - } - kernel_relation_infos_[cnode->fullname_with_scope()] = kernel_relation; - } - return 0; -} - -GraphId TrainANFSession::graph_sum_ = 0; - -KernelGraphPtr TrainANFSession::NewKernelGraph() { - auto graph = std::make_shared(); - graph->set_graph_id(graph_sum_); - graphs_[graph_sum_++] = graph; - return graph; -} - -std::shared_ptr TrainANFSession::ConstructKernelGraph(const FuncGraphPtr &func_graph) { - MS_EXCEPTION_IF_NULL(func_graph); - auto graph = NewKernelGraph(); - graph->set_return(func_graph->get_return()); - auto node_list = TopoSort(func_graph->get_return()); - std::vector cnode_order; - for (const auto &node : node_list) { - MS_EXCEPTION_IF_NULL(node); - if (node->isa()) { - auto cn_node = node->cast(); - cnode_order.push_back(cn_node); - } - } - graph->set_execution_order(cnode_order); - return graph; -} -GraphId TrainANFSession::CompileGraph(NotNull func_graph) { - auto graph = ConstructKernelGraph(func_graph); - func_graph_ = func_graph; - MS_EXCEPTION_IF_NULL(graph); - MS_LOG(INFO) << "Set kernel info"; - SetKernelInfo(graph.get()); - - (void)BuildKernelInputAndOutputFromFuncGraph(graph); - MS_LOG(INFO) << "Build kernel"; - auto ret = BuildKernel(graph.get()); - if (0 != ret) { - MS_LOG(EXCEPTION) << "BuildKernel failed"; - } - - // return the graph id to backend - auto graph_id = graph->graph_id(); - graphs_[graph_id] = graph; - MS_LOG(INFO) << "Compile graph " << graph_id << " success"; - return graph_id; -} - -void TrainANFSession::RunGraph(const GraphId &graph_id, const std::vector &inputs, - std::vector *outputs) { - auto &kernel_graph = graphs_[graph_id]; - MS_EXCEPTION_IF_NULL(kernel_graph); - MS_LOG(INFO) << "Bind input output address"; - // runtime_.BindInputOutput(kernel_graph.get(), inputs, outputs); -- will be bound in Run - // auto execution_order = kernel_graph->execution_order(); - // Todo : hangangqiang - // Reorder(&execution_order); - // kernel_graph->set_execution_order(execution_order); - MS_LOG(INFO) << "Run graph start"; - auto ret = runtime_.Run(kernel_graph.get(), (std::vector &)inputs, *outputs); - if (!ret) { - MS_LOG(EXCEPTION) << "Run graph failed"; - } - MS_LOG(INFO) << "Run graph end"; -} - -void TrainANFSession::SetKernelInfo(const KernelGraph *kernel_graph) { - MS_EXCEPTION_IF_NULL(kernel_graph); - auto &kernel_nodes = kernel_graph->execution_order(); - for (const auto &kernel_node : kernel_nodes) { - MS_EXCEPTION_IF_NULL(kernel_node); - auto kernel_info = std::make_shared(); - kernel_node->set_kernel_info(kernel_info); - } -} - -int TrainANFSession::BuildKernel(const KernelGraph *kernel_graph) { - MS_EXCEPTION_IF_NULL(kernel_graph); - for (auto iter = kernel_relation_infos_.begin(); iter != kernel_relation_infos_.end(); ++iter) { - std::string kernel_name = iter->first; - KernelRelation anf_register = iter->second; - MS_EXCEPTION_IF_NULL(anf_register.cnode); - if (IsPrimitiveCNode(anf_register.cnode, prim::kPrimReturn)) { - continue; - } - auto value_node_prim = anf_register.cnode->input(0); - MS_EXCEPTION_IF_NULL(value_node_prim); - auto prim = GetValueNode>(value_node_prim); - MS_EXCEPTION_IF_NULL(prim); - auto node_primitive = (lite::Primitive *)(prim->GetPrimitive()); - MS_EXCEPTION_IF_NULL(node_primitive); - auto ret = node_primitive->InferShape(anf_register.input_tensor, anf_register.output_tensor); - if (0 != ret) { - MS_LOG(ERROR) << "InferShape failed, node : " << kernel_name; - return ret; - } - kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, node_primitive->Type()}; - - auto *kernel = lite::KernelFactory::GetInstance()->GetKernel(anf_register.input_tensor, anf_register.output_tensor, - node_primitive, context_.get(), desc); - if (nullptr == kernel) { - MS_LOG(ERROR) << "Create kernel return nullptr, name: " << kernel_name; - return -1; - } - std::shared_ptr kernel_mod(kernel); - kernel_mod->set_name(anf_register.cnode->fullname_with_scope()); - - // kernel->train(); - auto kernel_info = dynamic_cast(anf_register.cnode->kernel_info()); - MS_EXCEPTION_IF_NULL(kernel_info); - kernel_info->set_kernel_mod(kernel_mod); // XXX TODO -- only derived class KernelInfo has this method - } - return 0; -} -} // namespace session -} // namespace mindspore diff --git a/mindspore/lite/src/train/train_anf_session.h b/mindspore/lite/src/train/train_anf_session.h deleted file mode 100644 index fc495fd49e..0000000000 --- a/mindspore/lite/src/train/train_anf_session.h +++ /dev/null @@ -1,76 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_TRAIN_TRAIN_SESSION_H_ -#define MINDSPORE_LITE_SRC_TRAIN_TRAIN_SESSION_H_ -#include -#include -#include -#include -#include -#include "include/context.h" -#include "backend/session/session_basic.h" -#include "backend/session/kernel_graph.h" -#include "mindspore/lite/src/train/lite_kernel_runtime.h" -// #include "backend/session/session_factory.h" -namespace mindspore { -namespace lite::tensor { -class Tensor; -} -namespace session { -struct KernelRelation { - std::string node_full_name; - std::vector input_tensor; - std::vector output_tensor; - CNodePtr cnode; -}; - -class TrainANFSession { - public: - explicit TrainANFSession(lite::Context *context) { Init(context); } - ~TrainANFSession() = default; - - GraphId CompileGraph(NotNull func_graph); - - void RunGraph(const GraphId &graph_id, const std::vector &inputs, - std::vector *outputs); - - // void RunGraph(const GraphId &graph_id, const std::vector &inputs, VectorRef *outputs) override - // {}; - protected: - void Init(lite::Context *context); - std::shared_ptr context_ = nullptr; - std::unordered_map> graphs_; - static GraphId graph_sum_; - KernelGraphPtr NewKernelGraph(); - - private: - // GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override; - // GraphId CompileGraph(const char *model_buf, size_t size); - std::shared_ptr ConstructKernelGraph(const FuncGraphPtr &func_graph); - int BuildKernelInputAndOutputFromFuncGraph(const KernelGraphPtr &kernel_graph); - - lite::tensor::Tensor *GetTensorForAnfNode(const AnfNodePtr anf_node); - - void SetKernelInfo(const KernelGraph *kernel_graph); - int BuildKernel(const KernelGraph *kernel_graph); - lite::LiteInferKernelRuntime runtime_; - std::map kernel_relation_infos_; - FuncGraphPtr func_graph_ = NULL; - std::map tensors_; -}; -} // namespace session -} // namespace mindspore -#endif // MINDSPORE_LITE_SRC_TRAIN_TRAIN_SESSION_H_ diff --git a/mindspore/lite/src/train/train_session.cc b/mindspore/lite/src/train/train_session.cc deleted file mode 100644 index d9d8ddc346..0000000000 --- a/mindspore/lite/src/train/train_session.cc +++ /dev/null @@ -1,267 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include "include/context.h" -#include "mindspore/ccsrc/runtime/device/kernel_info.h" -#include "mindspore/lite/src/train/train_session.h" -#include "mindspore/lite/src/kernel_factory.h" -#include "mindspore/lite/src/param_value_lite.h" -#include "utils/ms_utils.h" -#include "mindspore/lite/src/ops/ops.h" -#include "ir/anf.h" -#include "mindspore/lite/src/ir/tensor.h" -#include "abstract/abstract_value.h" -#include "backend/session/anf_runtime_algorithm.h" -#include "src/ir/primitive_value.h" -#include "src/train/model_impl.h" - -namespace mindspore { -namespace session { -static std::vector GetAnfNodeOutDims(const AnfNodePtr &anfNodePtr) { - auto nodeAbstract = anfNodePtr->abstract(); - if (nodeAbstract != nullptr) { - auto shape = nodeAbstract->GetShapeTrack(); - if (!shape->isa()) { - MS_LOG(EXCEPTION) << "Not a Shape"; - return {}; - } - auto dims = dyn_cast(shape)->shape(); - return dims; - } else { - MS_LOG(WARNING) << "abstract is nullptr, return empty dims"; - return {}; - } -} - -static schema::Format GetAnfNodeFormat(const AnfNodePtr &anfNodePtr) { - auto nodeAbstract = anfNodePtr->abstract(); - if (nodeAbstract != nullptr) { - return schema::Format_NHWC; // XXX TODO -- extract Format from AnfNode - } else { - MS_LOG(WARNING) << "abstract is nullptr, return schema::Format_NHWC"; - return schema::Format_NHWC; - } -} - -static TypeId GetAnfNodeOutTypeId(const AnfNodePtr &anfNodePtr) { - auto nodeAbstract = anfNodePtr->abstract(); - if (nodeAbstract != nullptr) { - return TypeId::kNumberTypeFloat32; // XXX TODO nodeAbstract->GetTypeTrack()->generic_type_id(); - } else { - MS_LOG(WARNING) << "abstract is nullptr, return kTypeUnknown"; - return TypeId::kTypeUnknown; - } -} - -void TrainSession::Init(lite::Context *context) { - MS_EXCEPTION_IF_NULL(context); - this->context_ = std::make_shared(context->thread_num_, context->allocator, context->device_ctx_); -} - -lite::tensor::Tensor *TrainSession::GetTensorForAnfNode(const AnfNodePtr anf_node) { - lite::tensor::Tensor *out_tensor = tensors_[anf_node]; - if (out_tensor == NULL) { - out_tensor = new lite::tensor::Tensor(GetAnfNodeOutTypeId(anf_node), - GetAnfNodeOutDims(anf_node)); //, schema::NodeType_Parameter); - tensors_[anf_node] = out_tensor; - } - return out_tensor; -} - -int TrainSession::BuildKernelInputAndOutputFromFuncGraph(const KernelGraphPtr &kernel_graph) { - auto return_node = kernel_graph->get_return(); - auto node_list = TopoSort(return_node); - auto model_imp = std::dynamic_pointer_cast(func_graph_); - for (auto &node : node_list) { - if (!node->isa()) { - continue; - } - KernelRelation kernel_relation; - auto cnode = node->cast(); - kernel_relation.node_full_name = cnode->fullname_with_scope(); - kernel_relation.cnode = cnode; - std::vector *cnode_io_indices = model_imp->GetCNodeInputOutputIndices(cnode->fullname_with_scope()); - if (cnode_io_indices == NULL) { - MS_LOG(WARNING) << "No IO vectors for " << cnode->fullname_with_scope(); - } else { - for (int i = 0; i < cnode_io_indices[1].size(); i++) { - AnfNodePtr anf_node = model_imp->GetAnfNode(cnode_io_indices[1].data()[i]); - kernel_relation.output_tensor.push_back(GetTensorForAnfNode(anf_node)); - } - } - lite::tensor::Tensor *tensor_ptr = nullptr; - for (size_t index = 1; index < cnode->inputs().size(); ++index) { - if (cnode->input(index)->isa()) { - auto input_cnode = cnode->input(index)->cast(); - auto input_kernel_relation = kernel_relation_infos_[input_cnode->fullname_with_scope()]; - // todo not support multi-outputs kernel sudo as spilt - tensor_ptr = input_kernel_relation.output_tensor.front(); - } else if (cnode->input(index)->isa()) { - auto input_parameter = cnode->input(index)->cast(); - auto para = input_parameter->default_param(); - auto param_value = std::dynamic_pointer_cast(para); - // auto dims = param_value->tensor_shape(); - // tensor_ptr = new lite::tensor::Tensor(param_value->tensor_type(), dims); // schema::NodeType_ValueNode); - tensor_ptr = GetTensorForAnfNode(cnode->input(index)); - if ((param_value != nullptr) && (param_value->tensor_size() != 0)) { - tensor_ptr->SetData(param_value->tensor_addr()); - } - } else if (cnode->input(index)->isa()) { - auto input_valuenode = cnode->input(index)->cast(); - // tensor_ptr = new lite::tensor::Tensor(GetAnfNodeOutTypeId(input_valuenode), - // GetAnfNodeOutDims(input_valuenode)); // schema::NodeType_Parameter); - tensor_ptr = GetTensorForAnfNode(input_valuenode); - // todo(yankai) - } else { - MS_ASSERT(false); - } - kernel_relation.input_tensor.push_back(tensor_ptr); - } - kernel_relation_infos_[cnode->fullname_with_scope()] = kernel_relation; - } - return 0; -} -#if 0 -GraphId TrainSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { - auto graph_id = graph_sum_; - auto graph = SessionBasic::ConstructKernelGraph(lst, outputs); - MS_EXCEPTION_IF_NULL(graph); - - BuildKernel(graph.get()); - MS_LOG(INFO) << "Assign kernel address"; - runtime_.AssignKernelAddress(graph.get()); - return graph_id; -} - -GraphId TrainSession::CompileGraph(const char *model_buf, size_t size) { return 0; } -#else -GraphId TrainSession::graph_sum_ = 0; - -KernelGraphPtr TrainSession::NewKernelGraph() { - auto graph = std::make_shared(); - graph->set_graph_id(graph_sum_); - graphs_[graph_sum_++] = graph; - return graph; -} - -#endif - -std::shared_ptr TrainSession::ConstructKernelGraph(const FuncGraphPtr &func_graph) { - MS_EXCEPTION_IF_NULL(func_graph); - auto graph = NewKernelGraph(); - graph->set_return(func_graph->get_return()); - auto node_list = TopoSort(func_graph->get_return()); - std::vector cnode_order; - for (const auto &node : node_list) { - MS_EXCEPTION_IF_NULL(node); - if (node->isa()) { - auto cn_node = node->cast(); - cnode_order.push_back(cn_node); - } - } - graph->set_execution_order(cnode_order); - return graph; -} -GraphId TrainSession::CompileGraph(NotNull func_graph) { - auto graph = ConstructKernelGraph(func_graph); - func_graph_ = func_graph; - MS_EXCEPTION_IF_NULL(graph); - MS_LOG(INFO) << "Set kernel info"; - SetKernelInfo(graph.get()); - - (void)BuildKernelInputAndOutputFromFuncGraph(graph); - MS_LOG(INFO) << "Build kernel"; - auto ret = BuildKernel(graph.get()); - if (0 != ret) { - MS_LOG(EXCEPTION) << "BuildKernel failed"; - } - - // return the graph id to backend - auto graph_id = graph->graph_id(); - graphs_[graph_id] = graph; - MS_LOG(INFO) << "Compile graph " << graph_id << " success"; - return graph_id; -} - -void TrainSession::RunGraph(const GraphId &graph_id, const std::vector &inputs, - std::vector *outputs) { - auto &kernel_graph = graphs_[graph_id]; - MS_EXCEPTION_IF_NULL(kernel_graph); - MS_LOG(INFO) << "Bind input output address"; - // runtime_.BindInputOutput(kernel_graph.get(), inputs, outputs); -- will be bound in Run - // auto execution_order = kernel_graph->execution_order(); - // Todo : hangangqiang - // Reorder(&execution_order); - // kernel_graph->set_execution_order(execution_order); - MS_LOG(INFO) << "Run graph start"; - auto ret = runtime_.Run(kernel_graph.get(), (std::vector &)inputs, outputs); - if (!ret) { - MS_LOG(EXCEPTION) << "Run graph failed"; - } - MS_LOG(INFO) << "Run graph end"; -} - -void TrainSession::SetKernelInfo(const KernelGraph *kernel_graph) { - MS_EXCEPTION_IF_NULL(kernel_graph); - auto &kernel_nodes = kernel_graph->execution_order(); - for (const auto &kernel_node : kernel_nodes) { - MS_EXCEPTION_IF_NULL(kernel_node); - auto kernel_info = std::make_shared(); - kernel_node->set_kernel_info(kernel_info); - } -} - -int TrainSession::BuildKernel(const KernelGraph *kernel_graph) { - MS_EXCEPTION_IF_NULL(kernel_graph); - for (auto iter = kernel_relation_infos_.begin(); iter != kernel_relation_infos_.end(); ++iter) { - std::string kernel_name = iter->first; - KernelRelation anf_register = iter->second; - MS_EXCEPTION_IF_NULL(anf_register.cnode); - if (IsPrimitiveCNode(anf_register.cnode, prim::kPrimReturn)) { - continue; - } - auto value_node_prim = anf_register.cnode->input(0); - MS_EXCEPTION_IF_NULL(value_node_prim); - auto prim = GetValueNode>(value_node_prim); - MS_EXCEPTION_IF_NULL(prim); - auto node_primitive = (lite::Primitive *)(prim->GetPrimitive()); - MS_EXCEPTION_IF_NULL(node_primitive); - auto ret = node_primitive->InferShape(anf_register.input_tensor, anf_register.output_tensor); - if (0 != ret) { - MS_LOG(ERROR) << "InferShape failed, node : " << kernel_name; - return ret; - } - kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, node_primitive->Type()}; - - auto *kernel = lite::KernelFactory::GetInstance()->GetKernel(anf_register.input_tensor, anf_register.output_tensor, - node_primitive, context_.get(), desc); - if (nullptr == kernel) { - MS_LOG(ERROR) << "Create kernel return nullptr, name: " << kernel_name; - return -1; - } - std::shared_ptr kernel_mod(kernel); - kernel_mod->set_name(anf_register.cnode->fullname_with_scope()); - - // kernel->train(); - auto kernel_info = dynamic_cast(anf_register.cnode->kernel_info()); - MS_EXCEPTION_IF_NULL(kernel_info); - kernel_info->set_kernel_mod(kernel_mod); // XXX TODO -- only derived class KernelInfo has this method - } - return 0; -} -} // namespace session -} // namespace mindspore diff --git a/mindspore/lite/src/train/train_session.h b/mindspore/lite/src/train/train_session.h deleted file mode 100644 index ff712ffd83..0000000000 --- a/mindspore/lite/src/train/train_session.h +++ /dev/null @@ -1,76 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_TRAIN_TRAIN_SESSION_H_ -#define MINDSPORE_LITE_SRC_TRAIN_TRAIN_SESSION_H_ -#include -#include -#include -#include -#include -#include "include/context.h" -#include "backend/session/session_basic.h" -#include "backend/session/kernel_graph.h" -#include "mindspore/lite/src/train/lite_kernel_runtime.h" -// #include "backend/session/session_factory.h" -namespace mindspore { -namespace lite::tensor { -class Tensor; -} -namespace session { -struct KernelRelation { - std::string node_full_name; - std::vector input_tensor; - std::vector output_tensor; - CNodePtr cnode; -}; - -class TrainSession { - public: - explicit TrainSession(lite::Context * context) { Init(context); } - ~TrainSession() = default; - - GraphId CompileGraph(NotNull func_graph); - - void RunGraph(const GraphId &graph_id, const std::vector &inputs, - std::vector *outputs); - - // void RunGraph(const GraphId &graph_id, const std::vector &inputs, VectorRef *outputs) override - // {}; - protected: - void Init(lite::Context *context); - std::shared_ptr context_ = nullptr; - std::unordered_map> graphs_; - static GraphId graph_sum_; - KernelGraphPtr NewKernelGraph(); - - private: - // GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override; - // GraphId CompileGraph(const char *model_buf, size_t size); - std::shared_ptr ConstructKernelGraph(const FuncGraphPtr &func_graph); - int BuildKernelInputAndOutputFromFuncGraph(const KernelGraphPtr &kernel_graph); - - lite::tensor::Tensor *GetTensorForAnfNode(const AnfNodePtr anf_node); - - void SetKernelInfo(const KernelGraph *kernel_graph); - int BuildKernel(const KernelGraph *kernel_graph); - lite::LiteInferKernelRuntime runtime_; - std::map kernel_relation_infos_; - FuncGraphPtr func_graph_ = NULL; - std::map tensors_; -}; -} // namespace session -} // namespace mindspore -#endif // MINDSPORE_LITE_SRC_TRAIN_TRAIN_SESSION_H_ diff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt index b6b961868d..a58524abb6 100644 --- a/mindspore/lite/test/CMakeLists.txt +++ b/mindspore/lite/test/CMakeLists.txt @@ -175,12 +175,10 @@ set(TEST_LITE_SRC ${LITE_DIR}/src/ir/primitive_t_value.cc ${LITE_DIR}/src/context.cc ${LITE_DIR}/src/executor.cc - ${LITE_DIR}/src/kernel_factory.cc ${LITE_DIR}/src/kernel_registry.cc ${LITE_DIR}/src/lite_kernel.cc ${LITE_DIR}/src/lite_session.cc ${LITE_DIR}/src/model.cc - ${LITE_DIR}/src/model_impl.cc ${LITE_DIR}/src/populate_parameter.cc ${LITE_DIR}/src/scheduler.cc ${LITE_DIR}/src/common/graph_util.cc @@ -265,7 +263,7 @@ if (SUPPORT_TRAIN) # ${LITE_DIR}/src/ir/primitive_value.cc # ${LITE_DIR}/src/train/lite_kernel_runtime.cc # ${LITE_DIR}/src/train/train_session.cc - # ${LITE_DIR}/src/train/model_impl.cc + # ${LITE_DIR}/src/train/model.cc ${LITE_DIR}/src/lite_session.cc # temporary ) else() diff --git a/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc b/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc index 0cd2036cec..8197881445 100644 --- a/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc @@ -20,15 +20,15 @@ #include #include "schema/inner/model_generated.h" #include "tools/optimizer/common/gllo_utils.h" -#include "src/kernel_factory.h" #include "tools/anf_exporter/anf_exporter.h" +#include "src/kernel_registry.h" #include "src/scheduler.h" #include "include/context.h" #include "src/lite_session.h" #include "src/ir/primitive_t_value.h" #include "src/populate_parameter.h" -using mindspore::lite::KernelFactory; +using mindspore::lite::KernelRegistry; using mindspore::lite::tensor::Tensor; using mindspore::lite::PrimitiveTValue; namespace mindspore::opt {