You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
graphengine/ge/model/ge_model.cc

89 lines
3.2 KiB

/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "model/ge_model.h"
#include <utility>
#include "common/debug/log.h"
#include "graph/debug/ge_attr_define.h"
#include "graph/utils/attr_utils.h"
namespace ge {
void GeModel::Init() {
(void)AttrUtils::SetInt(this, ATTR_MODEL_MEMORY_SIZE, 0);
(void)AttrUtils::SetInt(this, ATTR_MODEL_P2P_MEMORY_SIZE, 0);
(void)AttrUtils::SetInt(this, ATTR_MODEL_STREAM_NUM, 0);
(void)AttrUtils::SetInt(this, ATTR_MODEL_EVENT_NUM, 0);
(void)AttrUtils::SetInt(this, ATTR_MODEL_LABEL_NUM, 0);
(void)AttrUtils::SetInt(this, ATTR_MODEL_WEIGHT_SIZE, 0);
(void)AttrUtils::SetStr(this, ATTR_MODEL_TARGET_TYPE, TARGET_TYPE_MINI);
version_ = 0;
// default attrSize = 5
}
GeModel::GeModel() {
attrs_.InitDefault();
Init();
}
const Graph &GeModel::GetGraph() const { return this->graph_; }
std::shared_ptr<domi::ModelTaskDef> GeModel::GetModelTaskDefPtr() const { return this->task_; }
const TBEKernelStore &GeModel::GetTBEKernelStore() const { return this->tbe_kernal_store_; }
const CustAICPUKernelStore &GeModel::GetCustAICPUKernelStore() const { return this->cust_aicpu_kernal_store_; }
Buffer GeModel::GetWeight() const { return this->weights_buffer_; }
std::string GeModel::GetName() const { return this->name_; }
uint32_t GeModel::GetVersion() const { return this->version_; }
std::string GeModel::GetPlatformVersion() const { return this->platform_version_; }
uint8_t GeModel::GetPlatformType() const { return this->platform_type_; }
void GeModel::SetGraph(const Graph &graph) { this->graph_ = graph; }
void GeModel::SetModelTaskDef(const std::shared_ptr<domi::ModelTaskDef> &task) { this->task_ = task; }
void GeModel::SetTBEKernelStore(const TBEKernelStore &tbe_kernal_store) {
this->tbe_kernal_store_ = tbe_kernal_store;
}
void GeModel::SetCustAICPUKernelStore(const CustAICPUKernelStore &cust_aicpu_kernal_store) {
this->cust_aicpu_kernal_store_ = cust_aicpu_kernal_store;
}
void GeModel::SetWeight(const Buffer &weights_buffer) { this->weights_buffer_ = weights_buffer; }
void GeModel::SetName(const std::string &name) { this->name_ = name; }
void GeModel::SetVersion(uint32_t version) { this->version_ = version; }
void GeModel::SetPlatformVersion(const std::string &platform_version) { this->platform_version_ = platform_version; }
void GeModel::SetPlatformType(uint8_t platform_type) { this->platform_type_ = platform_type; }
void GeModel::SetAttr(const ProtoAttrMapHelper &attrs) { attrs_ = attrs; }
ProtoAttrMapHelper GeModel::MutableAttrMap() { return attrs_; }
ConstProtoAttrMapHelper GeModel::GetAttrMap() const {
return ConstProtoAttrMapHelper(attrs_.GetProtoOwner(), attrs_.GetProtoMsg());
}
} // namespace ge