From 3a319a97e92e45b4c53d64d40171f130ff5f5495 Mon Sep 17 00:00:00 2001 From: huangmengxi Date: Mon, 1 Mar 2021 17:21:05 +0800 Subject: [PATCH] move matmul from numpy to ops --- mindspore/numpy/math_ops.py | 57 +---------- mindspore/numpy/utils_const.py | 11 --- mindspore/ops/composite/__init__.py | 5 +- mindspore/ops/composite/math_ops.py | 146 ++++++++++++++++++++++++++++ tests/st/ops/gpu/test_matmul_op.py | 44 +++++++++ 5 files changed, 196 insertions(+), 67 deletions(-) diff --git a/mindspore/numpy/math_ops.py b/mindspore/numpy/math_ops.py index 7654008375..4bf0d331e9 100644 --- a/mindspore/numpy/math_ops.py +++ b/mindspore/numpy/math_ops.py @@ -31,8 +31,8 @@ from .array_ops import ravel, expand_dims from .utils_const import _infer_out_shape, _check_axis_valid, _get_device, \ _check_shape_aligned, _raise_type_error, _check_same_type, _check_is_float, \ - _raise_value_error, _check_matmul_shapes, _promote, _check_axis_type, _canonicalize_axis, \ - _max, _is_shape_empty, _check_is_int, _expanded_shape, _check_axis_in_range + _raise_value_error, _promote, _check_axis_type, _canonicalize_axis, \ + _is_shape_empty, _check_is_int, _expanded_shape, _check_axis_in_range from .utils import _is_scalar, _expand, _broadcast_to, _broadcast_to_shape, _get_size, \ _check_input_tensor @@ -1285,44 +1285,7 @@ def matmul(x1, x2, dtype=None): [ 550. 620. 690. 760. 830.] [ 670. 756. 842. 928. 1014.]]] """ - # performs type promotion - dtype1 = F.dtype(x1) - dtype2 = F.dtype(x2) - dtype_out = _promote(dtype1, dtype2) - if not _check_same_type(dtype1, dtype_out): - x1 = F.cast(x1, dtype_out) - if not _check_same_type(dtype2, dtype_out): - x2 = F.cast(x2, dtype_out) - - ndim1_orig, ndim2_orig = F.rank(x1), F.rank(x2) - shape1_orig, shape2_orig = F.shape(x1), F.shape(x2) - _check_matmul_shapes(shape1_orig, shape2_orig) - ndim_aligned = _max(ndim1_orig, ndim2_orig) - transpose_b = ndim2_orig == 1 - shape_backbone = _infer_out_shape( - shape1_orig[:-2], shape2_orig[:-2]) - # infers the shape of the output - shape_out = shape_backbone + _infer_shape_rem(shape1_orig, shape2_orig, - ndim1_orig, ndim2_orig, transpose_b) - - x1 = _expand(x1, _max(ndim_aligned, 2)) - x2 = _expand(x2, _max(ndim_aligned, 2)) - shape1_aligned, shape2_aligned = F.shape(x1), F.shape(x2) - - if ndim_aligned <= 2: - res = P.MatMul(False, transpose_b)(x1, x2) - else: - # broadcasts x1.shape[:-2] with x2.shape[:-2] - shape_aligned = shape_backbone + _infer_shape_rem(shape1_aligned, shape2_aligned, - ndim_aligned, ndim_aligned, - transpose_b) - x1 = _broadcast_to(x1, shape1_aligned[:-2], shape_aligned[:-2], ndim_aligned) - x2 = _broadcast_to(x2, shape2_aligned[:-2], shape_aligned[:-2], ndim_aligned) - res = P.BatchMatMul(False, transpose_b)(x1, x2) - - if dtype is not None and not _check_same_type(dtype_out, dtype): - res = F.cast(res, dtype) - return F.reshape(res, shape_out) + return C.matmul(x1, x2, dtype=dtype) def square(x, out=None, where=True, dtype=None): @@ -2256,20 +2219,6 @@ def _shape_reduced(shape, axes): return tuple(shape_out) -def _infer_shape_rem(shape1, shape2, ndim1, ndim2, transpose_b): - """Infers the shape of the last two dimensions after performing matmul.""" - shape_rem = () - if ndim1 >= 2: - shape_rem += (shape1[-2],) - if transpose_b: - if ndim2 >= 2: - shape_rem += (shape2[-2],) - else: - if ndim1 >= 1: - shape_rem += (shape2[-1],) - return shape_rem - - def _reduce(a, reduce_fn, cmp_fn, axis=None, keepdims=False, initial=None, where=True): """Applies comparison based on cmp_fn and reduction based on reduce_fn""" _check_input_tensor(a) diff --git a/mindspore/numpy/utils_const.py b/mindspore/numpy/utils_const.py index e8ccd223e5..35e96ebcc6 100644 --- a/mindspore/numpy/utils_const.py +++ b/mindspore/numpy/utils_const.py @@ -278,17 +278,6 @@ def _check_is_int(dtype): return isinstance(dtype, typing.Int) -@constexpr -def _check_matmul_shapes(shape1, shape2): - """Checks shape1 and shape2 are valid shapes to perform matmul""" - ndim1, ndim2 = len(shape1), len(shape2) - if ndim1 < 1 or ndim2 < 1: - raise ValueError('input operands must have at least 1 dimension') - if ndim2 >= 2 and shape1[-1] != shape2[-2]: - raise ValueError(f'mismatch in core dimension of input operands (size ' - f'{shape1[-1]} is different from {shape2[-2]})') - - @constexpr def _check_axis_type(axis, type_int=True, type_tuple=True, type_list=True): """Check axis argument type.""" diff --git a/mindspore/ops/composite/__init__.py b/mindspore/ops/composite/__init__.py index e9c09c2c14..bbca233beb 100644 --- a/mindspore/ops/composite/__init__.py +++ b/mindspore/ops/composite/__init__.py @@ -27,7 +27,7 @@ from .multitype_ops.add_impl import hyper_add from .multitype_ops.ones_like_impl import ones_like from .multitype_ops.zeros_like_impl import zeros_like from .random_ops import normal, laplace, uniform, gamma, poisson, multinomial -from .math_ops import count_nonzero, tensor_dot, dot, batch_dot +from .math_ops import count_nonzero, tensor_dot, dot, batch_dot, matmul from .array_ops import repeat_elements, sequence_mask @@ -56,4 +56,5 @@ __all__ = [ 'dot', 'batch_dot', 'repeat_elements', - 'sequence_mask'] + 'sequence_mask', + 'matmul'] diff --git a/mindspore/ops/composite/math_ops.py b/mindspore/ops/composite/math_ops.py index c2cf89833c..3a08415c8b 100644 --- a/mindspore/ops/composite/math_ops.py +++ b/mindspore/ops/composite/math_ops.py @@ -13,6 +13,8 @@ # limitations under the License. # ============================================================================ """math Operations.""" +from itertools import zip_longest +from collections import deque import numpy as np from mindspore.ops.composite.multitype_ops import _constexpr_utils as const_utils from mindspore.common import dtype as mstype @@ -486,3 +488,147 @@ def batch_dot(x1, x2, axes=None): final_result = squeeze_minus_one_op(final_result) return final_result + +@constexpr +def _check_same_type(dtype1, dtype2): + return dtype1 == dtype2 + +@constexpr +def _max(*args): + """Returns the maximum value.""" + return max(*args) + +@constexpr +def _min(*args): + """Returns the minimum value.""" + return min(*args) + +@constexpr +def _infer_shape_rem(shape1, shape2, ndim1, ndim2, transpose_b): + """Infers the shape of the last two dimensions after performing matmul.""" + shape_rem = [] + if ndim1 >= 2: + shape_rem.append(shape1[-2]) + if transpose_b: + if ndim2 >= 2: + shape_rem.append(shape2[-2]) + else: + if ndim1 >= 1: + shape_rem.append(shape2[-1]) + return tuple(shape_rem) + +@constexpr +def _check_matmul_shapes(shape1, shape2): + """Checks shape1 and shape2 are valid to perform matmul, and returns output shape after broadcasting.""" + ndim1, ndim2 = len(shape1), len(shape2) + if ndim1 < 1 or ndim2 < 1: + raise ValueError('input operands must have at least 1 dimension') + if ndim2 >= 2 and shape1[-1] != shape2[-2]: + raise ValueError(f'mismatch in core dimension of input operands (size ' + f'{shape1[-1]} is different from {shape2[-2]})') + shape_out = deque() + for items in zip_longest(reversed(shape1[:-2]), reversed(shape2[:-2]), fillvalue=1): + max_size = max(items) + if any(item not in (1, max_size) for item in items): + raise ValueError(f'operands could not be broadcast together with shapes {shape1} {shape2}') + shape_out.appendleft(max_size) + return tuple(shape_out) + +@constexpr +def _tile_size(shape, out_shape, ndim): + """Returns tile_size such that shape*tile_size = out_shape""" + size = [1]*ndim + for idx, (i, j) in enumerate(zip(shape, out_shape)): + if i != j: + size[idx] = j + return tuple(size) + +@constexpr +def _check_need_broadcast(shape1, shape2): + """Returns True if broadcast is necessary for batchmatmul.""" + return shape1[:-2] != shape2[:-2] + +def _expand(x, ndim): + """Expand x to ndim from axis, which can be 0 or -1.""" + while F.rank(x) < ndim: + x = F.expand_dims(x, 0) + return x + +def _broadcast_to(x, shape_cur, shape_to, ndim_to): + """Broadcasts x from shape_cur to shape_to.""" + size = _tile_size(shape_cur, shape_to, ndim_to) + return F.tile(x, size) + +def matmul(x1, x2, dtype=None): + """ + Returns the matrix product of two arrays. + + Note: + Numpy arguments `out`, `casting`, `order`, `subok`, `signature`, and `extobj` are + not supported. + On GPU, the supported dtypes are np.float16 and np.float32. + On CPU, the supported dtypes are np.float16 and np.float32. + + Args: + x1 (Tensor): Input tensor, scalar not allowed. + x2 (Tensor): Input tensor, scalar not allowed. + dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the + output Tensor. + + Returns: + Tensor or scalar, the matrix product of the inputs. This is a scalar only + when both `x1`, `x2` are 1-d vectors. + + Raises: + ValueError: If the last dimension of `x1` is not the same size as the + second-to-last dimension of `x2`, or if a scalar value is passed in. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> x1 = np.arange(2*3*4).reshape(2, 3, 4).astype('float32') + >>> x2 = np.arange(4*5).reshape(4, 5).astype('float32') + >>> output = np.matmul(x1, x2) + >>> print(output) + [[[ 70. 76. 82. 88. 94.] + [ 190. 212. 234. 256. 278.] + [ 310. 348. 386. 424. 462.]] + [[ 430. 484. 538. 592. 646.] + [ 550. 620. 690. 760. 830.] + [ 670. 756. 842. 928. 1014.]]] + """ + # performs type promotion + dtype1 = F.dtype(x1) + dtype2 = F.dtype(x2) + if not _check_same_type(dtype1, dtype2): + x1 = x1.astype(mstype.float32) + x2 = x2.astype(mstype.float32) + + ndim1_orig, ndim2_orig = F.rank(x1), F.rank(x2) + shape1_orig, shape2_orig = F.shape(x1), F.shape(x2) + transpose_b = ndim2_orig == 1 + shape_backbone = _check_matmul_shapes(shape1_orig, shape2_orig) + # infers the shape of the output + shape_out = shape_backbone + _infer_shape_rem(shape1_orig, shape2_orig, + ndim1_orig, ndim2_orig, transpose_b) + + x1 = _expand(x1, 2) + x2 = _expand(x2, 2) + if F.rank(x2) == 2: + if F.rank(x1) > 2: + x1 = F.reshape(x1, (-1, shape1_orig[-1])) + res = P.MatMul(False, transpose_b)(x1, x2) + else: + # broadcasts x1.shape[:-2] with x2.shape[:-2] + ndim_aligned = _max(ndim1_orig, ndim2_orig) + x1 = _expand(x1, ndim_aligned) + x2 = _expand(x2, ndim_aligned) + shape1_aligned, shape2_aligned = F.shape(x1), F.shape(x2) + x1 = _broadcast_to(x1, shape1_aligned[:-2], shape_backbone, ndim_aligned) + x2 = _broadcast_to(x2, shape2_aligned[:-2], shape_backbone, ndim_aligned) + res = P.BatchMatMul(False, transpose_b)(x1, x2) + + if dtype is not None: + res = res.astype(dtype) + return F.reshape(res, shape_out) diff --git a/tests/st/ops/gpu/test_matmul_op.py b/tests/st/ops/gpu/test_matmul_op.py index 3289df0356..b3fcba147c 100644 --- a/tests/st/ops/gpu/test_matmul_op.py +++ b/tests/st/ops/gpu/test_matmul_op.py @@ -20,6 +20,7 @@ import mindspore.context as context import mindspore.nn as nn from mindspore import Tensor from mindspore.ops import operations as P +from mindspore.ops import composite as C from mindspore.ops.operations import _inner_ops as inner class MatMulNet(nn.Cell): @@ -43,6 +44,15 @@ class MatMul_d(nn.Cell): return self.matmul(x, y) +class MatMulComposite(nn.Cell): + def __init__(self): + super(MatMulComposite, self).__init__() + self.matmul = C.matmul + + def construct(self, x, y): + return self.matmul(x, y) + + @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard @@ -77,3 +87,37 @@ def test_matmul_float64(): output = net(Tensor(x), Tensor(y)) expect = np.matmul(x, y) np.testing.assert_array_almost_equal(output.asnumpy(), expect) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_matmul_composite(): + + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + net = MatMulComposite() + + scalars = [np.random.randn(1).astype(np.float32), np.random.randn(1).astype(np.float32), + np.random.randn(1, 1).astype(np.float32), + np.random.randn(1, 1, 1).astype(np.float32)] + for x in scalars: + for y in scalars: + output = net(Tensor(x), Tensor(y)) + expect = np.matmul(x, y) + np.testing.assert_array_almost_equal(output.asnumpy(), expect) + + broadcastables = [ + np.random.randn(3).astype(np.float32), np.random.randn(3).astype(np.float32), + np.random.randn(6).astype(np.float32), np.random.randn(6, 4).astype(np.float32), + np.random.randn(5, 2).astype(np.float32), np.random.randn(2).astype(np.float32), + np.random.randn(2, 9).astype(np.float32), np.random.randn(9, 8).astype(np.float32), + np.random.randn(6).astype(np.float32), np.random.randn(2, 6, 5).astype(np.float32), + np.random.randn(9, 2, 7).astype(np.float32), np.random.randn(7).astype(np.float32), + np.random.randn(5, 2, 4).astype(np.float32), np.random.randn(6, 1, 4, 9).astype(np.float32), + np.random.randn(7, 1, 5, 3, 2).astype(np.float32), np.random.randn(8, 1, 6, 1, 2, 9).astype(np.float32) + ] + for i in range(8): + x = broadcastables[2*i] + y = broadcastables[2*i + 1] + output = net(Tensor(x), Tensor(y)) + expect = np.matmul(x, y) + np.testing.assert_array_almost_equal(output.asnumpy(), expect)