From 6ae8573803cfd2b391b868477a1883a5a4df3166 Mon Sep 17 00:00:00 2001 From: jonwe Date: Fri, 27 Nov 2020 14:59:53 -0500 Subject: [PATCH] Equal op dynamic shape --- mindspore/core/abstract/prim_maths.cc | 27 ++++++++++++++++--------- tests/st/ops/gpu/test_equal_op.py | 29 +++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 10 deletions(-) diff --git a/mindspore/core/abstract/prim_maths.cc b/mindspore/core/abstract/prim_maths.cc index d5da5d94e8..6d71f01bf0 100644 --- a/mindspore/core/abstract/prim_maths.cc +++ b/mindspore/core/abstract/prim_maths.cc @@ -80,24 +80,31 @@ AbstractBasePtr InferImplEqual(const AnalysisEnginePtr &, const PrimitivePtr &pr const AbstractBasePtrList &args_spec_list) { const std::string op_name = primitive->name(); CheckArgsSize(op_name, args_spec_list, 2); - auto input_x = CheckArg(op_name, args_spec_list, 0); - MS_EXCEPTION_IF_NULL(input_x); - MS_EXCEPTION_IF_NULL(input_x->shape()); + auto x = CheckArg(op_name, args_spec_list, 0); + MS_EXCEPTION_IF_NULL(x); + MS_EXCEPTION_IF_NULL(x->shape()); + ShapeVector x_shape = x->shape()->shape(); + ShapeVector x_shape_min = x->shape()->min_shape().empty() ? x_shape : x->shape()->min_shape(); + ShapeVector x_shape_max = x->shape()->max_shape().empty() ? x_shape : x->shape()->max_shape(); + + auto y = CheckArg(op_name, args_spec_list, 1); + MS_EXCEPTION_IF_NULL(y); + MS_EXCEPTION_IF_NULL(y->shape()); + ShapeVector y_shape = y->shape()->shape(); + ShapeVector y_shape_min = y->shape()->min_shape().empty() ? y_shape : y->shape()->min_shape(); + ShapeVector y_shape_max = y->shape()->max_shape().empty() ? y_shape : y->shape()->max_shape(); - auto input_y = CheckArg(op_name, args_spec_list, 1); - MS_EXCEPTION_IF_NULL(input_y); - MS_EXCEPTION_IF_NULL(input_y->shape()); - - auto x_shape = input_x->shape()->shape(); - auto y_shape = input_y->shape()->shape(); auto out_shape = BroadcastShape(x_shape, y_shape); if (out_shape.empty()) { MS_LOG(EXCEPTION) << "BroadcastShape fail: " << args_spec_list[0]->ToString() << "," << args_spec_list[1]->ToString(); } + auto out_shape_min = BroadcastShape(x_shape_min, y_shape_min); + auto out_shape_max = BroadcastShape(x_shape_max, y_shape_max); auto output_type = std::make_shared(); - auto ret = std::make_shared(output_type, out_shape); + auto ret = + std::make_shared(output_type, std::make_shared(out_shape, out_shape_min, out_shape_max)); return ret; } diff --git a/tests/st/ops/gpu/test_equal_op.py b/tests/st/ops/gpu/test_equal_op.py index adb12728dd..5b0ba815bf 100644 --- a/tests/st/ops/gpu/test_equal_op.py +++ b/tests/st/ops/gpu/test_equal_op.py @@ -20,6 +20,7 @@ import mindspore.context as context from mindspore.common.tensor import Tensor from mindspore.nn import Cell from mindspore.ops import operations as P +from mindspore.ops.operations import _inner_ops as inner class NetEqual(Cell): @@ -30,6 +31,17 @@ class NetEqual(Cell): def construct(self, x, y): return self.Equal(x, y) +class NetEqualDynamic(Cell): + def __init__(self): + super(NetEqualDynamic, self).__init__() + self.conv = inner.GpuConvertToDynamicShape() + self.Equal = P.Equal() + + def construct(self, x, y): + x_conv = self.conv(x) + y_conv = self.conv(y) + return self.Equal(x_conv, y_conv) + class NetNotEqual(Cell): def __init__(self): super(NetNotEqual, self).__init__() @@ -211,3 +223,20 @@ def test_greaterqual(): output2 = gequal(x2, y2) assert np.all(output2.asnumpy() == expect2) assert output2.shape == expect2.shape + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_equal_dynamic_shape(): + x0_np = np.arange(24).reshape((4, 3, 2)).astype(np.float32) + x0 = Tensor(x0_np) + y0_np = np.arange(24).reshape((4, 3, 2)).astype(np.float32) + y0 = Tensor(y0_np) + expect0 = np.equal(x0_np, y0_np) + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + equal = NetEqualDynamic() + output0 = equal(x0, y0) + assert np.all(output0.asnumpy() == expect0) + assert output0.shape == expect0.shape