You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
mindspore/mindspore/ccsrc/pipeline/static_analysis/abstract_function.cc

363 lines
11 KiB

/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pipeline/static_analysis/abstract_function.h"
#include <vector>
#include "pipeline/static_analysis/analysis_context.h"
#include "pipeline/static_analysis/static_analysis.h"
namespace mindspore {
namespace abstract {
class Evaluator;
class AnalysisEngine;
AbstractFunctionPtr AbstractFunction::MakeAbstractFunction(const AbstractFuncAtomPtrList &func_list) {
if (func_list.size() == 1) {
return func_list[0];
}
return std::make_shared<AbstractFuncUnion>(func_list);
}
AbstractFunctionPtr AbstractFuncAtom::Join(const AbstractFunctionPtr &other) {
auto this_func = shared_from_base<AbstractFuncAtom>();
if (other->isa<AbstractFuncAtom>()) {
if (*this_func == *other) {
return this_func;
}
return std::make_shared<AbstractFuncUnion>(this_func, other);
}
auto other_union = dyn_cast<AbstractFuncUnion>(other);
if (other_union->IsSuperSet(this_func)) {
return other;
}
return std::make_shared<AbstractFuncUnion>(this_func, other);
}
void AbstractFuncAtom::Visit(std::function<void(const AbstractFuncAtomPtr &)> visit_func) const {
visit_func(const_cast<AbstractFuncAtom *>(this)->shared_from_base<AbstractFuncAtom>());
}
bool AbstractFuncAtom::operator==(const AbstractFunction &other) const { return this == &other; }
AbstractFuncUnion::AbstractFuncUnion(const AbstractFuncAtomPtrList &func_list) { func_list_ = func_list; }
AbstractFuncUnion::AbstractFuncUnion(const AbstractFunctionPtr &first, const AbstractFunctionPtr &second) {
AbstractFuncAtomPtrList new_func_list;
auto build_func_list = [&new_func_list](const AbstractFuncAtomPtr &func) { new_func_list.push_back(func); };
first->Visit(build_func_list);
second->Visit(build_func_list);
func_list_ = new_func_list;
}
std::string AbstractFuncUnion::ToString() const {
std::ostringstream buffer;
buffer << "AbstractFuncUnion({";
int i = 0;
for (const auto &func : func_list_) {
MS_EXCEPTION_IF_NULL(func);
buffer << "[" << i << "]: " << func->ToString() << ", ";
i++;
}
buffer << "})";
return buffer.str();
}
bool AbstractFuncUnion::IsSuperSet(const AbstractFunctionPtr &other) {
MS_EXCEPTION_IF_NULL(other);
std::vector<bool> is_in_list;
auto build_in_list = [this, &is_in_list](const AbstractFuncAtomPtr &func) {
auto iter = find(func_list_.begin(), func_list_.end(), func);
if (iter == func_list_.end()) {
is_in_list.push_back(false);
}
return true;
};
other->Visit(build_in_list);
return std::all_of(is_in_list.begin(), is_in_list.end(), [](bool is_in) { return is_in; });
}
AbstractFunctionPtr AbstractFuncUnion::Join(const AbstractFunctionPtr &other) {
auto this_func = shared_from_base<AbstractFunction>();
if (other->isa<AbstractFuncAtom>()) {
if (IsSuperSet(other)) {
return this_func;
}
return std::make_shared<AbstractFuncUnion>(this_func, other);
}
auto other_union = dyn_cast<AbstractFuncUnion>(other);
if (other_union->IsSuperSet(this_func)) {
return other;
}
return std::make_shared<AbstractFuncUnion>(this_func, other);
}
void AbstractFuncUnion::Visit(std::function<void(const AbstractFuncAtomPtr &)> visit_func) const {
for (AbstractFuncAtomPtr poss : func_list_) {
visit_func(poss);
}
}
bool AbstractFuncUnion::operator==(const AbstractFunction &other) const {
if (!other.isa<AbstractFuncUnion>()) {
return false;
}
auto other_union = static_cast<const AbstractFuncUnion *>(&other);
if (func_list_.size() != other_union->func_list_.size()) {
return false;
}
if (func_list_ == other_union->func_list_) {
return true;
}
return false;
}
std::size_t AbstractFuncUnion::hash() const {
std::size_t hash_sum = 0;
for (auto f : func_list_) {
hash_sum = hash_combine(hash_sum, f->hash());
}
return hash_sum;
}
EvaluatorPtr PrimitiveAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) {
MS_EXCEPTION_IF_NULL(engine);
return engine->_GetEvaluatorFor(shared_from_base<PrimitiveAbstractClosure>());
}
bool PrimitiveAbstractClosure::operator==(const AbstractFunction &other) const {
if (!other.isa<PrimitiveAbstractClosure>()) {
return false;
}
auto other_prim = static_cast<const PrimitiveAbstractClosure *>(&other);
if (prim_ == other_prim->prim_ && tracking_id() == other_prim->tracking_id()) {
return true;
}
return false;
}
std::size_t PrimitiveAbstractClosure::hash() const { return hash_combine(tid(), prim_->hash()); }
EvaluatorPtr FuncGraphAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) {
MS_EXCEPTION_IF_NULL(engine);
return engine->_GetEvaluatorFor(shared_from_base<FuncGraphAbstractClosure>());
}
bool FuncGraphAbstractClosure::operator==(const AbstractFunction &other) const {
if (!other.isa<FuncGraphAbstractClosure>()) {
return false;
}
auto other_fg = static_cast<const FuncGraphAbstractClosure *>(&other);
if (func_graph_ == other_fg->func_graph_ && context_ == other_fg->context_) {
return true;
}
return false;
}
std::size_t FuncGraphAbstractClosure::hash() const {
auto hash_value = hash_combine(tid(), func_graph_->hash());
hash_value = hash_combine(hash_value, context_->hash());
return hash_value;
}
std::string FuncGraphAbstractClosure::ToString() const {
std::stringstream ss;
ss << "FuncGraphAbstractClosure: "
<< "FuncGraph: " << func_graph_->ToString() << "; Context: " << context_->ToString();
return ss.str();
}
EvaluatorPtr MetaFuncGraphAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) {
MS_EXCEPTION_IF_NULL(engine);
return engine->_GetEvaluatorFor(shared_from_base<MetaFuncGraphAbstractClosure>());
}
bool MetaFuncGraphAbstractClosure::operator==(const AbstractFunction &other) const {
if (!other.isa<MetaFuncGraphAbstractClosure>()) {
return false;
}
auto other_meta_fg = static_cast<const MetaFuncGraphAbstractClosure *>(&other);
if (meta_func_graph_ == other_meta_fg->meta_func_graph_) {
return true;
}
return false;
}
std::size_t MetaFuncGraphAbstractClosure::hash() const {
auto hash_value = hash_combine(tid(), meta_func_graph_->hash());
return hash_value;
}
std::string MetaFuncGraphAbstractClosure::ToString() const {
return "MetaFuncGraphAbstractClosure: " + meta_func_graph_->name();
}
bool PartialAbstractClosure::operator==(const AbstractFunction &other) const {
if (!other.isa<PartialAbstractClosure>()) {
return false;
}
auto other_partial = static_cast<const PartialAbstractClosure *>(&other);
if (fn_ != other_partial->fn_) {
return false;
}
if (args_spec_list_.size() != other_partial->args_spec_list_.size()) {
return false;
}
if (args_spec_list_ == other_partial->args_spec_list_) {
return true;
}
return false;
}
std::size_t PartialAbstractClosure::hash() const {
auto hash_value = hash_combine(tid(), fn_->hash());
hash_value = hash_combine(hash_value, AbstractBasePtrListHash(args_spec_list_));
return hash_value;
}
EvaluatorPtr PartialAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) {
MS_EXCEPTION_IF_NULL(engine);
return engine->_GetEvaluatorFor(shared_from_base<PartialAbstractClosure>());
}
std::string PartialAbstractClosure::ToString() const {
std::ostringstream buffer;
buffer << "PartialAbstractClosure(" << fn_->ToString() << "(";
for (auto arg : args_spec_list_) {
buffer << arg->ToString() << ", ";
}
buffer << "))";
return buffer.str();
}
EvaluatorPtr JTransformedAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) {
MS_EXCEPTION_IF_NULL(engine);
return engine->_GetEvaluatorFor(shared_from_base<JTransformedAbstractClosure>());
}
bool JTransformedAbstractClosure::operator==(const AbstractFunction &other) const {
if (!other.isa<JTransformedAbstractClosure>()) {
return false;
}
auto other_transformed = static_cast<const JTransformedAbstractClosure *>(&other);
if (fn_ == other_transformed->fn_) {
return true;
}
return false;
}
std::size_t JTransformedAbstractClosure::hash() const {
auto hash_value = hash_combine(tid(), fn_->hash());
return hash_value;
}
EvaluatorPtr VirtualAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) {
MS_EXCEPTION_IF_NULL(engine);
return engine->_GetEvaluatorFor(shared_from_base<VirtualAbstractClosure>());
}
bool VirtualAbstractClosure::operator==(const AbstractFunction &other) const {
if (!other.isa<VirtualAbstractClosure>()) {
return false;
}
auto other_virtual = static_cast<const VirtualAbstractClosure *>(&other);
if (output_ != other_virtual->output_) {
return false;
}
if (args_spec_list_.size() != other_virtual->args_spec_list_.size()) {
return false;
}
if (args_spec_list_ == other_virtual->args_spec_list_) {
return true;
}
return false;
}
std::size_t VirtualAbstractClosure::hash() const {
auto hash_value = hash_combine(tid(), output_->hash());
hash_value = hash_combine(hash_value, AbstractBasePtrListHash(args_spec_list_));
return hash_value;
}
std::string VirtualAbstractClosure::ToString() const {
std::ostringstream buffer;
buffer << "VirtualAbstractClosure(args: {";
int i = 0;
for (const auto &arg : args_spec_list_) {
MS_EXCEPTION_IF_NULL(arg);
buffer << "[" << i << "]: " << arg->ToString() << ", ";
i++;
}
buffer << "}, output: " << output_->ToString() << ")";
return buffer.str();
}
EvaluatorPtr TypedPrimitiveAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) {
MS_EXCEPTION_IF_NULL(engine);
return engine->_GetEvaluatorFor(shared_from_base<TypedPrimitiveAbstractClosure>());
}
bool TypedPrimitiveAbstractClosure::operator==(const AbstractFunction &other) const {
if (!other.isa<TypedPrimitiveAbstractClosure>()) {
return false;
}
auto other_typed = static_cast<const TypedPrimitiveAbstractClosure *>(&other);
if (output_ != other_typed->output_) {
return false;
}
if (prim_ != other_typed->prim_) {
return false;
}
if (args_spec_list_.size() != other_typed->args_spec_list_.size()) {
return false;
}
if (args_spec_list_ == other_typed->args_spec_list_) {
return true;
}
return false;
}
std::size_t TypedPrimitiveAbstractClosure::hash() const {
auto hash_value = hash_combine(tid(), prim_->hash());
hash_value = hash_combine(hash_value, AbstractBasePtrListHash(args_spec_list_));
return hash_value;
}
std::string TypedPrimitiveAbstractClosure::ToString() const {
std::ostringstream buffer;
buffer << "TypedPrimitiveAbstractClosure: primitive: " << prim_->name() << "(args: {";
int i = 0;
for (const auto &arg : args_spec_list_) {
MS_EXCEPTION_IF_NULL(arg);
buffer << "[" << i << "]: " << arg->ToString() << ", ";
i++;
}
buffer << "}, output: " << output_->ToString() << ")";
return buffer.str();
}
bool DummyAbstractClosure::operator==(const AbstractFunction &other) const {
if (!other.isa<DummyAbstractClosure>()) {
return false;
}
return true;
}
} // namespace abstract
} // namespace mindspore