add ExpandDims whitelist

add comment for control_depend
pull/2631/head
huangdongrun 5 years ago
parent f65756162e
commit 96b38f72b2

@ -74,6 +74,7 @@ bool InConvertWhiteList(const AnfNodePtr &node, size_t index) {
{prim::kPrimApplyRMSProp, {6, 7, 8}}, {prim::kPrimApplyRMSProp, {6, 7, 8}},
{prim::kPrimCumSum, {2}}, {prim::kPrimCumSum, {2}},
{prim::kPrimTile, {2}}, {prim::kPrimTile, {2}},
{prim::kPrimExpandDims, {2}},
{prim::kPrimHistogramSummary, {1}}}); {prim::kPrimHistogramSummary, {1}}});
for (auto &item : white_list) { for (auto &item : white_list) {
auto matched = std::any_of(item.second.begin(), item.second.end(), [&item, &node, &index](size_t idx) { auto matched = std::any_of(item.second.begin(), item.second.end(), [&item, &node, &index](size_t idx) {

@ -30,6 +30,8 @@ class ControlDepend(Primitive):
tells the engine that the destination operations should depend on the source operation which means the source tells the engine that the destination operations should depend on the source operation which means the source
operations should be executed before the destination. operations should be executed before the destination.
Note:
This operation does not work in `PYNATIVE_MODE`.
Args: Args:
depend_mode (int): Use 0 for normal depend, 1 for depend on operations that used the parameter. Default: 0. depend_mode (int): Use 0 for normal depend, 1 for depend on operations that used the parameter. Default: 0.

@ -19,6 +19,8 @@ import pytest
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Tensor from mindspore import Tensor
from mindspore.ops import composite as C from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore.common import dtype as ms
from mindspore.common.api import _executor from mindspore.common.api import _executor
@ -116,3 +118,28 @@ def test_parser_map_0002():
net = NetMap0002() net = NetMap0002()
with pytest.raises(TypeError): with pytest.raises(TypeError):
net(input_me_x) net(input_me_x)
def test_fix_expanddims_loss_scale():
class ControlOneIfOneScaleOneScale(nn.Cell):
def __init__(self):
super().__init__()
self.op = P.ExpandDims()
def construct(self, x, y, data):
if x > y:
out = 1
else:
out = 2
if x > y:
out = self.op(data, out)
else:
out = self.op(data, out)
return out
net = ControlOneIfOneScaleOneScale()
x = Tensor(1, ms.float32)
y = Tensor(0, ms.float32)
input_shape = (1024, 512, 7, 7)
input_data = np.random.randn(*input_shape).astype(np.float32)
net = ControlOneIfOneScaleOneScale()
net(x, y, Tensor(input_data))

Loading…
Cancel
Save