!8416 Fix RealDiv Type Error when insert an VirtualDiv Operation

From: @huangxinjing
Reviewed-by: @stsuteng
Signed-off-by: @stsuteng
pull/8416/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 156980778b

@ -225,7 +225,8 @@ def get_bprop_virtual_div_operator(self):
def bprop(x, out, dout): def bprop(x, out, dout):
if F.issubclass_(F.typeof(dout), mstype.tensor): if F.issubclass_(F.typeof(dout), mstype.tensor):
if F.issubclass_(F.dtype(dout), mstype.bool_): if F.issubclass_(F.dtype(dout), mstype.bool_) or F.issubclass_(F.dtype(dout), mstype.int32) \
or F.issubclass_(F.dtype(dout), mstype.int16):
return (dout,) return (dout,)
dx = op(dout, cast(F.scalar_to_array(divisor), dtype(dout))) dx = op(dout, cast(F.scalar_to_array(divisor), dtype(dout)))
return (dx,) return (dx,)

@ -21,7 +21,6 @@ from mindspore import context
from mindspore.common.api import _executor from mindspore.common.api import _executor
from mindspore.ops import composite as C from mindspore.ops import composite as C
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops.operations.comm_ops import _VirtualDataset
from tests.ut.python.ops.test_math_ops import VirtualLoss from tests.ut.python.ops.test_math_ops import VirtualLoss
context.set_context(mode=context.GRAPH_MODE) context.set_context(mode=context.GRAPH_MODE)
@ -33,7 +32,6 @@ grad_all = C.GradOperation(get_all=True)
class Net(nn.Cell): class Net(nn.Cell):
def __init__(self, strategy1, strategy2, num_segments): def __init__(self, strategy1, strategy2, num_segments):
super(Net, self).__init__() super(Net, self).__init__()
self.virtual_dataset = _VirtualDataset()
self.merge_op = P.UnsortedSegmentSum().shard((strategy1, strategy2)) self.merge_op = P.UnsortedSegmentSum().shard((strategy1, strategy2))
self.num_segments = num_segments self.num_segments = num_segments
@ -54,8 +52,8 @@ class GradWrap(nn.Cell):
class NetWithLoss(nn.Cell): class NetWithLoss(nn.Cell):
def __init__(self, network): def __init__(self, network):
super(NetWithLoss, self).__init__() super(NetWithLoss, self).__init__()
self.loss = VirtualLoss()
self.network = network self.network = network
self.loss = VirtualLoss()
def construct(self, x, y): def construct(self, x, y):
predict = self.network(x, y) predict = self.network(x, y)
@ -63,13 +61,13 @@ class NetWithLoss(nn.Cell):
def compile_graph(x, y, segments, strategy1, strategy2, auto=False): def compile_graph(x, y, segments, strategy1, strategy2, auto=False):
net = GradWrap(NetWithLoss(Net(strategy1, strategy2, segments)))
net.set_auto_parallel()
net.set_train()
if auto: if auto:
context.set_auto_parallel_context(parallel_mode="auto_parallel") context.set_auto_parallel_context(parallel_mode="auto_parallel")
else: else:
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
net = GradWrap(NetWithLoss(Net(strategy1, strategy2, segments)))
net.set_auto_parallel()
net.set_train()
_executor.compile(net, x, y) _executor.compile(net, x, y)
@ -151,3 +149,13 @@ def test_unsortedsegmentsum_model_parallel_index_vector_slice_3d():
strategy1 = (2, 1, 2) strategy1 = (2, 1, 2)
strategy2 = (2, 1) strategy2 = (2, 1)
compile_graph(x, y, num_segments, strategy1, strategy2) compile_graph(x, y, num_segments, strategy1, strategy2)
def test_unsortedsegmentsum_model_parallel_repeat_caculate():
context.set_auto_parallel_context(device_num=4, global_rank=0)
x = Tensor(np.ones((4, 4, 8)), ms.float32)
y = Tensor(np.ones((4, 4)), ms.int32)
num_segments = 16
strategy1 = (1, 1, 1)
strategy2 = (1, 1)
compile_graph(x, y, num_segments, strategy1, strategy2)

Loading…
Cancel
Save