Resolve specialize error during transforming after block in PyNative mode.

pull/8293/head
Zhang Qinghua 5 years ago
parent e3b852ed9e
commit 4e6e68f187

@ -23,6 +23,7 @@
#include <utility>
#include <unordered_map>
#include <unordered_set>
#include <tuple>
#include "frontend/optimizer/irpass.h"
#include "frontend/optimizer/optimizer.h"
@ -42,13 +43,13 @@ class SpecializeTransform {
~SpecializeTransform() = default;
FuncGraphPtr operator()(const FuncGraphPtr &func_graph, std::vector<FuncGraphPtr> graph_args,
std::vector<PrimitivePtr> prim_args, std::vector<tensor::TensorPtr> value_args) {
std::vector<PrimitivePtr> prim_args, std::vector<tensor::TensorPtr> tensor_value_args) {
if (cache_.count(func_graph) == 0) {
cache_[func_graph] = {};
}
auto &cache = cache_[func_graph];
auto key = std::make_pair(graph_args, prim_args);
auto key = std::make_tuple(graph_args, prim_args, tensor_value_args);
if (cache.count(key) == 0) {
auto mng = func_graph->manager();
MS_EXCEPTION_IF_NULL(mng);
@ -70,8 +71,8 @@ class SpecializeTransform {
(void)mng->Replace(params[i], arg);
continue;
}
if (value_args[i] != nullptr) {
auto &const_tensor = *value_args[i];
if (tensor_value_args[i] != nullptr) {
auto &const_tensor = *tensor_value_args[i];
auto const_tensor_ptr = std::make_shared<tensor::Tensor>(const_tensor);
AnfNodePtr arg = NewValueNode(const_tensor_ptr);
(void)mng->Replace(params[i], arg);
@ -87,8 +88,10 @@ class SpecializeTransform {
}
private:
std::unordered_map<FuncGraphPtr,
std::map<std::pair<std::vector<FuncGraphPtr>, std::vector<PrimitivePtr>>, FuncGraphPtr>>
std::unordered_map<
FuncGraphPtr,
std::map<std::tuple<std::vector<FuncGraphPtr>, std::vector<PrimitivePtr>, std::vector<tensor::TensorPtr>>,
FuncGraphPtr>>
cache_;
};
} // namespace internal
@ -116,7 +119,7 @@ class SpecializeOnGraphArguments : public AnfVisitor {
std::vector<FuncGraphPtr> graph_args;
std::vector<PrimitivePtr> prim_args;
std::vector<tensor::TensorPtr> value_node_args;
std::vector<tensor::TensorPtr> tensor_value_args;
std::vector<AnfNodePtr> new_xs;
bool hasVNode = false;
for (size_t i = 1; i < inputs.size(); i++) {
@ -124,24 +127,24 @@ class SpecializeOnGraphArguments : public AnfVisitor {
auto fg_vnode = GetValueNode<FuncGraphPtr>(inputs[i]);
graph_args.push_back(fg_vnode);
prim_args.emplace_back(nullptr);
value_node_args.emplace_back(nullptr);
tensor_value_args.emplace_back(nullptr);
hasVNode = true;
} else if (IsValueNode<Primitive>(inputs[i])) {
auto p_vnode = GetValueNode<PrimitivePtr>(inputs[i]);
graph_args.emplace_back(nullptr);
prim_args.push_back(p_vnode);
value_node_args.emplace_back(nullptr);
tensor_value_args.emplace_back(nullptr);
hasVNode = true;
} else if (IsValueNode<tensor::Tensor>(inputs[i])) {
tensor::TensorPtr t_vnode = GetValueNode<tensor::TensorPtr>(inputs[i]);
graph_args.emplace_back(nullptr);
prim_args.emplace_back(nullptr);
value_node_args.emplace_back(t_vnode);
tensor_value_args.emplace_back(t_vnode);
hasVNode = true;
} else {
graph_args.emplace_back(nullptr);
prim_args.emplace_back(nullptr);
value_node_args.emplace_back(nullptr);
tensor_value_args.emplace_back(nullptr);
new_xs.push_back(inputs[i]);
}
}
@ -150,7 +153,7 @@ class SpecializeOnGraphArguments : public AnfVisitor {
return nullptr;
}
auto new_fg = specialize_transform_(inp0_fg, graph_args, prim_args, value_node_args);
auto new_fg = specialize_transform_(inp0_fg, graph_args, prim_args, tensor_value_args);
(void)new_xs.insert(new_xs.begin(), NewValueNode(new_fg));
return node->func_graph()->NewCNode(new_xs);

@ -17,6 +17,7 @@
#include "pipeline/jit/static_analysis/evaluator.h"
#include <algorithm>
#include <utility>
#include <unordered_set>
#include "ir/func_graph_cloner.h"

Loading…
Cancel
Save