From ad3d737d2953d2499c3630c7ca243a4fc31468fb Mon Sep 17 00:00:00 2001 From: TFBunny Date: Fri, 5 Feb 2021 18:01:08 -0500 Subject: [PATCH] Rewrite sequence_mask as a composite op --- .../gpu/arrays/argmaxwithvalue_gpu_kernel.h | 15 ++- .../gpu/arrays/sequence_mask_gpu_kernel.cc | 35 ------ .../gpu/arrays/sequence_mask_gpu_kernel.h | 101 ------------------ .../gpu/cuda_impl/sequence_mask_impl.cu | 50 --------- .../gpu/cuda_impl/sequence_mask_impl.cuh | 25 ----- mindspore/core/abstract/infer_functions.h | 4 + mindspore/core/abstract/prim_arrays.cc | 59 ++++++++++ mindspore/core/abstract/prim_maths.cc | 34 +++++- .../core/abstract/primitive_infer_map.cc | 2 + mindspore/ops/composite/array_ops.py | 40 ++++++- tests/st/ops/gpu/test_sequence_mask_op.py | 69 ++++++++++-- 11 files changed, 209 insertions(+), 225 deletions(-) delete mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/sequence_mask_gpu_kernel.cc delete mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/sequence_mask_gpu_kernel.h delete mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sequence_mask_impl.cu delete mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sequence_mask_impl.cuh diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/argmaxwithvalue_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/argmaxwithvalue_gpu_kernel.h index 2862715508..0859ad528c 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/argmaxwithvalue_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/argmaxwithvalue_gpu_kernel.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,7 +26,7 @@ namespace kernel { template class ArgmaxWithValueGpuKernel : public GpuKernel { public: - ArgmaxWithValueGpuKernel() : input_size_(0), output_size_(0), bound_(0), outerSize_(0), innerSize_(0) {} + ArgmaxWithValueGpuKernel() { ResetResource(); } ~ArgmaxWithValueGpuKernel() override = default; const std::vector &GetInputSizeList() const override { return input_size_list_; } @@ -75,6 +75,17 @@ class ArgmaxWithValueGpuKernel : public GpuKernel { return true; } + void ResetResource() noexcept override { + input_size_ = 0; + output_size_ = 0; + bound_ = 0; + outerSize_ = 0; + innerSize_ = 0; + input_size_list_.clear(); + output_size_list_.clear(); + workspace_size_list_.clear(); + } + protected: void InitSizeLists() override { input_size_list_.push_back(input_size_); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/sequence_mask_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/sequence_mask_gpu_kernel.cc deleted file mode 100644 index c42927eab2..0000000000 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/sequence_mask_gpu_kernel.cc +++ /dev/null @@ -1,35 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include - -#include "backend/kernel_compiler/gpu/arrays/sequence_mask_gpu_kernel.h" - -namespace mindspore { -namespace kernel { - -// keep this as TWO but output is always bool, just in case framework can -// support passing optional dtype and then we can be identical to tf -MS_REG_GPU_KERNEL_TWO( - SequenceMask, - KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool), - SequenceMaskGpuKernel, int32_t, bool) - -MS_REG_GPU_KERNEL_TWO( - SequenceMask, - KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool), - SequenceMaskGpuKernel, int64_t, bool) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/sequence_mask_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/sequence_mask_gpu_kernel.h deleted file mode 100644 index 314d4e102e..0000000000 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/sequence_mask_gpu_kernel.h +++ /dev/null @@ -1,101 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SEQUENCE_MASK_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SEQUENCE_MASK_GPU_KERNEL_H_ - -#include "backend/kernel_compiler/gpu/cuda_impl/sequence_mask_impl.cuh" - -#include - -#include - -#include "backend/kernel_compiler/gpu/gpu_kernel.h" -#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" - -namespace mindspore { -namespace kernel { -template -class SequenceMaskGpuKernel : public GpuKernel { - public: - SequenceMaskGpuKernel() { ResetResource(); } - ~SequenceMaskGpuKernel() = default; - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override { - T *lengths_device_address = GetDeviceAddress(inputs, 0); - T *maxlen_device_address = GetDeviceAddress(inputs, 1); - S *output_device_address = GetDeviceAddress(outputs, 0); - - CalSequenceMask(lengths_device_address, maxlen_device_address, output_device_address, output_size_, - reinterpret_cast(stream_ptr)); - - return true; - } - - bool Init(const CNodePtr &kernel_node) override { - size_t input_count = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_count != 2) { - MS_LOG(EXCEPTION) << input_count << " inputs were provided, but SequenceMaskGpuKernel expects 2."; - } - - input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - for (const int &e : input_shape_) { - lengths_size_ *= e; - } - - std::vector inferred_output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); - for (const size_t &e : inferred_output_shape) { - output_size_ *= e; - } - - InitSizeLists(); - - return true; - } - - void ResetResource() noexcept override { - output_size_ = 1; - lengths_size_ = 1; - input_size_list_.clear(); - output_size_list_.clear(); - workspace_size_list_.clear(); - } - - protected: - void InitSizeLists() override { - input_size_list_.push_back(lengths_size_ * sizeof(T)); - input_size_list_.push_back(sizeof(T)); - output_size_list_.push_back(output_size_); - } - - private: - std::vector input_shape_; - size_t lengths_size_; - size_t output_size_; - - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SEQUENCE_MASK_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sequence_mask_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sequence_mask_impl.cu deleted file mode 100644 index 1bdc72b4ef..0000000000 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sequence_mask_impl.cu +++ /dev/null @@ -1,50 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include - -#include "sequence_mask_impl.cuh" -#include "runtime/device/gpu/cuda_common.h" - -__global__ void ValidateArgs(int *maxlen, const int lengths_size, const int max_output_size) { - int maxlen_value = *maxlen; - if (maxlen_value < 0 || lengths_size * maxlen_value > max_output_size) { - asm("trap;"); - } -} - -template -__global__ void SequenceMask( - const T *input, T *maxlen, S *output, const size_t output_size) { - T maxlen_value = *maxlen; - - for (size_t gt_id = blockIdx.x * blockDim.x + threadIdx.x; gt_id < output_size; gt_id += gridDim.x * blockDim.x) { - T mask_comparison_value = gt_id % maxlen_value; - T input_comparison_index = (gt_id - mask_comparison_value) / maxlen_value; - S result = mask_comparison_value < input[input_comparison_index]; - output[gt_id] = result; - } -} - -template -void CalSequenceMask(const T *lengths, T *maxlen, S *output, const size_t output_size, cudaStream_t cuda_stream) { - SequenceMask<<>>(lengths, maxlen, output, output_size); -} - -template void CalSequenceMask(const int *lengths, int *maxlen, bool *output, const size_t output_size, - cudaStream_t cuda_stream); - -template void CalSequenceMask(const int64_t *lengths, int64_t *maxlen, bool *output, - const size_t output_size, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sequence_mask_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sequence_mask_impl.cuh deleted file mode 100644 index 241c0134d1..0000000000 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sequence_mask_impl.cuh +++ /dev/null @@ -1,25 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_SEQUENCE_MASK_CUH_ -#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_SEQUENCE_MASK_CUH_ - -#include - -template -void CalSequenceMask(const T *lengths, T *maxlen, S *output, const size_t output_size, cudaStream_t cuda_stream); - -#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_SEQUENCE_MASK_CUH_ diff --git a/mindspore/core/abstract/infer_functions.h b/mindspore/core/abstract/infer_functions.h index 09fc78a3ae..b03757c1aa 100644 --- a/mindspore/core/abstract/infer_functions.h +++ b/mindspore/core/abstract/infer_functions.h @@ -297,6 +297,10 @@ AbstractBasePtr InferImplMatMul(const AnalysisEnginePtr &, const PrimitivePtr &p const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplBatchMatMul(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplLess(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplArgMaxWithValue(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); template AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) { // Inputs: a tuple or list or dict. diff --git a/mindspore/core/abstract/prim_arrays.cc b/mindspore/core/abstract/prim_arrays.cc index 4d4e1ec9b0..9ecef84235 100644 --- a/mindspore/core/abstract/prim_arrays.cc +++ b/mindspore/core/abstract/prim_arrays.cc @@ -1068,5 +1068,64 @@ AbstractBasePtr InferImplRange(const AnalysisEnginePtr &, const PrimitivePtr &pr return std::make_shared(range_start_type, shape); } + +AbstractBasePtr InferImplArgMaxWithValue(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 1); + auto x = CheckArg(op_name, args_spec_list, 0); + MS_EXCEPTION_IF_NULL(x); + MS_EXCEPTION_IF_NULL(x->shape()); + // check keep_dims + ValuePtr keep_dims = primitive->GetAttr("keep_dims"); + MS_EXCEPTION_IF_NULL(keep_dims); + if (!keep_dims->isa()) { + MS_LOG(EXCEPTION) << "keep_dims should be Bool."; + } + bool keep_dims_value = GetValue(keep_dims); + // check axis + ValuePtr axis = primitive->GetAttr("axis"); + MS_EXCEPTION_IF_NULL(axis); + if (!axis->isa() && !axis->isa()) { + MS_LOG(EXCEPTION) << "axis should be Int."; + } + // check axis convert negative to positive value + auto check_axis = [](int64_t &axis, const size_t dim) -> void { + int64_t dim_ = static_cast(dim); + if (axis < -dim_ || axis >= dim_) { + MS_LOG(EXCEPTION) << "axis should be in [" << -dim_ << ", " << dim_ << "). But got axis = " << axis << "."; + } + if (axis >= -dim_ && axis < 0) { + axis += dim_; + } + return; + }; + // main calculate shape func + auto cal_shape = [axis, keep_dims_value, check_axis](ShapeVector &shape, const ShapeVector &x_shape) -> void { + shape.insert(shape.end(), x_shape.begin(), x_shape.end()); + int64_t axis_value = GetValue(axis); + check_axis(axis_value, x_shape.size()); + if (keep_dims_value) { + shape[axis_value] = 1; + } else { + shape.erase(std::begin(shape) + axis_value); + } + }; + ShapeVector shape = {}; + ShapeVector min_shape = {}; + ShapeVector max_shape = {}; + ShapeVector x_shape = x->shape()->shape(); + ShapeVector x_min_shape = x->shape()->min_shape(); + ShapeVector x_max_shape = x->shape()->max_shape(); + (void)CheckMinMaxShape(x_shape, &x_min_shape, &x_max_shape); + cal_shape(shape, x_shape); + cal_shape(min_shape, x_min_shape); + cal_shape(max_shape, x_max_shape); + TypePtr idx_type = kInt32; + auto index = std::make_shared(idx_type, std::make_shared(shape, min_shape, max_shape)); + auto value = std::make_shared(x->element(), std::make_shared(shape, min_shape, max_shape)); + AbstractBasePtrList result = {index, value}; + return std::make_shared(result); +} } // namespace abstract } // namespace mindspore diff --git a/mindspore/core/abstract/prim_maths.cc b/mindspore/core/abstract/prim_maths.cc index fa3e9f20bd..7bcd7759d2 100644 --- a/mindspore/core/abstract/prim_maths.cc +++ b/mindspore/core/abstract/prim_maths.cc @@ -1,5 +1,5 @@ /** - * Copyright 2019-2020 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -454,5 +454,37 @@ AbstractBasePtr InferImplBatchMatMul(const AnalysisEnginePtr &, const PrimitiveP } return std::make_shared(x_type, std::make_shared(ret_shape, ret_min_shape, ret_max_shape)); } + +AbstractBasePtr InferImplLess(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 2); + auto x = CheckArg(op_name, args_spec_list, 0); + MS_EXCEPTION_IF_NULL(x); + MS_EXCEPTION_IF_NULL(x->shape()); + ShapeVector x_shape = x->shape()->shape(); + ShapeVector x_shape_min = x->shape()->min_shape().empty() ? x_shape : x->shape()->min_shape(); + ShapeVector x_shape_max = x->shape()->max_shape().empty() ? x_shape : x->shape()->max_shape(); + + auto y = CheckArg(op_name, args_spec_list, 1); + MS_EXCEPTION_IF_NULL(y); + MS_EXCEPTION_IF_NULL(y->shape()); + ShapeVector y_shape = y->shape()->shape(); + ShapeVector y_shape_min = y->shape()->min_shape().empty() ? y_shape : y->shape()->min_shape(); + ShapeVector y_shape_max = y->shape()->max_shape().empty() ? y_shape : y->shape()->max_shape(); + + auto out_shape = BroadcastShape(x_shape, y_shape); + if (out_shape.empty()) { + MS_LOG(EXCEPTION) << "BroadcastShape fail: " << args_spec_list[0]->ToString() << "," + << args_spec_list[1]->ToString(); + } + auto out_shape_min = BroadcastShape(x_shape_min, y_shape_min); + auto out_shape_max = BroadcastShape(x_shape_max, y_shape_max); + + auto output_type = std::make_shared(); + auto ret = + std::make_shared(output_type, std::make_shared(out_shape, out_shape_min, out_shape_max)); + return ret; +} } // namespace abstract } // namespace mindspore diff --git a/mindspore/core/abstract/primitive_infer_map.cc b/mindspore/core/abstract/primitive_infer_map.cc index dd94ebfc57..f1579c691a 100644 --- a/mindspore/core/abstract/primitive_infer_map.cc +++ b/mindspore/core/abstract/primitive_infer_map.cc @@ -55,6 +55,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { {prim::kPrimAddN, {InferImplAddN, true}}, {prim::kPrimMatMul, {InferImplMatMul, true}}, {prim::kPrimBatchMatMul, {InferImplBatchMatMul, true}}, + {prim::kPrimLess, {InferImplLess, true}}, // Array {prim::kPrimScalarToArray, {InferImplScalarToArray, true}}, {prim::kPrimArrayToScalar, {InferImplArrayToScalar, true}}, @@ -89,6 +90,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { {prim::kPrimSequenceMask, {InferImplSequenceMask, true}}, {prim::kPrimConcat, {InferImplConcat, true}}, {prim::kPrimRange, {InferImplRange, true}}, + {prim::kPrimArgMaxWithValue, {InferImplArgMaxWithValue, true}}, // Structure {prim::kPrimMakeTuple, {InferImplMakeTuple, true}}, {prim::kPrimMakeList, {InferImplMakeList, true}}, diff --git a/mindspore/ops/composite/array_ops.py b/mindspore/ops/composite/array_ops.py index a64d5faac0..ebe5aa8013 100644 --- a/mindspore/ops/composite/array_ops.py +++ b/mindspore/ops/composite/array_ops.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,7 +20,6 @@ from mindspore._checkparam import Rel from mindspore.ops.primitive import constexpr from mindspore.ops import functional as F from .. import operations as P -from ..operations import _inner_ops as inner @constexpr @@ -105,7 +104,15 @@ def repeat_elements(x, rep, axis=0): return x_rep -def sequence_mask(lengths, maxlen): + +@constexpr +def _check_sequence_mask_input_len(input_shape): + if not input_shape: + raise ValueError(f"sequence_mask input lengths_shape should be > 0. " + f"current lengths_shape is {input_shape}.") + + +def sequence_mask(lengths, maxlen=None): """ Returns a mask tensor representing the first N positions of each cell. @@ -135,4 +142,29 @@ def sequence_mask(lengths, maxlen): [[True, True, False], [False, False, False]]] """ - return inner.SequenceMask()(lengths, maxlen) + + argmax_op = P.ArgMaxWithValue() + reshape_op = P.Reshape() + range_op = P.Range() + expand_op = P.ExpandDims() + cast_op = P.Cast() + shape_op = P.Shape() + to_tensor_op = P.ScalarToArray() + + const_utils.check_type_valid(F.dtype(lengths), [mstype.int64, mstype.int32], 'lengths') + _check_sequence_mask_input_len(shape_op(lengths)) + + if maxlen is None: + flatten_data = reshape_op(lengths, (-1,)) + flatten_data = cast_op(flatten_data, mstype.float32) + _, value = argmax_op(flatten_data) + maxlen = cast_op(value, mstype.int32) + else: + maxlen = _check_positive_int(maxlen, "maxlen", "sequence_mask") + maxlen = to_tensor_op(maxlen) + + range_vector = range_op(to_tensor_op(0), maxlen + , to_tensor_op(1)) + mask = expand_op(lengths, -1) + result = range_vector < mask + return result diff --git a/tests/st/ops/gpu/test_sequence_mask_op.py b/tests/st/ops/gpu/test_sequence_mask_op.py index 42e5328862..d5394dd34f 100644 --- a/tests/st/ops/gpu/test_sequence_mask_op.py +++ b/tests/st/ops/gpu/test_sequence_mask_op.py @@ -1,3 +1,17 @@ +# Copyright 2020-2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ import numpy as np import pytest @@ -16,7 +30,6 @@ def sequence_mask(x, maxlen): def test_sequence_mask_1d(): context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") a = np.array([2, 3, 1]) - maxlen = 4 ms_out = sequence_mask(a, maxlen) expected_out = Tensor(np.array([[True, True, False, False], @@ -30,7 +43,6 @@ def test_sequence_mask_1d(): def test_sequence_mask_2d(): context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") a = np.array([[0, 1, 3, 2], [1, 4, 4, 2]]) - maxlen = 6 ms_out = sequence_mask(a, maxlen) expected_out = Tensor(np.array([[[False, False, False, False, False, False], @@ -51,7 +63,6 @@ def test_sequence_mask_3d(): a = np.array([[[2, 2], [1, 1]], [[2, 0], [2, 1]], [[0, 0], [0, 0]]]) - maxlen = 2 ms_out = sequence_mask(a, maxlen) expected_out = Tensor(np.array([[[[True, True], [True, True]], [[True, False], [True, False]]], @@ -68,7 +79,6 @@ def test_sequence_mask_maxlen_1(): a = np.array([[[0, 1], [1, 1]], [[1, 0], [1, 1]], [[0, 1], [0, 1]]]) - maxlen = 1 ms_out = sequence_mask(a, maxlen) expected_out = Tensor(np.array([[[[False], [True]], [[True], [True,]]], @@ -81,9 +91,9 @@ def test_sequence_mask_maxlen_1(): @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard def test_sequence_mask_dynamic(): - class SequenceMaskDynamicNet(nn.Cell): + class SequenceMaskDynamicNet1(nn.Cell): def __init__(self, maxlen): - super(SequenceMaskDynamicNet, self).__init__() + super(SequenceMaskDynamicNet1, self).__init__() self.maxlen = maxlen self.convert_to_dynamic_shape = inner.GpuConvertToDynamicShape() @@ -91,9 +101,18 @@ def test_sequence_mask_dynamic(): converted_to_dynamic_shape = self.convert_to_dynamic_shape(x) return C.sequence_mask(converted_to_dynamic_shape, self.maxlen) + class SequenceMaskDynamicNet2(nn.Cell): + def __init__(self): + super(SequenceMaskDynamicNet2, self).__init__() + self.convert_to_dynamic_shape = inner.GpuConvertToDynamicShape() + + def construct(self, x): + converted_to_dynamic_shape = self.convert_to_dynamic_shape(x) + return C.sequence_mask(converted_to_dynamic_shape) + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - sequence_mask_net = SequenceMaskDynamicNet(4) + sequence_mask_net = SequenceMaskDynamicNet1(4) a = Tensor(np.array([0, 1, 0, 2, 0, 5])) ms_out = sequence_mask_net(a) @@ -113,3 +132,39 @@ def test_sequence_mask_dynamic(): [[False, False, False, False], [True, False, False, False], [True, True, True, False]]])) + np.testing.assert_array_equal(expected_out.asnumpy(), ms_out.asnumpy()) + + net_without_maxlen = SequenceMaskDynamicNet2() + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + a = np.array([2, 3, 1]) + ms_out = net_without_maxlen(Tensor(a)) + expected_out = Tensor(np.array([[True, True, False], + [True, True, True], + [True, False, False]])) + np.testing.assert_array_equal(expected_out.asnumpy(), ms_out.asnumpy()) + + +def sequence_mask_optional(x): + return C.sequence_mask(Tensor(x.astype(np.int32))) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_sequence_mask_optional_maxlen(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + a = np.array([2, 3, 1]) + ms_out = sequence_mask_optional(a) + expected_out = Tensor(np.array([[True, True, False], + [True, True, True], + [True, False, False]])) + np.testing.assert_array_equal(expected_out.asnumpy(), ms_out.asnumpy()) + + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + a = np.array([2, 3, 1]) + ms_out = sequence_mask_optional(a) + expected_out = Tensor(np.array([[True, True, False], + [True, True, True], + [True, False, False]])) + np.testing.assert_array_equal(expected_out.asnumpy(), ms_out.asnumpy())