switch_layer incorporate env_get and tuple_get

pull/4663/head
panyifeng 5 years ago
parent 4f46c4277a
commit 22a9d02e9f

@ -95,6 +95,10 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
incorporate_env_getitem_ =
MakeSubstitution(std::make_shared<IncorporateEnvGetitem>(), "incorporate_env_get_item", prim::kPrimEnvGetItem);
incorporate_env_getitem_switch_layer_ =
MakeSubstitution(std::make_shared<IncorporateEnvGetitemSwitchLayer>(), "incorporate_env_getitem_switch_layer",
prim::kPrimEnvGetItem);
// Ref eliminate
make_ref_eliminate_ =
MakeSubstitution(std::make_shared<MakeRefEliminater>(), "make_ref_eliminate", prim::kPrimMakeRef);

@ -58,6 +58,7 @@ class OptimizeIRPassLib {
SubstitutionPtr incorporate_env_getitem_;
SubstitutionPtr incorporate_env_getitem_bypass_recursive_;
SubstitutionPtr incorporate_env_getitem_switch_;
SubstitutionPtr incorporate_env_getitem_switch_layer_;
// Ref eliminate
SubstitutionPtr make_ref_eliminate_;

@ -91,6 +91,69 @@ class EnvGetitemTransform {
std::unordered_map<std::pair<SymbolicKeyInstancePtr, AnfNodePtr>, FuncGraphPtr, PairHasher>>
cache_;
};
class EnvGetitemTransformACrossGraph {
public:
EnvGetitemTransformACrossGraph() : cache_() {}
~EnvGetitemTransformACrossGraph() = default;
FuncGraphPtr operator()(const FuncGraphPtr &fg, const SymbolicKeyInstancePtr &key, const AnfNodePtr &default_node) {
if (cache_.find(fg) == cache_.end()) {
cache_[fg] = {};
}
auto &cache = cache_[fg];
auto hash_key = std::make_pair(key, default_node);
if (cache.find(hash_key) == cache.end()) {
std::ostringstream ss("env", std::ostringstream::app);
if (key->node() != nullptr) {
ss << key->node()->ToString();
}
auto new_fg_outer = TransformableClone(fg, std::make_shared<TraceTransform>(ss.str()));
auto output_outer = new_fg_outer->output();
if (!IsValueNode<FuncGraph>(output_outer)) {
MS_LOG(WARNING) << "Output of outer graph should be a func_graph";
return nullptr;
}
auto fg_inner = GetValueNode<FuncGraphPtr>(output_outer);
auto new_fg = TransformableClone(fg_inner, std::make_shared<TraceTransform>(ss.str()));
new_fg_outer->set_output(NewValueNode(new_fg));
auto env = new_fg->output();
while (IsPrimitiveCNode(env, prim::kPrimEnvSetItem)) {
// {prim::kPrimEnvSetItem, env, symbolickey, value}
auto &inputs = env->cast<CNodePtr>()->inputs();
if (inputs.size() != 4) {
MS_LOG(WARNING) << "Input size should be 4";
return nullptr;
}
if (!IsValueNode<SymbolicKeyInstance>(inputs[2])) {
MS_LOG(DEBUG) << "Input 2 is not a SymbolicKeyInstance?";
return nullptr;
}
env = inputs[1];
auto value = inputs[3];
auto key2 = GetValueNode<SymbolicKeyInstancePtr>(inputs[2]);
if (*key2 == *key) {
new_fg->set_output(value);
cache[hash_key] = new_fg_outer;
return new_fg_outer;
}
}
new_fg->set_output(new_fg->NewCNode({NewValueNode(prim::kPrimEnvGetItem), env, NewValueNode(key), default_node}));
cache[hash_key] = new_fg_outer;
}
return cache[hash_key];
}
private:
std::unordered_map<FuncGraphPtr,
std::unordered_map<std::pair<SymbolicKeyInstancePtr, AnfNodePtr>, FuncGraphPtr, PairHasher>>
cache_;
};
} // namespace internal
// {prim::kPrimEnvGetItem, C1, C2, Y} -> Y
@ -358,6 +421,78 @@ class IncorporateEnvGetitemSwitch : public AnfVisitor {
bool is_match_{false};
internal::EnvGetitemTransform env_get_item_transform_;
};
// {prim::kPrimEnvGetItem, {{{prim::kPrimSwitchLayer, X, {prim::kPrimMakeTuple, G1, G2...}}, Xs}, Ys}, C, Y}
class IncorporateEnvGetitemSwitchLayer : public AnfVisitor {
public:
IncorporateEnvGetitemSwitchLayer() : env_get_item_transform_() {}
~IncorporateEnvGetitemSwitchLayer() override = default;
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
is_match_ = false;
AnfVisitor::Match(prim::kPrimEnvGetItem, {IsCNode, IsValueNode<SymbolicKeyInstance>, IsNode})(node);
if (!is_match_ || node->func_graph() == nullptr) {
return nullptr;
}
// {prim::kPrimEnvGetItem, {...}, C, Y}
auto cnode = node->cast<CNodePtr>();
auto inp1 = cnode->input(1)->cast<CNodePtr>();
auto key = GetValueNode<SymbolicKeyInstancePtr>(cnode->input(2));
auto default_v = cnode->input(3);
// {{prim::kPrimSwitchLayer, X, {prim::kPrimMakeTuple, G1, G2...}}, Xs}, Ys}
auto &inputs_outer = inp1->inputs();
if (!inputs_outer[0]->isa<CNode>()) {
return nullptr;
}
std::vector<AnfNodePtr> args_outer;
args_outer.insert(args_outer.end(), inputs_outer.begin() + 1, inputs_outer.end());
auto &input_switch_layer = inputs_outer[0]->cast<CNodePtr>()->inputs();
is_match_ = false;
AnfVisitor::Match(prim::kPrimSwitchLayer, {IsNode, IsCNode})(input_switch_layer[0]);
if (!is_match_) {
return nullptr;
}
std::vector<AnfNodePtr> args;
(void)args.insert(args.end(), input_switch_layer.begin() + 1, input_switch_layer.end());
// {prim::kPrimSwitchLayers, X, {prim::kPrimMakeTuple, G1, G2...}}
auto sw = input_switch_layer[0]->cast<CNodePtr>();
std::vector<FuncGraphPtr> graphs{};
auto graphs_cnode = sw->input(2)->cast<CNodePtr>();
auto &graphs_inputs = graphs_cnode->inputs();
if (IsPrimitiveCNode(graphs_cnode, prim::kPrimMakeTuple) && IsValueNode<FuncGraph>(graphs_inputs[1])) {
(void)std::transform(graphs_inputs.begin() + 1, graphs_inputs.end(), std::back_inserter(graphs),
[](const AnfNodePtr &vnode) { return GetValueNode<FuncGraphPtr>(vnode); });
}
if (graphs.empty()) {
return nullptr;
}
auto fg = node->func_graph();
std::vector<AnfNodePtr> layers;
for (auto &graph : graphs) {
auto fg_transform = env_get_item_transform_(graph, key, default_v);
if (fg_transform == nullptr) {
return nullptr;
}
layers.push_back(NewValueNode(fg_transform));
}
auto layers_node = fg->NewCNode(prim::kPrimMakeTuple, layers);
auto new_sw = fg->NewCNode({NewValueNode(prim::kPrimSwitchLayer), sw->input(1), layers_node});
args.insert(args.begin(), new_sw);
auto inner_call = fg->NewCNode(args);
args_outer.insert(args_outer.begin(), inner_call);
return fg->NewCNode(args_outer);
}
void Visit(const AnfNodePtr &) override { is_match_ = true; }
private:
bool is_match_{false};
internal::EnvGetitemTransformACrossGraph env_get_item_transform_;
};
} // namespace irpass
} // namespace opt
} // namespace mindspore

@ -180,7 +180,7 @@ OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) {
{irpass.zero_like_fill_zero_, irpass.item_tuple_eliminate_, irpass.float_tuple_getitem_switch_,
irpass.reset_defer_inline_, irpass.inline_, irpass.special_op_eliminate_, irpass.get_make_ref_eliminate_,
irpass.incorporate_env_getitem_, irpass.incorporate_env_getitem_switch_, irpass.env_get_item_eliminate_,
irpass.value_based_eliminate_});
irpass.incorporate_env_getitem_switch_layer_, irpass.value_based_eliminate_});
opt::OptPassConfig b_2 = opt::OptPassConfig({
irpass.replace_refkey_by_param_,
irpass.make_ref_eliminate_,

@ -464,6 +464,36 @@ def test_switch_layer_with_single_prim():
C.grad_all(net)(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
def test_switch_layer_env_eliminate():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.conv = nn.Conv2d(1, 1, 3, pad_mode='same')
self.conv2 = nn.Conv2d(1, 1, 5, pad_mode='same')
self.funs = (self.conv, self.conv2)
def construct(self, x, index):
x = self.funs[index](x)
return x
class NetGrad(nn.Cell):
def __init__(self, net):
super(NetGrad, self).__init__()
self.grad_op = C.GradOperation('grad', get_by_list=True, sens_param=False)
self.net = net
self.weights = ParameterTuple(self.net.trainable_params())
def construct(self, x, index):
weights = self.weights
grad = self.grad_op(self.net, weights)(x, index)
return grad
net = Net()
net2 = NetGrad(net)
x = Tensor(np.ones((3, 1, 12, 12)), ms.float32)
i = Tensor(1, ms.int32)
net2(x, i)
def test_control_depend_check():
with pytest.raises(TypeError) as e:
P.ControlDepend(0.0)

Loading…
Cancel
Save