|
|
|
/**
|
|
|
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
|
|
*
|
|
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
* you may not use this file except in compliance with the License.
|
|
|
|
* You may obtain a copy of the License at
|
|
|
|
*
|
|
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
*
|
|
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
* See the License for the specific language governing permissions and
|
|
|
|
* limitations under the License.
|
|
|
|
*/
|
|
|
|
|
|
|
|
#ifndef GE_MODEL_GE_MODEL_H_
|
|
|
|
#define GE_MODEL_GE_MODEL_H_
|
|
|
|
|
|
|
|
#include <securec.h>
|
|
|
|
#include <map>
|
|
|
|
#include <memory>
|
|
|
|
#include <string>
|
|
|
|
#include "common/tbe_kernel_store.h"
|
|
|
|
#include "common/cust_aicpu_kernel_store.h"
|
|
|
|
#include "framework/common/debug/log.h"
|
|
|
|
#include "framework/common/fmk_error_codes.h"
|
|
|
|
#include "graph/buffer.h"
|
|
|
|
#include "graph/graph.h"
|
|
|
|
#include "proto/task.pb.h"
|
|
|
|
|
|
|
|
namespace ge {
|
|
|
|
const uint32_t INVALID_MODEL_ID = 0xFFFFFFFFUL;
|
|
|
|
class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeModel : public AttrHolder {
|
|
|
|
public:
|
|
|
|
GeModel();
|
|
|
|
~GeModel() = default;
|
|
|
|
GeModel(const GeModel &other) = delete;
|
|
|
|
GeModel &operator=(const GeModel &other) = delete;
|
|
|
|
|
|
|
|
const Graph &GetGraph() const;
|
|
|
|
std::shared_ptr<domi::ModelTaskDef> GetModelTaskDefPtr() const;
|
|
|
|
const TBEKernelStore &GetTBEKernelStore() const;
|
|
|
|
const CustAICPUKernelStore &GetCustAICPUKernelStore() const;
|
|
|
|
Buffer GetWeight() const;
|
|
|
|
|
|
|
|
std::string GetName() const;
|
|
|
|
uint32_t GetVersion() const;
|
|
|
|
std::string GetPlatformVersion() const;
|
|
|
|
uint8_t GetPlatformType() const;
|
|
|
|
|
|
|
|
void SetGraph(const Graph &graph);
|
|
|
|
void SetModelTaskDef(const std::shared_ptr<domi::ModelTaskDef> &task);
|
|
|
|
void SetTBEKernelStore(const TBEKernelStore &tbe_kernal_store);
|
|
|
|
void SetCustAICPUKernelStore(const CustAICPUKernelStore &cust_aicpu_kernal_store);
|
|
|
|
void SetWeight(const Buffer &weights_buffer);
|
|
|
|
|
|
|
|
void SetName(const std::string &name);
|
|
|
|
void SetVersion(uint32_t version);
|
|
|
|
void SetPlatformVersion(const std::string &platform_version);
|
|
|
|
void SetPlatformType(uint8_t platform_type);
|
|
|
|
|
|
|
|
void SetAttr(const ProtoAttrMapHelper &attrs);
|
|
|
|
|
|
|
|
ProtoAttrMapHelper MutableAttrMap() override;
|
|
|
|
|
|
|
|
using AttrHolder::SetAttr;
|
|
|
|
using AttrHolder::GetAllAttrs;
|
|
|
|
using AttrHolder::GetAllAttrNames;
|
|
|
|
|
|
|
|
void SetModelId(uint32_t model_id) { model_id_ = model_id; }
|
|
|
|
uint32_t GetModelId() const { return model_id_; }
|
|
|
|
|
|
|
|
protected:
|
|
|
|
ConstProtoAttrMapHelper GetAttrMap() const override;
|
|
|
|
|
|
|
|
private:
|
|
|
|
void Init();
|
|
|
|
|
|
|
|
ProtoAttrMapHelper attrs_;
|
|
|
|
|
|
|
|
Graph graph_;
|
|
|
|
std::shared_ptr<domi::ModelTaskDef> task_;
|
|
|
|
TBEKernelStore tbe_kernal_store_;
|
|
|
|
CustAICPUKernelStore cust_aicpu_kernal_store_;
|
|
|
|
Buffer weights_buffer_;
|
|
|
|
|
|
|
|
std::string name_;
|
|
|
|
uint32_t version_ = {0};
|
|
|
|
std::string platform_version_;
|
|
|
|
uint8_t platform_type_ = {0};
|
|
|
|
uint32_t model_id_ = INVALID_MODEL_ID;
|
|
|
|
};
|
|
|
|
} // namespace ge
|
|
|
|
using GeModelPtr = std::shared_ptr<ge::GeModel>;
|
|
|
|
#endif // GE_MODEL_GE_MODEL_H_
|