You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
graphengine/inc/graph/utils/attr_utils.h

151 lines
8.8 KiB

/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef INC_GRAPH_UTILS_ATTR_UTILS_H_
#define INC_GRAPH_UTILS_ATTR_UTILS_H_
#include <memory>
#include <string>
#include <vector>
#include "graph/detail/attributes_holder.h"
#include "graph/ge_attr_value.h"
#include "graph/types.h"
namespace ge {
class OpDesc;
using OpDescPtr = std::shared_ptr<OpDesc>;
using ConstOpDescPtr = std::shared_ptr<const OpDesc>;
class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AttrUtils {
public:
class ConstAttrHolderAdapter;
class AttrHolderAdapter;
// Set
static bool HasAttr(ConstAttrHolderAdapter &&obj, const string &name);
static bool SetInt(AttrHolderAdapter &&obj, const string &name, const int64_t &value);
static bool SetListInt(AttrHolderAdapter &&obj, const string &name, const vector<int64_t> &value);
static bool SetListInt(AttrHolderAdapter &&obj, const string &name, const vector<uint32_t> &value);
static bool SetListInt(AttrHolderAdapter &&obj, const string &name, const vector<int32_t> &value);
static bool SetListInt(AttrHolderAdapter &&obj, const string &name, std::initializer_list<int64_t> &&value);
static bool SetFloat(AttrHolderAdapter &&obj, const string &name, const float &value);
static bool SetListFloat(AttrHolderAdapter &&obj, const string &name, const vector<float> &value);
static bool SetBool(AttrHolderAdapter &&obj, const string &name, const bool &value);
static bool SetListBool(AttrHolderAdapter &&obj, const string &name, const vector<bool> &value);
static bool SetStr(AttrHolderAdapter &&obj, const string &name, const string &value);
static bool SetListStr(AttrHolderAdapter &&obj, const string &name, const vector<string> &value);
static bool SetTensorDesc(AttrHolderAdapter &&obj, const string &name, const GeTensorDesc &value);
static bool SetListTensorDesc(AttrHolderAdapter &&obj, const string &name, const vector<GeTensorDesc> &value);
static bool SetTensor(AttrHolderAdapter &&obj, const string &name, const GeTensorPtr &value);
static bool SetTensor(AttrHolderAdapter &&obj, const string &name, const ConstGeTensorPtr &value);
static bool SetTensor(AttrHolderAdapter &&obj, const string &name, const GeTensor &value);
static bool SetListTensor(AttrHolderAdapter &&obj, const string &name, const vector<GeTensorPtr> &value);
static bool SetListTensor(AttrHolderAdapter &&obj, const string &name, const vector<ConstGeTensorPtr> &value);
static bool SetListTensor(AttrHolderAdapter &&obj, const string &name,
std::initializer_list<ConstGeTensorPtr> &&value);
static bool SetListTensor(AttrHolderAdapter &&obj, const string &name, const vector<GeTensor> &value);
static bool SetGraph(AttrHolderAdapter &&obj, const string &name, const ComputeGraphPtr &value);
static bool SetListGraph(AttrHolderAdapter &&obj, const string &name, const vector<ComputeGraphPtr> &value);
static bool SetBytes(AttrHolderAdapter &&obj, const string &name, const GeAttrValue::BYTES &value);
static bool SetListBytes(AttrHolderAdapter &&obj, const string &name, const vector<GeAttrValue::BYTES> &value);
static bool SetNamedAttrs(AttrHolderAdapter &&obj, const string &name, const GeAttrValue::NAMED_ATTRS &value);
static bool SetListNamedAttrs(AttrHolderAdapter &&obj, const string &name,
const vector<GeAttrValue::NAMED_ATTRS> &value);
static bool SetListOpDesc(AttrHolderAdapter &&obj, const string &name, const vector<ConstOpDescPtr> &value);
static bool SetListOpDesc(AttrHolderAdapter &&obj, const string &name, const vector<OpDescPtr> &value);
// Get
static bool GetInt(ConstAttrHolderAdapter &&obj, const string &name, int64_t &value);
static bool GetInt(ConstAttrHolderAdapter &&obj, const string &name, int32_t &value);
static bool GetInt(ConstAttrHolderAdapter &&obj, const string &name, uint32_t &value);
static bool GetListInt(ConstAttrHolderAdapter &&obj, const string &name, vector<int64_t> &value);
static bool GetListInt(ConstAttrHolderAdapter &&obj, const string &name, vector<int32_t> &value);
static bool GetListInt(ConstAttrHolderAdapter &&obj, const string &name, vector<uint32_t> &value);
static bool GetFloat(ConstAttrHolderAdapter &&obj, const string &name, float &value);
static bool GetListFloat(ConstAttrHolderAdapter &&obj, const string &name, vector<float> &value);
static bool GetBool(ConstAttrHolderAdapter &&obj, const string &name, bool &value);
static bool GetListBool(ConstAttrHolderAdapter &&obj, const string &name, vector<bool> &value);
static bool GetStr(ConstAttrHolderAdapter &&obj, const string &name, string &value);
static bool GetListStr(ConstAttrHolderAdapter &&obj, const string &name, vector<string> &value);
static bool GetTensorDesc(ConstAttrHolderAdapter &&obj, const string &name, GeTensorDesc &value);
static bool GetListTensorDesc(ConstAttrHolderAdapter &&obj, const string &name, vector<GeTensorDesc> &value);
static bool GetTensor(ConstAttrHolderAdapter &&obj, const string &name, ConstGeTensorPtr &value);
static bool MutableTensor(AttrHolderAdapter &&obj, const string &name, GeTensorPtr &value);
static bool GetListTensor(ConstAttrHolderAdapter &&obj, const string &name, vector<ConstGeTensorPtr> &value);
static bool MutableListTensor(AttrHolderAdapter &&obj, const string &name, vector<GeTensorPtr> &value);
static bool GetGraph(ConstAttrHolderAdapter &&obj, const string &name, ComputeGraphPtr &value);
static bool GetListGraph(ConstAttrHolderAdapter &&obj, const string &name, vector<ComputeGraphPtr> &value);
static bool GetBytes(ConstAttrHolderAdapter &&obj, const string &name, GeAttrValue::BYTES &value);
static bool GetListBytes(ConstAttrHolderAdapter &&obj, const string &name, vector<GeAttrValue::BYTES> &value);
static bool GetNamedAttrs(ConstAttrHolderAdapter &&obj, const string &name, GeAttrValue::NAMED_ATTRS &value);
static bool GetListNamedAttrs(ConstAttrHolderAdapter &&obj, const string &name,
vector<GeAttrValue::NAMED_ATTRS> &value);
static bool GetListOpDesc(ConstAttrHolderAdapter &&obj, const string &name, vector<OpDescPtr> &value);
// Value will be moved
static bool SetZeroCopyBytes(AttrHolderAdapter &&obj, const string &name, Buffer &&buffer);
static bool GetZeroCopyBytes(ConstAttrHolderAdapter &&obj, const string &name, Buffer &buffer);
// Value will be moved
static bool SetZeroCopyListBytes(AttrHolderAdapter &&obj, const string &name, vector<Buffer> &listBuffer);
static bool GetZeroCopyListBytes(ConstAttrHolderAdapter &&obj, const string &name, vector<Buffer> &listBuffer);
static bool SetListListInt(AttrHolderAdapter &&obj, const string &name, const vector<vector<int64_t>> &value);
static bool GetListListInt(ConstAttrHolderAdapter &&obj, const string &name, vector<vector<int64_t>> &value);
static bool SetListDataType(AttrHolderAdapter &&obj, const string &name, const vector<ge::DataType> &value);
static bool GetListDataType(ConstAttrHolderAdapter &&obj, const string &name, vector<ge::DataType> &value);
static bool SetDataType(AttrHolderAdapter &&obj, const string &name, const ge::DataType &value);
static bool GetDataType(ConstAttrHolderAdapter &&obj, const string &name, ge::DataType &value);
static OpDescPtr CloneOpDesc(const ConstOpDescPtr &orgOpDesc);
static OpDescPtr CopyOpDesc(const ConstOpDescPtr &orgOpDesc);
static std::string GetAllAttrsStr(ConstAttrHolderAdapter &&obj);
class AttrHolderAdapter {
public:
AttrHolderAdapter(AttrHolder *obj) : obj_(obj) {}
~AttrHolderAdapter() {}
template <class T>
AttrHolderAdapter(const std::shared_ptr<T> &obj) : obj_(obj.get()) {}
AttrHolderAdapter(AttrHolder &obj) : obj_(&obj) {}
operator bool() const { return obj_ != nullptr; }
AttrHolder *operator->() { return obj_; }
AttrHolder *get() { return obj_; }
AttrHolder *obj_;
};
class ConstAttrHolderAdapter {
public:
ConstAttrHolderAdapter(const AttrHolder *obj) : obj_(obj) {}
~ConstAttrHolderAdapter() {}
template <class T>
ConstAttrHolderAdapter(const std::shared_ptr<T> obj) : obj_(obj.get()) {}
ConstAttrHolderAdapter(const AttrHolder &obj) : obj_(&obj) {}
operator bool() const { return obj_ != nullptr; }
const AttrHolder *operator->() const { return obj_; }
const AttrHolder *get() const { return obj_; }
private:
const AttrHolder *obj_;
};
};
} // namespace ge
#endif // INC_GRAPH_UTILS_ATTR_UTILS_H_