!11290 Add dynamic shape support & testcases to MatMul, BatchMatMul gpu

From: @TFbunny
Reviewed-by: @tom__chen,@robingrosman
Signed-off-by: @robingrosman
pull/11290/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit be14046b0e

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_MATMUL_GPU_KERNEL_H
#define MINDSPORE_MATMUL_GPU_KERNEL_H
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_MATMUL_GPU_KERNEL_H
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_MATMUL_GPU_KERNEL_H
#include <cublas_v2.h>
#include <cuda_runtime_api.h>
@ -30,19 +30,7 @@ namespace kernel {
template <typename T>
class MatMulGpuKernel : public GpuKernel {
public:
MatMulGpuKernel()
: batch_(0),
m_(0),
n_(0),
k_(0),
is_null_input_(false),
transpose_x1_(CUBLAS_OP_N),
transpose_x2_(CUBLAS_OP_N),
handle_(nullptr),
dtype_a_(CUDA_R_32F),
dtype_b_(CUDA_R_32F),
dtype_c_(CUDA_R_32F),
algo_(CUBLAS_GEMM_DEFAULT) {}
MatMulGpuKernel() { ResetResource(); }
~MatMulGpuKernel() = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
@ -122,6 +110,24 @@ class MatMulGpuKernel : public GpuKernel {
return true;
}
void ResetResource() noexcept override {
batch_ = 0;
m_ = 0;
n_ = 0;
k_ = 0;
is_null_input_ = false;
transpose_x1_ = CUBLAS_OP_N;
transpose_x2_ = CUBLAS_OP_N;
handle_ = nullptr;
dtype_a_ = CUDA_R_32F;
dtype_b_ = CUDA_R_32F;
dtype_c_ = CUDA_R_32F;
algo_ = CUBLAS_GEMM_DEFAULT;
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
}
protected:
void InitSizeLists() override {
size_t unit_size = sizeof(T);
@ -158,4 +164,4 @@ class MatMulGpuKernel : public GpuKernel {
} // namespace kernel
} // namespace mindspore
#endif
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_MATMUL_GPU_KERNEL_H

@ -289,7 +289,10 @@ AbstractBasePtr InferImplAddN(const AnalysisEnginePtr &, const PrimitivePtr &pri
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplRange(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplMatMul(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplBatchMatMul(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
template <typename T>
AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) {
// Inputs: a tuple or list or dict.

@ -317,6 +317,7 @@ AbstractBasePtr InferImplLinSpace(const AnalysisEnginePtr &, const PrimitivePtr
std::make_shared<AbstractTensor>(start->element(), std::make_shared<Shape>(shape, min_shape, max_shape));
return ret;
}
AbstractBasePtr InferImplAddN(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
const std::string op_name = primitive->name();
@ -326,5 +327,93 @@ AbstractBasePtr InferImplAddN(const AnalysisEnginePtr &, const PrimitivePtr &pri
auto input = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
return input->Broaden();
}
AbstractBasePtr InferImplMatMul(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 2);
auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
MS_EXCEPTION_IF_NULL(x);
MS_EXCEPTION_IF_NULL(x->shape());
auto y = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
MS_EXCEPTION_IF_NULL(y);
MS_EXCEPTION_IF_NULL(y->shape());
if (x->shape()->shape().size() != 2 || y->shape()->shape().size() != 2) {
MS_LOG(EXCEPTION) << "MatMul inputs should have the same dimension size and equal to 2.";
}
ValuePtr TAptr = primitive->GetAttr("transpose_a");
ValuePtr TBptr = primitive->GetAttr("transpose_b");
bool TA = GetValue<bool>(TAptr);
bool TB = GetValue<bool>(TBptr);
ShapeVector x_min_shape = x->shape()->min_shape();
ShapeVector x_max_shape = x->shape()->max_shape();
ShapeVector y_min_shape = y->shape()->min_shape();
ShapeVector y_max_shape = y->shape()->max_shape();
(void)CheckMinMaxShape(x->shape()->shape(), &x_min_shape, &x_max_shape);
(void)CheckMinMaxShape(y->shape()->shape(), &y_min_shape, &y_max_shape);
ShapeVector ret_shape;
ShapeVector ret_min_shape;
ShapeVector ret_max_shape;
auto make_shape = [&TA, &TB](ShapeVector &output, const ShapeVector xshp, const ShapeVector yshp) -> void {
output.push_back(xshp[(TA ? 1 : 0)]);
output.push_back(yshp[(TB ? 0 : 1)]);
return;
};
make_shape(ret_shape, x->shape()->shape(), y->shape()->shape());
make_shape(ret_min_shape, x_min_shape, y_min_shape);
make_shape(ret_max_shape, x_max_shape, y_max_shape);
return std::make_shared<AbstractTensor>(x->element(),
std::make_shared<Shape>(ret_shape, ret_min_shape, ret_max_shape));
}
AbstractBasePtr InferImplBatchMatMul(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 2);
auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
MS_EXCEPTION_IF_NULL(x);
MS_EXCEPTION_IF_NULL(x->shape());
auto y = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
MS_EXCEPTION_IF_NULL(y);
MS_EXCEPTION_IF_NULL(y->shape());
if (x->shape()->shape().size() != y->shape()->shape().size() || x->shape()->shape().size() < 3) {
MS_LOG(EXCEPTION)
<< "BatchMatMul input x, y should have the same dimension size and should be greater or equal to 3.";
}
ValuePtr TAptr = primitive->GetAttr("transpose_a");
ValuePtr TBptr = primitive->GetAttr("transpose_b");
bool TA = GetValue<bool>(TAptr);
bool TB = GetValue<bool>(TBptr);
ShapeVector x_min_shape = x->shape()->min_shape();
ShapeVector x_max_shape = x->shape()->max_shape();
ShapeVector y_min_shape = y->shape()->min_shape();
ShapeVector y_max_shape = y->shape()->max_shape();
(void)CheckMinMaxShape(x->shape()->shape(), &x_min_shape, &x_max_shape);
(void)CheckMinMaxShape(y->shape()->shape(), &y_min_shape, &y_max_shape);
ShapeVector ret_shape;
ShapeVector ret_min_shape;
ShapeVector ret_max_shape;
auto make_shape = [&TA, &TB](ShapeVector &output, const ShapeVector xshp, const ShapeVector yshp) -> void {
for (size_t i = 0; i < xshp.size() - 2; i++) {
if (xshp[i] != yshp[i]) {
if (xshp[i] > 0 && yshp[i] > 0) {
MS_LOG(EXCEPTION) << "BatchMatMul input x, y are different at index " << i << ".";
}
output.push_back(Shape::SHP_ANY);
} else {
output.push_back(xshp[i]);
}
}
size_t offset = xshp.size() - 2;
output.push_back(xshp[offset + (TA ? 1 : 0)]);
output.push_back(yshp[offset + (TB ? 0 : 1)]);
return;
};
make_shape(ret_shape, x->shape()->shape(), y->shape()->shape());
make_shape(ret_min_shape, x_min_shape, y_min_shape);
make_shape(ret_max_shape, x_max_shape, y_max_shape);
return std::make_shared<AbstractTensor>(x->element(),
std::make_shared<Shape>(ret_shape, ret_min_shape, ret_max_shape));
}
} // namespace abstract
} // namespace mindspore

@ -49,6 +49,8 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimDivNoNan, {InferImplDivNoNan, true}},
{prim::kPrimLinSpace, {InferImplLinSpace, true}},
{prim::kPrimAddN, {InferImplAddN, true}},
{prim::kPrimMatMul, {InferImplMatMul, true}},
{prim::kPrimBatchMatMul, {InferImplBatchMatMul, true}},
// Array
{prim::kPrimScalarToArray, {InferImplScalarToArray, true}},
{prim::kPrimArrayToScalar, {InferImplArrayToScalar, true}},

@ -689,7 +689,7 @@ class CumProd(PrimitiveWithInfer):
raise ValueError(f"For {self.name}, axis must be const.")
class MatMul(PrimitiveWithInfer):
class MatMul(PrimitiveWithCheck):
"""
Multiplies matrix `a` and matrix `b`.
@ -730,10 +730,10 @@ class MatMul(PrimitiveWithInfer):
def check_shape_size(self, x1, x2):
if len(x1) != 2 or len(x2) != 2:
raise ValueError('P.MatMul inputs x1, x2 should has the same dimension size and '
raise ValueError('P.MatMul inputs x1, x2 should have the same dimension size and '
+ f'equal to 2, while x1 size is ({len(x1)}) and x2 size is ({len(x2)}).')
def infer_shape(self, x1, x2):
def check_shape(self, x1, x2):
self.check_shape_size(x1, x2)
cls_name = self.name
# expected dimension of x, y, x:[...,a,b] y:[..., c,d], the dim size should be the same except the last two
@ -747,23 +747,18 @@ class MatMul(PrimitiveWithInfer):
x2_last = x2[-2:]
x1_col = x1_last[not self.transpose_a]
x2_row = x2_last[self.transpose_b]
if x1_col != x2_row:
raise ValueError(f'For \'{cls_name}\' evaluator shapes of inputs can not do this operator,'
+ f' got {x1_col} and {x2_row}, with x1 shape {x1}(transpose_a={self.transpose_a})'
+ f', x2 shape {x2}(transpose_b={self.transpose_b}).')
if np.all(np.array(x1) != -1) and np.all(np.array(x2) != -1):
if x1_col != x2_row:
raise ValueError(f'For \'{cls_name}\' evaluator shapes of inputs can not do this operator,'
+ f' got {x1_col} and {x2_row}, with x1 shape {x1}(transpose_a={self.transpose_a})'
+ f', x2 shape {x2}(transpose_b={self.transpose_b}).')
# set attribute
self.add_prim_attr('transpose_x1', self.transpose_a)
self.add_prim_attr('transpose_x2', self.transpose_b)
ret_dims = x1[: -2] + [x1_last[self.transpose_a], x2_last[not self.transpose_b]]
return ret_dims
def infer_dtype(self, x1, x2):
def check_dtype(self, x1, x2):
args = {"x1": x1, "x2": x2}
validator.check_tensors_dtypes_same_and_valid(args, mstype.float_type + mstype.int_type, self.name)
if x1.element_type() == mstype.int8:
return mstype.tensor_type(mstype.int32)
return x1
class BatchMatMul(MatMul):

@ -21,11 +21,9 @@ import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common import dtype as mstype
from mindspore.ops import operations as P
from mindspore.ops.operations import _inner_ops as inner
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
class BatchMatMulNet(nn.Cell):
def __init__(self, transpose_a=False, transpose_b=False):
super(BatchMatMulNet, self).__init__()
@ -34,7 +32,9 @@ class BatchMatMulNet(nn.Cell):
def construct(self, x, y):
return self.batch_matmul(x, y)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_4d():
input_x = Tensor(np.arange(2 * 4 * 1 * 3).reshape(2, 4, 1, 3), mstype.float32)
input_y = Tensor(np.arange(2 * 4 * 3 * 4).reshape(2, 4, 3, 4), mstype.float32)
@ -140,3 +140,38 @@ def test_4D_fp16():
[[4340, 4396, 4456, 4510]],
[[5816, 5880, 5948, 6016]]]]).astype(np.float16)
assert (output.asnumpy() == expect).all()
class BatchMatMul_d(nn.Cell):
def __init__(self, transpose_a=False, transpose_b=False):
super(BatchMatMul_d, self).__init__()
self.batch_matmul = P.BatchMatMul(transpose_a, transpose_b)
self.test_dynamic = inner.GpuConvertToDynamicShape()
def construct(self, x, y):
x = self.test_dynamic(x)
y = self.test_dynamic(y)
return self.batch_matmul(x, y)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_batchmatmul_dynamic():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
net = BatchMatMul_d()
x1 = np.arange(8).reshape(2, 2, 2).astype(np.float32)
y1 = np.arange(28).reshape(2, 2, 7).astype(np.float32)
output1 = net(Tensor(x1), Tensor(y1))
expect1 = np.matmul(x1, y1)
assert (output1.asnumpy() == expect1).all()
x2 = np.arange(2 * 4 * 1 * 3).reshape(2, 4, 1, 3).astype(np.float32)
y2 = np.arange(2 * 4 * 3 * 4).reshape(2, 4, 3, 4).astype(np.float32)
output2 = net(Tensor(x2), Tensor(y2))
expect2 = np.matmul(x2, y2)
assert (output2.asnumpy() == expect2).all()

@ -0,0 +1,54 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.ops.operations import _inner_ops as inner
class MatMul_d(nn.Cell):
def __init__(self):
super(MatMul_d, self).__init__()
self.test_dynamic = inner.GpuConvertToDynamicShape()
self.matmul = P.MatMul()
def construct(self, x, y):
x = self.test_dynamic(x)
y = self.test_dynamic(y)
return self.matmul(x, y)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_MatMul_dynamic():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
net = MatMul_d()
x1 = np.arange(2).reshape(1, 2).astype(np.float32)
y1 = np.arange(4).reshape(2, 2).astype(np.float32)
output1 = net(Tensor(x1), Tensor(y1))
expect1 = np.matmul(x1, y1)
np.testing.assert_array_almost_equal(output1.asnumpy(), expect1)
x2 = np.arange(102).reshape(34, 3).astype(np.float32)
y2 = np.arange(18).reshape(3, 6).astype(np.float32)
output2 = net(Tensor(x2), Tensor(y2))
expect2 = np.matmul(x2, y2)
np.testing.assert_array_almost_equal(output2.asnumpy(), expect2)
Loading…
Cancel
Save