|
|
|
@ -23,6 +23,7 @@
|
|
|
|
|
#include <utility>
|
|
|
|
|
#include <unordered_map>
|
|
|
|
|
#include <unordered_set>
|
|
|
|
|
#include <tuple>
|
|
|
|
|
|
|
|
|
|
#include "frontend/optimizer/irpass.h"
|
|
|
|
|
#include "frontend/optimizer/optimizer.h"
|
|
|
|
@ -42,13 +43,13 @@ class SpecializeTransform {
|
|
|
|
|
~SpecializeTransform() = default;
|
|
|
|
|
|
|
|
|
|
FuncGraphPtr operator()(const FuncGraphPtr &func_graph, std::vector<FuncGraphPtr> graph_args,
|
|
|
|
|
std::vector<PrimitivePtr> prim_args, std::vector<tensor::TensorPtr> value_args) {
|
|
|
|
|
std::vector<PrimitivePtr> prim_args, std::vector<tensor::TensorPtr> tensor_value_args) {
|
|
|
|
|
if (cache_.count(func_graph) == 0) {
|
|
|
|
|
cache_[func_graph] = {};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto &cache = cache_[func_graph];
|
|
|
|
|
auto key = std::make_pair(graph_args, prim_args);
|
|
|
|
|
auto key = std::make_tuple(graph_args, prim_args, tensor_value_args);
|
|
|
|
|
if (cache.count(key) == 0) {
|
|
|
|
|
auto mng = func_graph->manager();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(mng);
|
|
|
|
@ -70,8 +71,8 @@ class SpecializeTransform {
|
|
|
|
|
(void)mng->Replace(params[i], arg);
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
if (value_args[i] != nullptr) {
|
|
|
|
|
auto &const_tensor = *value_args[i];
|
|
|
|
|
if (tensor_value_args[i] != nullptr) {
|
|
|
|
|
auto &const_tensor = *tensor_value_args[i];
|
|
|
|
|
auto const_tensor_ptr = std::make_shared<tensor::Tensor>(const_tensor);
|
|
|
|
|
AnfNodePtr arg = NewValueNode(const_tensor_ptr);
|
|
|
|
|
(void)mng->Replace(params[i], arg);
|
|
|
|
@ -87,8 +88,10 @@ class SpecializeTransform {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
std::unordered_map<FuncGraphPtr,
|
|
|
|
|
std::map<std::pair<std::vector<FuncGraphPtr>, std::vector<PrimitivePtr>>, FuncGraphPtr>>
|
|
|
|
|
std::unordered_map<
|
|
|
|
|
FuncGraphPtr,
|
|
|
|
|
std::map<std::tuple<std::vector<FuncGraphPtr>, std::vector<PrimitivePtr>, std::vector<tensor::TensorPtr>>,
|
|
|
|
|
FuncGraphPtr>>
|
|
|
|
|
cache_;
|
|
|
|
|
};
|
|
|
|
|
} // namespace internal
|
|
|
|
@ -116,7 +119,7 @@ class SpecializeOnGraphArguments : public AnfVisitor {
|
|
|
|
|
|
|
|
|
|
std::vector<FuncGraphPtr> graph_args;
|
|
|
|
|
std::vector<PrimitivePtr> prim_args;
|
|
|
|
|
std::vector<tensor::TensorPtr> value_node_args;
|
|
|
|
|
std::vector<tensor::TensorPtr> tensor_value_args;
|
|
|
|
|
std::vector<AnfNodePtr> new_xs;
|
|
|
|
|
bool hasVNode = false;
|
|
|
|
|
for (size_t i = 1; i < inputs.size(); i++) {
|
|
|
|
@ -124,24 +127,24 @@ class SpecializeOnGraphArguments : public AnfVisitor {
|
|
|
|
|
auto fg_vnode = GetValueNode<FuncGraphPtr>(inputs[i]);
|
|
|
|
|
graph_args.push_back(fg_vnode);
|
|
|
|
|
prim_args.emplace_back(nullptr);
|
|
|
|
|
value_node_args.emplace_back(nullptr);
|
|
|
|
|
tensor_value_args.emplace_back(nullptr);
|
|
|
|
|
hasVNode = true;
|
|
|
|
|
} else if (IsValueNode<Primitive>(inputs[i])) {
|
|
|
|
|
auto p_vnode = GetValueNode<PrimitivePtr>(inputs[i]);
|
|
|
|
|
graph_args.emplace_back(nullptr);
|
|
|
|
|
prim_args.push_back(p_vnode);
|
|
|
|
|
value_node_args.emplace_back(nullptr);
|
|
|
|
|
tensor_value_args.emplace_back(nullptr);
|
|
|
|
|
hasVNode = true;
|
|
|
|
|
} else if (IsValueNode<tensor::Tensor>(inputs[i])) {
|
|
|
|
|
tensor::TensorPtr t_vnode = GetValueNode<tensor::TensorPtr>(inputs[i]);
|
|
|
|
|
graph_args.emplace_back(nullptr);
|
|
|
|
|
prim_args.emplace_back(nullptr);
|
|
|
|
|
value_node_args.emplace_back(t_vnode);
|
|
|
|
|
tensor_value_args.emplace_back(t_vnode);
|
|
|
|
|
hasVNode = true;
|
|
|
|
|
} else {
|
|
|
|
|
graph_args.emplace_back(nullptr);
|
|
|
|
|
prim_args.emplace_back(nullptr);
|
|
|
|
|
value_node_args.emplace_back(nullptr);
|
|
|
|
|
tensor_value_args.emplace_back(nullptr);
|
|
|
|
|
new_xs.push_back(inputs[i]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -150,7 +153,7 @@ class SpecializeOnGraphArguments : public AnfVisitor {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto new_fg = specialize_transform_(inp0_fg, graph_args, prim_args, value_node_args);
|
|
|
|
|
auto new_fg = specialize_transform_(inp0_fg, graph_args, prim_args, tensor_value_args);
|
|
|
|
|
(void)new_xs.insert(new_xs.begin(), NewValueNode(new_fg));
|
|
|
|
|
|
|
|
|
|
return node->func_graph()->NewCNode(new_xs);
|
|
|
|
|