diff --git a/mindspore/ccsrc/frontend/optimizer/irpass.cc b/mindspore/ccsrc/frontend/optimizer/irpass.cc index acdcb93083..f7e7027664 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass.cc +++ b/mindspore/ccsrc/frontend/optimizer/irpass.cc @@ -95,6 +95,10 @@ OptimizeIRPassLib::OptimizeIRPassLib() { incorporate_env_getitem_ = MakeSubstitution(std::make_shared(), "incorporate_env_get_item", prim::kPrimEnvGetItem); + incorporate_env_getitem_switch_layer_ = + MakeSubstitution(std::make_shared(), "incorporate_env_getitem_switch_layer", + prim::kPrimEnvGetItem); + // Ref eliminate make_ref_eliminate_ = MakeSubstitution(std::make_shared(), "make_ref_eliminate", prim::kPrimMakeRef); diff --git a/mindspore/ccsrc/frontend/optimizer/irpass.h b/mindspore/ccsrc/frontend/optimizer/irpass.h index 90ef450d22..afb485ead8 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass.h @@ -58,6 +58,7 @@ class OptimizeIRPassLib { SubstitutionPtr incorporate_env_getitem_; SubstitutionPtr incorporate_env_getitem_bypass_recursive_; SubstitutionPtr incorporate_env_getitem_switch_; + SubstitutionPtr incorporate_env_getitem_switch_layer_; // Ref eliminate SubstitutionPtr make_ref_eliminate_; diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/env_item_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/env_item_eliminate.h index 6fa1304dd7..9286379a75 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/env_item_eliminate.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/env_item_eliminate.h @@ -91,6 +91,69 @@ class EnvGetitemTransform { std::unordered_map, FuncGraphPtr, PairHasher>> cache_; }; + +class EnvGetitemTransformACrossGraph { + public: + EnvGetitemTransformACrossGraph() : cache_() {} + ~EnvGetitemTransformACrossGraph() = default; + + FuncGraphPtr operator()(const FuncGraphPtr &fg, const SymbolicKeyInstancePtr &key, const AnfNodePtr &default_node) { + if (cache_.find(fg) == cache_.end()) { + cache_[fg] = {}; + } + + auto &cache = cache_[fg]; + auto hash_key = std::make_pair(key, default_node); + if (cache.find(hash_key) == cache.end()) { + std::ostringstream ss("env", std::ostringstream::app); + if (key->node() != nullptr) { + ss << key->node()->ToString(); + } + + auto new_fg_outer = TransformableClone(fg, std::make_shared(ss.str())); + auto output_outer = new_fg_outer->output(); + if (!IsValueNode(output_outer)) { + MS_LOG(WARNING) << "Output of outer graph should be a func_graph"; + return nullptr; + } + auto fg_inner = GetValueNode(output_outer); + auto new_fg = TransformableClone(fg_inner, std::make_shared(ss.str())); + new_fg_outer->set_output(NewValueNode(new_fg)); + + auto env = new_fg->output(); + while (IsPrimitiveCNode(env, prim::kPrimEnvSetItem)) { + // {prim::kPrimEnvSetItem, env, symbolickey, value} + auto &inputs = env->cast()->inputs(); + if (inputs.size() != 4) { + MS_LOG(WARNING) << "Input size should be 4"; + return nullptr; + } + if (!IsValueNode(inputs[2])) { + MS_LOG(DEBUG) << "Input 2 is not a SymbolicKeyInstance?"; + return nullptr; + } + + env = inputs[1]; + auto value = inputs[3]; + auto key2 = GetValueNode(inputs[2]); + if (*key2 == *key) { + new_fg->set_output(value); + cache[hash_key] = new_fg_outer; + return new_fg_outer; + } + } + new_fg->set_output(new_fg->NewCNode({NewValueNode(prim::kPrimEnvGetItem), env, NewValueNode(key), default_node})); + cache[hash_key] = new_fg_outer; + } + + return cache[hash_key]; + } + + private: + std::unordered_map, FuncGraphPtr, PairHasher>> + cache_; +}; } // namespace internal // {prim::kPrimEnvGetItem, C1, C2, Y} -> Y @@ -358,6 +421,78 @@ class IncorporateEnvGetitemSwitch : public AnfVisitor { bool is_match_{false}; internal::EnvGetitemTransform env_get_item_transform_; }; + +// {prim::kPrimEnvGetItem, {{{prim::kPrimSwitchLayer, X, {prim::kPrimMakeTuple, G1, G2...}}, Xs}, Ys}, C, Y} +class IncorporateEnvGetitemSwitchLayer : public AnfVisitor { + public: + IncorporateEnvGetitemSwitchLayer() : env_get_item_transform_() {} + ~IncorporateEnvGetitemSwitchLayer() override = default; + + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + is_match_ = false; + AnfVisitor::Match(prim::kPrimEnvGetItem, {IsCNode, IsValueNode, IsNode})(node); + if (!is_match_ || node->func_graph() == nullptr) { + return nullptr; + } + // {prim::kPrimEnvGetItem, {...}, C, Y} + auto cnode = node->cast(); + auto inp1 = cnode->input(1)->cast(); + auto key = GetValueNode(cnode->input(2)); + auto default_v = cnode->input(3); + + // {{prim::kPrimSwitchLayer, X, {prim::kPrimMakeTuple, G1, G2...}}, Xs}, Ys} + auto &inputs_outer = inp1->inputs(); + if (!inputs_outer[0]->isa()) { + return nullptr; + } + std::vector args_outer; + args_outer.insert(args_outer.end(), inputs_outer.begin() + 1, inputs_outer.end()); + auto &input_switch_layer = inputs_outer[0]->cast()->inputs(); + + is_match_ = false; + AnfVisitor::Match(prim::kPrimSwitchLayer, {IsNode, IsCNode})(input_switch_layer[0]); + if (!is_match_) { + return nullptr; + } + std::vector args; + (void)args.insert(args.end(), input_switch_layer.begin() + 1, input_switch_layer.end()); + + // {prim::kPrimSwitchLayers, X, {prim::kPrimMakeTuple, G1, G2...}} + auto sw = input_switch_layer[0]->cast(); + std::vector graphs{}; + auto graphs_cnode = sw->input(2)->cast(); + auto &graphs_inputs = graphs_cnode->inputs(); + if (IsPrimitiveCNode(graphs_cnode, prim::kPrimMakeTuple) && IsValueNode(graphs_inputs[1])) { + (void)std::transform(graphs_inputs.begin() + 1, graphs_inputs.end(), std::back_inserter(graphs), + [](const AnfNodePtr &vnode) { return GetValueNode(vnode); }); + } + if (graphs.empty()) { + return nullptr; + } + + auto fg = node->func_graph(); + std::vector layers; + for (auto &graph : graphs) { + auto fg_transform = env_get_item_transform_(graph, key, default_v); + if (fg_transform == nullptr) { + return nullptr; + } + layers.push_back(NewValueNode(fg_transform)); + } + auto layers_node = fg->NewCNode(prim::kPrimMakeTuple, layers); + auto new_sw = fg->NewCNode({NewValueNode(prim::kPrimSwitchLayer), sw->input(1), layers_node}); + args.insert(args.begin(), new_sw); + auto inner_call = fg->NewCNode(args); + args_outer.insert(args_outer.begin(), inner_call); + return fg->NewCNode(args_outer); + } + + void Visit(const AnfNodePtr &) override { is_match_ = true; } + + private: + bool is_match_{false}; + internal::EnvGetitemTransformACrossGraph env_get_item_transform_; +}; } // namespace irpass } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/incorporate_getitem.h b/mindspore/ccsrc/frontend/optimizer/irpass/incorporate_getitem.h index 9c02df6b2f..a4a2f494e6 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/incorporate_getitem.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/incorporate_getitem.h @@ -72,6 +72,52 @@ class GetitemTransform { private: std::unordered_map> cache_; }; + +class GetItemTransformACrossGraph { + public: + GetItemTransformACrossGraph() : cache_() {} + ~GetItemTransformACrossGraph() = default; + + FuncGraphPtr operator()(const FuncGraphPtr &fg, int idx) { + if (cache_.find(fg) == cache_.end()) { + cache_[fg] = {}; + } + + auto &cache = cache_[fg]; + if (cache.find(idx) == cache.end()) { + std::ostringstream ss("tp", std::ostringstream::app); + ss << idx; + + auto new_fg_outer = TransformableClone(fg, std::make_shared(ss.str())); + auto output_outer = new_fg_outer->output(); + if (!IsValueNode(output_outer)) { + MS_LOG(WARNING) << "Output of outer graph should be a func_graph"; + return nullptr; + } + auto fg_inner = GetValueNode(output_outer); + auto new_fg = TransformableClone(fg_inner, std::make_shared(ss.str())); + new_fg_outer->set_output(NewValueNode(new_fg)); + auto output = new_fg->output(); + if (IsPrimitiveCNode(output, prim::kPrimMakeTuple)) { + auto cnode = output->cast(); + auto ids = IntToSize(idx + 1); + // Inputs should be [make_tuple, item1, item2, ...], so have to offset idx in tuple_getitem by 1. + if (ids >= cnode->size()) { + MS_LOG(EXCEPTION) << "index " << ids << " is out of inputs length " << cnode->size(); + } + new_fg->set_output(cnode->input(ids)); + } else { + new_fg->set_output(new_fg->NewCNode({NewValueNode(prim::kPrimTupleGetItem), output, NewValueNode(idx)})); + } + + cache[idx] = new_fg_outer; + } + return cache[idx]; + } + + private: + std::unordered_map> cache_; +}; } // namespace internal // {prim::kPrimTupleGetItem, {G, Xs}, C} @@ -385,13 +431,199 @@ class IncorporateGetitemSwitch : public AnfVisitor { internal::GetitemTransform getitem_transform_; }; +// {prim::kPrimTupleGetItem, {{prim::kPrimSwitchLayer, X, {prim::kPrimMakeTuple, G1, G2...}}, Xs}, C} +class IncorporateGetitemSwitchLayerA : public AnfVisitor { + public: + IncorporateGetitemSwitchLayerA() : getitem_transform_() {} + ~IncorporateGetitemSwitchLayerA() override = default; + + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + Reset(); + is_in_get_ = true; + AnfVisitor::Match(prim::kPrimTupleGetItem, {IsCNode, IsValueNode})(node); + is_in_get_ = false; + + auto fg = node->func_graph(); + if (idx_ == -1 || switch_layer_ == nullptr || fg == nullptr) { + return nullptr; + } + + is_in_switch_ = true; + AnfVisitor::Match(prim::kPrimSwitchLayer, {IsNode, IsCNode})(switch_layer_); + is_in_switch_ = false; + + if (graphs_.empty()) { + return nullptr; + } + + std::vector layers; + for (auto &graph : graphs_) { + auto fg_transform = getitem_transform_(graph, idx_); + if (fg_transform == nullptr) { + return nullptr; + } + layers.push_back(NewValueNode(fg_transform)); + } + auto layers_node = fg->NewCNode(prim::kPrimMakeTuple, layers); + std::vector sw_args{NewValueNode(prim::kPrimSwitchLayer), x_, layers_node}; + auto sw_node = fg->NewCNode(sw_args); + (void)args_.insert(args_.begin(), sw_node); + + return fg->NewCNode(args_); + } + + void Visit(const AnfNodePtr &node) override { + if (is_in_switch_ && x_ == nullptr) { + x_ = node; + return; + } + AnfVisitor::Visit(node); + } + + void Visit(const CNodePtr &cnode) override { + if (is_in_get_ && cnode->size() != 0) { + auto &inputs = cnode->inputs(); + switch_layer_ = inputs[0]; + (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(args_)); + } + if (is_in_switch_ && cnode->size() > 2) { + auto &inputs = cnode->inputs(); + if (IsPrimitiveCNode(cnode, prim::kPrimMakeTuple) && IsValueNode(inputs[1])) { + (void)std::transform(inputs.begin() + 1, inputs.end(), std::back_inserter(graphs_), + [](const AnfNodePtr &vnode) { return GetValueNode(vnode); }); + } + } + } + + void Visit(const ValueNodePtr &vnode) override { + if (is_in_get_) { + idx_ = GetValue(vnode->value()); + } + } + + void Reset() { + x_ = nullptr; + graphs_.clear(); + switch_layer_ = nullptr; + args_.clear(); + is_in_get_ = false; + is_in_switch_ = false; + } + + private: + int idx_{-1}; + AnfNodePtr switch_layer_{nullptr}, x_{nullptr}; + std::vector graphs_{}; + bool is_in_get_{false}, is_in_switch_{false}; + std::vector args_{}; + internal::GetitemTransform getitem_transform_; +}; + +// {prim::kPrimTupleGetItem, {{{prim::kPrimSwitchLayer, X, {prim::kPrimMakeTuple, G1, G2...}}, Xs}, Ys}, C} +class IncorporateGetitemSwitchLayerB : public AnfVisitor { + public: + IncorporateGetitemSwitchLayerB() : getitem_transform_() {} + ~IncorporateGetitemSwitchLayerB() override = default; + + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + Reset(); + is_in_get_ = true; + AnfVisitor::Match(prim::kPrimTupleGetItem, {IsCNode, IsValueNode})(node); + is_in_get_ = false; + + auto fg = node->func_graph(); + if (idx_ == -1 || switch_layer_call_ == nullptr || !switch_layer_call_->isa() || fg == nullptr) { + return nullptr; + } + + auto &switch_layer_call_inputs = switch_layer_call_->cast()->inputs(); + (void)std::copy(switch_layer_call_inputs.begin() + 1, switch_layer_call_inputs.end(), std::back_inserter(args_)); + + is_in_switch_ = true; + AnfVisitor::Match(prim::kPrimSwitchLayer, {IsNode, IsCNode})(switch_layer_call_inputs[0]); + is_in_switch_ = false; + + if (graphs_.empty()) { + return nullptr; + } + + std::vector layers; + for (auto &graph : graphs_) { + auto fg_transform = getitem_transform_(graph, idx_); + if (fg_transform == nullptr) { + return nullptr; + } + layers.push_back(NewValueNode(fg_transform)); + } + auto layers_node = fg->NewCNode(prim::kPrimMakeTuple, layers); + std::vector sw_args{NewValueNode(prim::kPrimSwitchLayer), x_, layers_node}; + auto sw_node = fg->NewCNode(sw_args); + (void)args_.insert(args_.begin(), sw_node); + auto call_switch_layer = fg->NewCNode(args_); + (void)outer_call_args_.insert(outer_call_args_.begin(), call_switch_layer); + return fg->NewCNode(outer_call_args_); + } + + void Visit(const AnfNodePtr &node) override { + if (is_in_switch_ && x_ == nullptr) { + x_ = node; + return; + } + AnfVisitor::Visit(node); + } + + void Visit(const CNodePtr &cnode) override { + if (is_in_get_ && cnode->size() != 0) { + auto &inputs = cnode->inputs(); + switch_layer_call_ = inputs[0]; + (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(outer_call_args_)); + } + if (is_in_switch_ && cnode->size() > 2) { + auto &inputs = cnode->inputs(); + if (IsPrimitiveCNode(cnode, prim::kPrimMakeTuple) && IsValueNode(inputs[1])) { + (void)std::transform(inputs.begin() + 1, inputs.end(), std::back_inserter(graphs_), + [](const AnfNodePtr &vnode) { return GetValueNode(vnode); }); + } + } + } + + void Visit(const ValueNodePtr &vnode) override { + if (is_in_get_) { + idx_ = GetValue(vnode->value()); + } + } + + void Reset() { + x_ = nullptr; + graphs_.clear(); + switch_layer_call_ = nullptr; + args_.clear(); + outer_call_args_.clear(); + is_in_get_ = false; + is_in_switch_ = false; + } + + private: + int idx_{-1}; + AnfNodePtr switch_layer_call_{nullptr}, x_{nullptr}; + std::vector graphs_{}; + bool is_in_get_{false}, is_in_switch_{false}; + std::vector args_{}; + std::vector outer_call_args_{}; + internal::GetItemTransformACrossGraph getitem_transform_; +}; + class IncorporateGetitemSet : public OptimizerCaller { public: IncorporateGetitemSet() : incorporate_getitem_(std::make_shared()), - incorporate_getitem_switch_(std::make_shared()) { + incorporate_getitem_switch_(std::make_shared()), + incorporate_getitem_switch_layer_a_(std::make_shared()), + incorporate_getitem_switch_layer_b_(std::make_shared()) { eliminaters_.emplace_back(incorporate_getitem_); eliminaters_.emplace_back(incorporate_getitem_switch_); + eliminaters_.emplace_back(incorporate_getitem_switch_layer_a_); + eliminaters_.emplace_back(incorporate_getitem_switch_layer_b_); } ~IncorporateGetitemSet() = default; @@ -407,7 +639,8 @@ class IncorporateGetitemSet : public OptimizerCaller { } private: - OptimizerCallerPtr incorporate_getitem_, incorporate_getitem_switch_; + OptimizerCallerPtr incorporate_getitem_, incorporate_getitem_switch_, incorporate_getitem_switch_layer_a_, + incorporate_getitem_switch_layer_b_; std::vector eliminaters_{}; }; } // namespace irpass diff --git a/mindspore/ccsrc/pipeline/jit/pass.cc b/mindspore/ccsrc/pipeline/jit/pass.cc index 29d2358cd0..966275b964 100644 --- a/mindspore/ccsrc/pipeline/jit/pass.cc +++ b/mindspore/ccsrc/pipeline/jit/pass.cc @@ -180,7 +180,7 @@ OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) { {irpass.zero_like_fill_zero_, irpass.item_tuple_eliminate_, irpass.float_tuple_getitem_switch_, irpass.reset_defer_inline_, irpass.inline_, irpass.special_op_eliminate_, irpass.get_make_ref_eliminate_, irpass.incorporate_env_getitem_, irpass.incorporate_env_getitem_switch_, irpass.env_get_item_eliminate_, - irpass.value_based_eliminate_}); + irpass.incorporate_env_getitem_switch_layer_, irpass.value_based_eliminate_}); opt::OptPassConfig b_2 = opt::OptPassConfig({ irpass.replace_refkey_by_param_, irpass.make_ref_eliminate_, diff --git a/tests/ut/python/ops/test_control_ops.py b/tests/ut/python/ops/test_control_ops.py index 53f0222abc..85cc97f4c1 100644 --- a/tests/ut/python/ops/test_control_ops.py +++ b/tests/ut/python/ops/test_control_ops.py @@ -464,6 +464,36 @@ def test_switch_layer_with_single_prim(): C.grad_all(net)(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32))) +def test_switch_layer_env_eliminate(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.conv = nn.Conv2d(1, 1, 3, pad_mode='same') + self.conv2 = nn.Conv2d(1, 1, 5, pad_mode='same') + self.funs = (self.conv, self.conv2) + + def construct(self, x, index): + x = self.funs[index](x) + return x + + class NetGrad(nn.Cell): + def __init__(self, net): + super(NetGrad, self).__init__() + self.grad_op = C.GradOperation('grad', get_by_list=True, sens_param=False) + self.net = net + self.weights = ParameterTuple(self.net.trainable_params()) + + def construct(self, x, index): + weights = self.weights + grad = self.grad_op(self.net, weights)(x, index) + return grad + net = Net() + net2 = NetGrad(net) + x = Tensor(np.ones((3, 1, 12, 12)), ms.float32) + i = Tensor(1, ms.int32) + net2(x, i) + + def test_control_depend_check(): with pytest.raises(TypeError) as e: P.ControlDepend(0.0)