You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
mindspore/mindspore/ccsrc/pipeline/static_analysis/static_analysis.h

274 lines
10 KiB

/**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
*
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef PIPELINE_STATIC_ANALYSIS_STATIC_ANALYSIS_H_
#define PIPELINE_STATIC_ANALYSIS_STATIC_ANALYSIS_H_
#include <list>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include <utility>
#include <map>
#ifdef DEBUG
#include <stack>
#endif
#include "utils/log_adapter.h"
#include "ir/anf.h"
#include "ir/primitive.h"
#include "pipeline/static_analysis/analysis_context.h"
#include "pipeline/static_analysis/abstract_function.h"
#include "pipeline/parse/parse.h"
namespace mindspore {
namespace abstract {
// define attribute value map
using AttrValueMap = std::unordered_map<std::string, ValuePtr>;
using AttrValueMapPtr = std::shared_ptr<AttrValueMap>;
// the class to save evaluated result: abstract value and modified attribute
class EvalResult : public Base {
public:
EvalResult(AbstractBasePtr abs, AttrValueMapPtr attr) : abstract_(abs), attribute_(attr) {}
~EvalResult() override = default;
MS_DECLARE_PARENT(EvalResult, Base);
AbstractBasePtr abstract() { return abstract_; }
AttrValueMapPtr attribute() { return attribute_; }
private:
AbstractBasePtr abstract_;
AttrValueMapPtr attribute_;
};
using EvalResultPtr = std::shared_ptr<EvalResult>;
// Superclass for AnfNodeConfig and VirtualConfig.
class Config : public Base {
public:
Config() = default;
~Config() override = default;
MS_DECLARE_PARENT(Config, Base);
virtual EvalResultPtr GetEvaluatedValue() = 0;
};
// Config will be stored in AnalysisCache
using ConfigPtr = std::shared_ptr<Config>;
using ConfigPtrList = std::vector<ConfigPtr>;
// Config to a certain node in a certain context.
class AnfNodeConfig : public Config {
public:
AnfNodeConfig(const AnalysisEnginePtr &engine, const AnfNodePtr &node, const AnalysisContextPtr &context)
: Config(), engine_(std::weak_ptr<AnalysisEngine>(engine)), node_(node) {
FuncGraphPtr fg;
if (IsValueNode<FuncGraph>(node)) {
auto v = node->cast<ValueNodePtr>();
fg = v->value()->cast<FuncGraphPtr>();
} else {
fg = node->func_graph();
}
context_ = nullptr;
if (context != nullptr) {
context_ = context->Filter(fg);
}
}
~AnfNodeConfig() override = default;
MS_DECLARE_PARENT(AnfNodeConfig, Config);
EvalResultPtr GetEvaluatedValue() override;
AnalysisContextPtr context() const { return context_; }
AnfNodePtr node() const { return node_; }
AnalysisEnginePtr engine() const { return engine_.lock(); }
// used by unordered_map;
bool operator==(const AnfNodeConfig &other) const {
// compare node with pointer, context with pointer except DummyContext as it's created by make_shared;
// context should not be nullptr;
if (context_->IsDummyContext() && other.context_->IsDummyContext()) {
return true;
}
return (node_ == other.node_) && (context_ == other.context_);
}
std::string ToString() const override {
std::ostringstream buffer;
buffer << "Node: " << node_->DebugString() << ", Context: " << context_->ToString();
return buffer.str();
}
private:
// AnalysisEngine is global.
// As AnfNodeConfig is cached in AnalysisEngine.AnalysisCache, use
// weak_ptr to break Config cycle.
std::weak_ptr<AnalysisEngine> engine_;
AnfNodePtr node_;
AnalysisContextPtr context_;
};
using AnfNodeConfigPtr = std::shared_ptr<AnfNodeConfig>;
struct AnfNodeConfigHasher {
std::size_t operator()(const AnfNodeConfigPtr conf) const;
};
struct AnfNodeConfigEqual {
bool operator()(const AnfNodeConfigPtr lhs, const AnfNodeConfigPtr rhs) const;
};
class VirtualConfig : public Config {
public:
explicit VirtualConfig(const AbstractBasePtr &abstract) : Config(), abstract_(abstract) {}
~VirtualConfig() override = default;
MS_DECLARE_PARENT(VirtualConfig, Config);
EvalResultPtr GetEvaluatedValue() override {
return std::make_shared<EvalResult>(abstract_, std::make_shared<AttrValueMap>());
}
private:
AbstractBasePtr abstract_;
};
// AnalysisCache
class AnalysisCache {
public:
AnalysisCache() = default;
~AnalysisCache() = default;
void Clear() { cache_.clear(); }
void set_value(const AnfNodeConfigPtr &conf, const EvalResultPtr &arg);
EvalResultPtr GetValue(const AnfNodeConfigPtr &conf);
private:
std::unordered_map<AnfNodeConfigPtr, EvalResultPtr, AnfNodeConfigHasher, AnfNodeConfigEqual> cache_;
};
using PrimEvaluatorMap = std::unordered_map<PrimitivePtr, EvaluatorPtr, PrimitiveHasher, PrimitiveEqual>;
using AnfNodeConfigMap =
std::unordered_map<AnfNodeConfigPtr, AnfNodeConfigPtr, AnfNodeConfigHasher, AnfNodeConfigEqual>;
struct AnalysisResult {
EvalResultPtr inferred;
AnalysisContextPtr context;
};
class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
public:
AnalysisEngine(const PrimEvaluatorMap &prim_evaluator_map, const FuncGraphManagerPtr &func_graph_manager)
: cache_(AnalysisCache()), prim_constructors_(prim_evaluator_map), func_graph_manager_(func_graph_manager) {}
~AnalysisEngine() = default;
// func_graph: The func_graph to analyze.
// args_spec_list: The abstracted arguments for the func_graph. Must be a tuple of AbstractBase.
AnalysisResult Run(const FuncGraphPtr &func_graph, const AbstractBasePtrList &args_spec_list);
EvalResultPtr GetEvaluatedValue(const AnfNodeConfigPtr &conf);
// Return the Evaluator for the given function.
EvaluatorPtr GetEvaluatorFor(const AbstractFunctionPtr &fn);
AbstractBasePtr EvalValueNode(const ValueNodePtr &value_node, const AnfNodeConfigPtr &conf);
EvalResultPtr EvalCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf);
// Infer the result of fn(args).
EvalResultPtr Execute(const AbstractFunctionPtr &fn, const AbstractBasePtrList &args_spec_list);
void Clear();
void ClearEvaluatorCache();
AnalysisCache &cache() { return cache_; }
AnfNodeConfigPtr MakeConfig(const AnfNodePtr &node, const AnalysisContextPtr &context) {
return std::make_shared<AnfNodeConfig>(shared_from_this(), node, context);
}
// Overloaded function.
EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<PrimitiveAbstractClosure> &fn);
EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<PartialAbstractClosure> &fn);
EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<FuncGraphAbstractClosure> &fn);
EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<MetaFuncGraphAbstractClosure> &fn);
EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<VirtualAbstractClosure> &fn);
EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<TypedPrimitiveAbstractClosure> &);
EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<JTransformedAbstractClosure> &fn);
FuncGraphManagerPtr func_graph_manager() { return func_graph_manager_; }
const AnfNodeConfigMap &anfnode_config_map() const { return anfnode_config_map_; }
// Set the analysis result for orig to the result for new.
// This sets an entry in anfnode_config_map from orig to new.
EvalResultPtr ForwardConfig(const AnfNodeConfigPtr &orig_conf, const AnfNodeConfigPtr new_conf) {
// Use anfnode_config_map_[orig_conf] = new_conf will require AnfNodeConfig provide copy constructor.
(void)anfnode_config_map_.emplace(orig_conf, new_conf);
MS_LOG(DEBUG) << "Forward orig_conf: " << orig_conf->node()->DebugString()
<< ", to new_conf: " << new_conf->node()->DebugString();
return GetEvaluatedValue(new_conf);
}
const PrimEvaluatorMap &PrimConstructors() const { return prim_constructors_; }
AnalysisCache cache_;
std::unordered_map<PrimitivePyPtr, EvaluatorPtr> prim_py_evaluators_;
private:
const PrimEvaluatorMap &prim_constructors_;
FuncGraphManagerPtr func_graph_manager_;
std::unordered_map<AbstractFunctionPtr, EvaluatorPtr> constructors_;
AnfNodeConfigMap anfnode_config_map_;
// Use a list to trace multiple evaluators.
std::list<std::pair<EvaluatorPtr, AbstractBasePtrList>> eval_trace_;
std::map<EvaluatorPtr, EvaluatorPtr> multi_poss_;
AnalysisContextPtr Run(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context,
const ConfigPtrList &args_conf_list);
EvalResultPtr Eval(const AnfNodeConfigPtr &conf);
EvaluatorPtr _GetEvaluatorFor(const AbstractFunctionPtr &fn);
EvalResultPtr ExecuteEvaluators(const std::vector<EvaluatorPtr> &evaluators, const AnfNodeConfigPtr &out_conf,
const ConfigPtrList &args_conf_list);
EvalResultPtr ExecuteMultipleEvaluators(const std::vector<EvaluatorPtr> &evaluators, const AnfNodeConfigPtr &out_conf,
const ConfigPtrList &args_conf_list);
#ifdef DEBUG
std::vector<AnfNodePtr> compute_conf_stack_;
#endif
};
// Translate the value to an abstract value.
// Arguments:
// value: The value to convert.
// context: The context in which the value was found, used if the value is a Graph.
// conf: The Config to the valuenode we are converting, if there is one,
// so that we can generate a tracking_id.
AbstractBasePtr ToAbstract(const ValuePtr &value, const AnalysisContextPtr &context = nullptr,
const AnfNodeConfigPtr &conf = nullptr);
// Convert a value to an abstract value.
// Arguments:
// v: The value to convert.
// broaden: If True, concrete values will be made more abstract, so e.g.
// the value 1234 would become ANYTHING.
AbstractBasePtr FromValueInside(const ValuePtr &value, bool broaden = false);
template <typename T>
AbstractBasePtr FromValue(const T &value, bool broaden = false) {
return FromValueInside(MakeValue(value), broaden);
}
EvalResultPtr EvalOnePrim(const PrimitivePtr &p, const AbstractBasePtrList &arg_specs);
} // namespace abstract
} // namespace mindspore
#endif // PIPELINE_STATIC_ANALYSIS_STATIC_ANALYSIS_H_