parent
183144e135
commit
04763b8b76
@ -0,0 +1,71 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "ir/primitive_base.h"
|
||||||
|
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
bool Primitive::operator==(const Value &other) const {
|
||||||
|
if (other.isa<Primitive>()) {
|
||||||
|
auto other_prim = static_cast<const Primitive &>(other);
|
||||||
|
return *this == other_prim;
|
||||||
|
} else {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Primitive::operator==(const Primitive &other) const {
|
||||||
|
if (name() != other.name()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (attrs_.size() != other.attrs_.size()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
auto all = std::all_of(attrs_.begin(), attrs_.end(), [&other](const std::pair<std::string, ValuePtr> &item) -> bool {
|
||||||
|
if (item.second == nullptr) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
auto iter = other.attrs_.find(item.first);
|
||||||
|
if (iter == other.attrs_.end()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return *item.second == *iter->second;
|
||||||
|
});
|
||||||
|
return all;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string Primitive::GetAttrsText() const {
|
||||||
|
if (attrs_.empty()) {
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
|
||||||
|
std::ostringstream oss;
|
||||||
|
oss << "[";
|
||||||
|
bool is_first = true;
|
||||||
|
for (auto &attr : attrs_) {
|
||||||
|
if (is_first) {
|
||||||
|
is_first = false;
|
||||||
|
} else {
|
||||||
|
oss << ", ";
|
||||||
|
}
|
||||||
|
oss << attr.first << "=" << attr.second->DumpText();
|
||||||
|
}
|
||||||
|
oss << "]";
|
||||||
|
|
||||||
|
return oss.str();
|
||||||
|
}
|
||||||
|
} // namespace mindspore
|
@ -0,0 +1,128 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_IR_PRIMITIVE_BASE_H_
|
||||||
|
#define MINDSPORE_CCSRC_IR_PRIMITIVE_BASE_H_
|
||||||
|
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <vector>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <tuple>
|
||||||
|
|
||||||
|
#include "ir/dtype/type.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
// Supported meta type
|
||||||
|
enum PrimType {
|
||||||
|
kPrimTypeUnknown = 0,
|
||||||
|
kPrimTypeBegin = kTypeUnknown,
|
||||||
|
kPrimTypeBuiltIn, // Built-in primitive operator
|
||||||
|
kPrimTypePyInferShape, // Primitive operator defined by custom
|
||||||
|
kPrimTypePyInferTensor, // Primitive operator defined by custom
|
||||||
|
kPrimTypeUserCustom
|
||||||
|
};
|
||||||
|
|
||||||
|
class Primitive : public Named {
|
||||||
|
public:
|
||||||
|
explicit Primitive(const std::string &name, const bool is_base = true, const PrimType prim_type = kPrimTypeBuiltIn)
|
||||||
|
: Named(name), is_base_(is_base), has_signature_(false), prim_type_(prim_type) {}
|
||||||
|
|
||||||
|
Primitive(const Primitive &prim)
|
||||||
|
: Named(prim),
|
||||||
|
attrs_(prim.attrs_),
|
||||||
|
instance_name_(prim.instance_name_),
|
||||||
|
is_base_(prim.is_base_),
|
||||||
|
has_signature_(prim.has_signature_),
|
||||||
|
prim_type_(prim.prim_type_) {}
|
||||||
|
|
||||||
|
MS_DECLARE_PARENT(Primitive, Named);
|
||||||
|
|
||||||
|
abstract::AbstractBasePtr ToPrimAbstract(const AnfNodePtr &anf_node);
|
||||||
|
std::string ToString() const override { return name(); }
|
||||||
|
Primitive &AddAttr(const std::string &name, const ValuePtr &attr) {
|
||||||
|
attrs_[name] = attr;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
Primitive &SetAttrs(const std::unordered_map<std::string, ValuePtr> &attrs) {
|
||||||
|
for (auto &attr : attrs) {
|
||||||
|
attrs_[attr.first] = attr.second;
|
||||||
|
}
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
void set_attr(const std::string &attrName, const ValuePtr &attr) { attrs_[attrName] = attr; }
|
||||||
|
void EraseAttr(const std::string &attrName) { (void)attrs_.erase(attrName); }
|
||||||
|
|
||||||
|
ValuePtr GetAttr(const std::string &attrName) const {
|
||||||
|
auto iter = attrs_.find(attrName);
|
||||||
|
return iter == attrs_.cend() ? nullptr : iter->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::unordered_map<std::string, ValuePtr> &attrs() const { return attrs_; }
|
||||||
|
|
||||||
|
// if Primitive has any attribute, for Primitives like scalar_add, return, etc, don't have any attribute.
|
||||||
|
bool HasAttr() const { return !attrs_.empty(); }
|
||||||
|
bool HasAttr(const std::string &attrName) const {
|
||||||
|
auto iter = attrs_.find(attrName);
|
||||||
|
return !(iter == attrs_.cend());
|
||||||
|
}
|
||||||
|
void set_prim_type(const PrimType t) { prim_type_ = t; }
|
||||||
|
void set_instance_name(const std::string s) { instance_name_ = s; }
|
||||||
|
bool HasPyEvaluator() const { return prim_type_ == kPrimTypePyInferShape || prim_type_ == kPrimTypeUserCustom; }
|
||||||
|
bool HasPyInferTensor() const { return prim_type_ == kPrimTypePyInferTensor; }
|
||||||
|
bool IsCustomPrim() const { return prim_type_ == kPrimTypeUserCustom; }
|
||||||
|
|
||||||
|
PrimType prim_type() const { return prim_type_; }
|
||||||
|
std::string instance_name() const { return instance_name_; }
|
||||||
|
std::string GetAttrsText() const;
|
||||||
|
bool operator==(const Value &other) const override;
|
||||||
|
bool operator==(const Primitive &other) const;
|
||||||
|
~Primitive() override = default;
|
||||||
|
|
||||||
|
void set_has_signature(bool has_signature) { has_signature_ = has_signature; }
|
||||||
|
bool has_signature() const { return has_signature_; }
|
||||||
|
bool is_base() const { return is_base_; }
|
||||||
|
|
||||||
|
protected:
|
||||||
|
std::unordered_map<std::string, ValuePtr> attrs_;
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::string instance_name_;
|
||||||
|
bool is_base_;
|
||||||
|
bool has_signature_;
|
||||||
|
PrimType prim_type_;
|
||||||
|
};
|
||||||
|
|
||||||
|
inline std::ostream &operator<<(std::ostream &os, const PrimitivePtr &p) {
|
||||||
|
os << *p;
|
||||||
|
return os;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct PrimitiveEqual {
|
||||||
|
bool operator()(PrimitivePtr const &t1, PrimitivePtr const &t2) const {
|
||||||
|
MS_EXCEPTION_IF_NULL(t1);
|
||||||
|
MS_EXCEPTION_IF_NULL(t2);
|
||||||
|
return t1->name() == t2->name();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct PrimitiveHasher {
|
||||||
|
std::size_t operator()(PrimitivePtr const &prim) const { return prim->Hash(); }
|
||||||
|
};
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_IR_PRIMITIVE_BASE_H_
|
@ -0,0 +1,25 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "ir/primitive_base.h"
|
||||||
|
#include "pipeline/static_analysis/abstract_function.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
abstract::AbstractBasePtr Primitive::ToPrimAbstract(const AnfNodePtr &anf_node) {
|
||||||
|
auto prim_func = std::make_shared<abstract::PrimitiveAbstractClosure>(shared_from_base<Primitive>(), anf_node);
|
||||||
|
return prim_func;
|
||||||
|
}
|
||||||
|
} // namespace mindspore
|
@ -0,0 +1,49 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "utils/primitive_utils.h"
|
||||||
|
#include "pipeline/parse/python_adapter.h"
|
||||||
|
#include "utils/log_adapter.h"
|
||||||
|
#include "common/utils.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
py::function GetBpropFunctionByObj(py::object obj) {
|
||||||
|
static const std::string get_bprop_fn = "get_bprop_fn";
|
||||||
|
static const std::string ad_module = "mindspore.ops._grad";
|
||||||
|
py::function fn = parse::python_adapter::GetPyFn(ad_module, get_bprop_fn)(obj);
|
||||||
|
return fn;
|
||||||
|
}
|
||||||
|
|
||||||
|
py::function GetBpropFunction(std::string name) {
|
||||||
|
auto fn = GetBpropFunctionByObj(py::str(name));
|
||||||
|
if (fn.is_none()) {
|
||||||
|
MS_LOG(WARNING) << "Can't find bprop function for " << name;
|
||||||
|
}
|
||||||
|
return fn;
|
||||||
|
}
|
||||||
|
|
||||||
|
py::function GetComputeFunction(std::string name) {
|
||||||
|
static const std::string module = "mindspore._extends.builtin_operations";
|
||||||
|
py::module mod = py::module::import(common::SafeCStr(module));
|
||||||
|
if (!py::hasattr(mod, common::SafeCStr(name))) {
|
||||||
|
PyErr_SetString(PyExc_NotImplementedError, common::SafeCStr(name));
|
||||||
|
// If raise AttributeError, user can't understand. This case need raise NotImplementedError.
|
||||||
|
throw py::error_already_set();
|
||||||
|
}
|
||||||
|
py::object fn = mod.attr(common::SafeCStr(name));
|
||||||
|
return fn;
|
||||||
|
}
|
||||||
|
} // namespace mindspore
|
@ -0,0 +1,33 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_UTILS_PRIMITIVE_UTILS_H_
|
||||||
|
#define MINDSPORE_CCSRC_UTILS_PRIMITIVE_UTILS_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include "pybind11/pybind11.h"
|
||||||
|
|
||||||
|
namespace py = pybind11;
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
py::function GetBpropFunctionByObj(py::object obj);
|
||||||
|
|
||||||
|
py::function GetBpropFunction(std::string name);
|
||||||
|
|
||||||
|
py::function GetComputeFunction(std::string name);
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CCSRC_UTILS_PRIMITIVE_UTILS_H_
|
Loading…
Reference in new issue