|
|
|
@ -91,6 +91,69 @@ class EnvGetitemTransform {
|
|
|
|
|
std::unordered_map<std::pair<SymbolicKeyInstancePtr, AnfNodePtr>, FuncGraphPtr, PairHasher>>
|
|
|
|
|
cache_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class EnvGetitemTransformACrossGraph {
|
|
|
|
|
public:
|
|
|
|
|
EnvGetitemTransformACrossGraph() : cache_() {}
|
|
|
|
|
~EnvGetitemTransformACrossGraph() = default;
|
|
|
|
|
|
|
|
|
|
FuncGraphPtr operator()(const FuncGraphPtr &fg, const SymbolicKeyInstancePtr &key, const AnfNodePtr &default_node) {
|
|
|
|
|
if (cache_.find(fg) == cache_.end()) {
|
|
|
|
|
cache_[fg] = {};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto &cache = cache_[fg];
|
|
|
|
|
auto hash_key = std::make_pair(key, default_node);
|
|
|
|
|
if (cache.find(hash_key) == cache.end()) {
|
|
|
|
|
std::ostringstream ss("env", std::ostringstream::app);
|
|
|
|
|
if (key->node() != nullptr) {
|
|
|
|
|
ss << key->node()->ToString();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto new_fg_outer = TransformableClone(fg, std::make_shared<TraceTransform>(ss.str()));
|
|
|
|
|
auto output_outer = new_fg_outer->output();
|
|
|
|
|
if (!IsValueNode<FuncGraph>(output_outer)) {
|
|
|
|
|
MS_LOG(WARNING) << "Output of outer graph should be a func_graph";
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
auto fg_inner = GetValueNode<FuncGraphPtr>(output_outer);
|
|
|
|
|
auto new_fg = TransformableClone(fg_inner, std::make_shared<TraceTransform>(ss.str()));
|
|
|
|
|
new_fg_outer->set_output(NewValueNode(new_fg));
|
|
|
|
|
|
|
|
|
|
auto env = new_fg->output();
|
|
|
|
|
while (IsPrimitiveCNode(env, prim::kPrimEnvSetItem)) {
|
|
|
|
|
// {prim::kPrimEnvSetItem, env, symbolickey, value}
|
|
|
|
|
auto &inputs = env->cast<CNodePtr>()->inputs();
|
|
|
|
|
if (inputs.size() != 4) {
|
|
|
|
|
MS_LOG(WARNING) << "Input size should be 4";
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
if (!IsValueNode<SymbolicKeyInstance>(inputs[2])) {
|
|
|
|
|
MS_LOG(DEBUG) << "Input 2 is not a SymbolicKeyInstance?";
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
env = inputs[1];
|
|
|
|
|
auto value = inputs[3];
|
|
|
|
|
auto key2 = GetValueNode<SymbolicKeyInstancePtr>(inputs[2]);
|
|
|
|
|
if (*key2 == *key) {
|
|
|
|
|
new_fg->set_output(value);
|
|
|
|
|
cache[hash_key] = new_fg_outer;
|
|
|
|
|
return new_fg_outer;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
new_fg->set_output(new_fg->NewCNode({NewValueNode(prim::kPrimEnvGetItem), env, NewValueNode(key), default_node}));
|
|
|
|
|
cache[hash_key] = new_fg_outer;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return cache[hash_key];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
std::unordered_map<FuncGraphPtr,
|
|
|
|
|
std::unordered_map<std::pair<SymbolicKeyInstancePtr, AnfNodePtr>, FuncGraphPtr, PairHasher>>
|
|
|
|
|
cache_;
|
|
|
|
|
};
|
|
|
|
|
} // namespace internal
|
|
|
|
|
|
|
|
|
|
// {prim::kPrimEnvGetItem, C1, C2, Y} -> Y
|
|
|
|
@ -358,6 +421,78 @@ class IncorporateEnvGetitemSwitch : public AnfVisitor {
|
|
|
|
|
bool is_match_{false};
|
|
|
|
|
internal::EnvGetitemTransform env_get_item_transform_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// {prim::kPrimEnvGetItem, {{{prim::kPrimSwitchLayer, X, {prim::kPrimMakeTuple, G1, G2...}}, Xs}, Ys}, C, Y}
|
|
|
|
|
class IncorporateEnvGetitemSwitchLayer : public AnfVisitor {
|
|
|
|
|
public:
|
|
|
|
|
IncorporateEnvGetitemSwitchLayer() : env_get_item_transform_() {}
|
|
|
|
|
~IncorporateEnvGetitemSwitchLayer() override = default;
|
|
|
|
|
|
|
|
|
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
|
|
|
|
is_match_ = false;
|
|
|
|
|
AnfVisitor::Match(prim::kPrimEnvGetItem, {IsCNode, IsValueNode<SymbolicKeyInstance>, IsNode})(node);
|
|
|
|
|
if (!is_match_ || node->func_graph() == nullptr) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
// {prim::kPrimEnvGetItem, {...}, C, Y}
|
|
|
|
|
auto cnode = node->cast<CNodePtr>();
|
|
|
|
|
auto inp1 = cnode->input(1)->cast<CNodePtr>();
|
|
|
|
|
auto key = GetValueNode<SymbolicKeyInstancePtr>(cnode->input(2));
|
|
|
|
|
auto default_v = cnode->input(3);
|
|
|
|
|
|
|
|
|
|
// {{prim::kPrimSwitchLayer, X, {prim::kPrimMakeTuple, G1, G2...}}, Xs}, Ys}
|
|
|
|
|
auto &inputs_outer = inp1->inputs();
|
|
|
|
|
if (!inputs_outer[0]->isa<CNode>()) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
std::vector<AnfNodePtr> args_outer;
|
|
|
|
|
args_outer.insert(args_outer.end(), inputs_outer.begin() + 1, inputs_outer.end());
|
|
|
|
|
auto &input_switch_layer = inputs_outer[0]->cast<CNodePtr>()->inputs();
|
|
|
|
|
|
|
|
|
|
is_match_ = false;
|
|
|
|
|
AnfVisitor::Match(prim::kPrimSwitchLayer, {IsNode, IsCNode})(input_switch_layer[0]);
|
|
|
|
|
if (!is_match_) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
std::vector<AnfNodePtr> args;
|
|
|
|
|
(void)args.insert(args.end(), input_switch_layer.begin() + 1, input_switch_layer.end());
|
|
|
|
|
|
|
|
|
|
// {prim::kPrimSwitchLayers, X, {prim::kPrimMakeTuple, G1, G2...}}
|
|
|
|
|
auto sw = input_switch_layer[0]->cast<CNodePtr>();
|
|
|
|
|
std::vector<FuncGraphPtr> graphs{};
|
|
|
|
|
auto graphs_cnode = sw->input(2)->cast<CNodePtr>();
|
|
|
|
|
auto &graphs_inputs = graphs_cnode->inputs();
|
|
|
|
|
if (IsPrimitiveCNode(graphs_cnode, prim::kPrimMakeTuple) && IsValueNode<FuncGraph>(graphs_inputs[1])) {
|
|
|
|
|
(void)std::transform(graphs_inputs.begin() + 1, graphs_inputs.end(), std::back_inserter(graphs),
|
|
|
|
|
[](const AnfNodePtr &vnode) { return GetValueNode<FuncGraphPtr>(vnode); });
|
|
|
|
|
}
|
|
|
|
|
if (graphs.empty()) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto fg = node->func_graph();
|
|
|
|
|
std::vector<AnfNodePtr> layers;
|
|
|
|
|
for (auto &graph : graphs) {
|
|
|
|
|
auto fg_transform = env_get_item_transform_(graph, key, default_v);
|
|
|
|
|
if (fg_transform == nullptr) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
layers.push_back(NewValueNode(fg_transform));
|
|
|
|
|
}
|
|
|
|
|
auto layers_node = fg->NewCNode(prim::kPrimMakeTuple, layers);
|
|
|
|
|
auto new_sw = fg->NewCNode({NewValueNode(prim::kPrimSwitchLayer), sw->input(1), layers_node});
|
|
|
|
|
args.insert(args.begin(), new_sw);
|
|
|
|
|
auto inner_call = fg->NewCNode(args);
|
|
|
|
|
args_outer.insert(args_outer.begin(), inner_call);
|
|
|
|
|
return fg->NewCNode(args_outer);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Visit(const AnfNodePtr &) override { is_match_ = true; }
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
bool is_match_{false};
|
|
|
|
|
internal::EnvGetitemTransformACrossGraph env_get_item_transform_;
|
|
|
|
|
};
|
|
|
|
|
} // namespace irpass
|
|
|
|
|
} // namespace opt
|
|
|
|
|
} // namespace mindspore
|
|
|
|
|