|
|
@ -35,14 +35,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
namespace mindspore {
|
|
|
|
namespace mindspore {
|
|
|
|
namespace ad {
|
|
|
|
namespace ad {
|
|
|
|
using Registry = std::unordered_map<PrimitivePtr, FuncGraphPtr>;
|
|
|
|
struct PrimitiveTotalEqual {
|
|
|
|
|
|
|
|
bool operator()(PrimitivePtr const &t1, PrimitivePtr const &t2) const {
|
|
|
|
|
|
|
|
if (t1->name() != t2->name()) {
|
|
|
|
|
|
|
|
return false;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
auto const &attrs1 = t1->attrs();
|
|
|
|
|
|
|
|
auto const &attrs2 = t2->attrs();
|
|
|
|
|
|
|
|
if (attrs1.size() != attrs2.size()) {
|
|
|
|
|
|
|
|
return false;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for (auto &attr : attrs1) {
|
|
|
|
|
|
|
|
if (!t2->HasAttr(attr.first)) {
|
|
|
|
|
|
|
|
return false;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (!(*(attr.second) == *(t2->GetAttr(attr.first)))) {
|
|
|
|
|
|
|
|
return false;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return true;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
using Registry = std::unordered_map<PrimitivePtr, FuncGraphPtr, PrimitiveHasher, PrimitiveTotalEqual>;
|
|
|
|
class KPrim;
|
|
|
|
class KPrim;
|
|
|
|
extern KPrim g_k_prims;
|
|
|
|
extern KPrim g_k_prims;
|
|
|
|
class DFunctor;
|
|
|
|
class DFunctor;
|
|
|
|
using DFunctorPtr = std::shared_ptr<DFunctor>;
|
|
|
|
using DFunctorPtr = std::shared_ptr<DFunctor>;
|
|
|
|
|
|
|
|
|
|
|
|
// D Functor's rules to map closure object and morphisms.
|
|
|
|
// D Functor's rules to map closure object and morphisms.
|
|
|
|
class DFunctor {
|
|
|
|
class DFunctor : public std::enable_shared_from_this<DFunctor> {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
DFunctor(const FuncGraphPtr &primal_graph, const pipeline::ResourceBasePtr &resources);
|
|
|
|
DFunctor(const FuncGraphPtr &primal_graph, const pipeline::ResourceBasePtr &resources);
|
|
|
|
~DFunctor() = default;
|
|
|
|
~DFunctor() = default;
|
|
|
@ -54,7 +80,7 @@ class DFunctor {
|
|
|
|
// Construct user defined k object.
|
|
|
|
// Construct user defined k object.
|
|
|
|
FuncGraphPtr KUserDefined(const FuncGraphPtr &primal);
|
|
|
|
FuncGraphPtr KUserDefined(const FuncGraphPtr &primal);
|
|
|
|
// Register functor objects to form a global view.
|
|
|
|
// Register functor objects to form a global view.
|
|
|
|
void Init(const DFunctorPtr &functor, bool is_top = false);
|
|
|
|
void Init(bool is_top = false);
|
|
|
|
bool IsInScope(const AnfNodePtr &node);
|
|
|
|
bool IsInScope(const AnfNodePtr &node);
|
|
|
|
|
|
|
|
|
|
|
|
// Clear resources.
|
|
|
|
// Clear resources.
|
|
|
|