!12152 Rewrite sequence_mask as a composite op

From: @TFbunny
Reviewed-by: @robingrosman
Signed-off-by:
pull/12152/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit adf934c567

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -26,7 +26,7 @@ namespace kernel {
template <typename T, typename S>
class ArgmaxWithValueGpuKernel : public GpuKernel {
public:
ArgmaxWithValueGpuKernel() : input_size_(0), output_size_(0), bound_(0), outerSize_(0), innerSize_(0) {}
ArgmaxWithValueGpuKernel() { ResetResource(); }
~ArgmaxWithValueGpuKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
@ -75,6 +75,17 @@ class ArgmaxWithValueGpuKernel : public GpuKernel {
return true;
}
void ResetResource() noexcept override {
input_size_ = 0;
output_size_ = 0;
bound_ = 0;
outerSize_ = 0;
innerSize_ = 0;
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
}
protected:
void InitSizeLists() override {
input_size_list_.push_back(input_size_);

@ -1,35 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cstdint>
#include "backend/kernel_compiler/gpu/arrays/sequence_mask_gpu_kernel.h"
namespace mindspore {
namespace kernel {
// keep this as TWO but output is always bool, just in case framework can
// support passing optional dtype and then we can be identical to tf
MS_REG_GPU_KERNEL_TWO(
SequenceMask,
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
SequenceMaskGpuKernel, int32_t, bool)
MS_REG_GPU_KERNEL_TWO(
SequenceMask,
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
SequenceMaskGpuKernel, int64_t, bool)
} // namespace kernel
} // namespace mindspore

@ -1,101 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SEQUENCE_MASK_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SEQUENCE_MASK_GPU_KERNEL_H_
#include "backend/kernel_compiler/gpu/cuda_impl/sequence_mask_impl.cuh"
#include <cuda_runtime.h>
#include <vector>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
namespace mindspore {
namespace kernel {
template <typename T, typename S>
class SequenceMaskGpuKernel : public GpuKernel {
public:
SequenceMaskGpuKernel() { ResetResource(); }
~SequenceMaskGpuKernel() = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
T *lengths_device_address = GetDeviceAddress<T>(inputs, 0);
T *maxlen_device_address = GetDeviceAddress<T>(inputs, 1);
S *output_device_address = GetDeviceAddress<S>(outputs, 0);
CalSequenceMask(lengths_device_address, maxlen_device_address, output_device_address, output_size_,
reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
bool Init(const CNodePtr &kernel_node) override {
size_t input_count = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_count != 2) {
MS_LOG(EXCEPTION) << input_count << " inputs were provided, but SequenceMaskGpuKernel expects 2.";
}
input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
for (const int &e : input_shape_) {
lengths_size_ *= e;
}
std::vector<size_t> inferred_output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0);
for (const size_t &e : inferred_output_shape) {
output_size_ *= e;
}
InitSizeLists();
return true;
}
void ResetResource() noexcept override {
output_size_ = 1;
lengths_size_ = 1;
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
}
protected:
void InitSizeLists() override {
input_size_list_.push_back(lengths_size_ * sizeof(T));
input_size_list_.push_back(sizeof(T));
output_size_list_.push_back(output_size_);
}
private:
std::vector<size_t> input_shape_;
size_t lengths_size_;
size_t output_size_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SEQUENCE_MASK_GPU_KERNEL_H_

@ -1,50 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cuda_runtime.h>
#include "sequence_mask_impl.cuh"
#include "runtime/device/gpu/cuda_common.h"
__global__ void ValidateArgs(int *maxlen, const int lengths_size, const int max_output_size) {
int maxlen_value = *maxlen;
if (maxlen_value < 0 || lengths_size * maxlen_value > max_output_size) {
asm("trap;");
}
}
template <typename T, typename S>
__global__ void SequenceMask(
const T *input, T *maxlen, S *output, const size_t output_size) {
T maxlen_value = *maxlen;
for (size_t gt_id = blockIdx.x * blockDim.x + threadIdx.x; gt_id < output_size; gt_id += gridDim.x * blockDim.x) {
T mask_comparison_value = gt_id % maxlen_value;
T input_comparison_index = (gt_id - mask_comparison_value) / maxlen_value;
S result = mask_comparison_value < input[input_comparison_index];
output[gt_id] = result;
}
}
template <typename T, typename S>
void CalSequenceMask(const T *lengths, T *maxlen, S *output, const size_t output_size, cudaStream_t cuda_stream) {
SequenceMask<<<GET_BLOCKS(output_size), GET_THREADS, 0, cuda_stream>>>(lengths, maxlen, output, output_size);
}
template void CalSequenceMask<int, bool>(const int *lengths, int *maxlen, bool *output, const size_t output_size,
cudaStream_t cuda_stream);
template void CalSequenceMask<int64_t, bool>(const int64_t *lengths, int64_t *maxlen, bool *output,
const size_t output_size, cudaStream_t cuda_stream);

@ -1,25 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_SEQUENCE_MASK_CUH_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_SEQUENCE_MASK_CUH_
#include <cuda_runtime.h>
template <typename T, typename S>
void CalSequenceMask(const T *lengths, T *maxlen, S *output, const size_t output_size, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_SEQUENCE_MASK_CUH_

@ -304,6 +304,10 @@ AbstractBasePtr InferImplMatMul(const AnalysisEnginePtr &, const PrimitivePtr &p
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplBatchMatMul(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplLess(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplArgMaxWithValue(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
template <typename T>
AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) {
// Inputs: a tuple or list or dict.

@ -1068,5 +1068,64 @@ AbstractBasePtr InferImplRange(const AnalysisEnginePtr &, const PrimitivePtr &pr
return std::make_shared<AbstractTensor>(range_start_type, shape);
}
AbstractBasePtr InferImplArgMaxWithValue(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 1);
auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
MS_EXCEPTION_IF_NULL(x);
MS_EXCEPTION_IF_NULL(x->shape());
// check keep_dims
ValuePtr keep_dims = primitive->GetAttr("keep_dims");
MS_EXCEPTION_IF_NULL(keep_dims);
if (!keep_dims->isa<BoolImm>()) {
MS_LOG(EXCEPTION) << "keep_dims should be Bool.";
}
bool keep_dims_value = GetValue<bool>(keep_dims);
// check axis
ValuePtr axis = primitive->GetAttr("axis");
MS_EXCEPTION_IF_NULL(axis);
if (!axis->isa<Int32Imm>() && !axis->isa<Int64Imm>()) {
MS_LOG(EXCEPTION) << "axis should be Int.";
}
// check axis convert negative to positive value
auto check_axis = [](int64_t &axis, const size_t dim) -> void {
int64_t dim_ = static_cast<int64_t>(dim);
if (axis < -dim_ || axis >= dim_) {
MS_LOG(EXCEPTION) << "axis should be in [" << -dim_ << ", " << dim_ << "). But got axis = " << axis << ".";
}
if (axis >= -dim_ && axis < 0) {
axis += dim_;
}
return;
};
// main calculate shape func
auto cal_shape = [axis, keep_dims_value, check_axis](ShapeVector &shape, const ShapeVector &x_shape) -> void {
shape.insert(shape.end(), x_shape.begin(), x_shape.end());
int64_t axis_value = GetValue<int64_t>(axis);
check_axis(axis_value, x_shape.size());
if (keep_dims_value) {
shape[axis_value] = 1;
} else {
shape.erase(std::begin(shape) + axis_value);
}
};
ShapeVector shape = {};
ShapeVector min_shape = {};
ShapeVector max_shape = {};
ShapeVector x_shape = x->shape()->shape();
ShapeVector x_min_shape = x->shape()->min_shape();
ShapeVector x_max_shape = x->shape()->max_shape();
(void)CheckMinMaxShape(x_shape, &x_min_shape, &x_max_shape);
cal_shape(shape, x_shape);
cal_shape(min_shape, x_min_shape);
cal_shape(max_shape, x_max_shape);
TypePtr idx_type = kInt32;
auto index = std::make_shared<AbstractTensor>(idx_type, std::make_shared<Shape>(shape, min_shape, max_shape));
auto value = std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(shape, min_shape, max_shape));
AbstractBasePtrList result = {index, value};
return std::make_shared<AbstractTuple>(result);
}
} // namespace abstract
} // namespace mindspore

@ -1,5 +1,5 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -454,5 +454,37 @@ AbstractBasePtr InferImplBatchMatMul(const AnalysisEnginePtr &, const PrimitiveP
}
return std::make_shared<AbstractTensor>(x_type, std::make_shared<Shape>(ret_shape, ret_min_shape, ret_max_shape));
}
AbstractBasePtr InferImplLess(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 2);
auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
MS_EXCEPTION_IF_NULL(x);
MS_EXCEPTION_IF_NULL(x->shape());
ShapeVector x_shape = x->shape()->shape();
ShapeVector x_shape_min = x->shape()->min_shape().empty() ? x_shape : x->shape()->min_shape();
ShapeVector x_shape_max = x->shape()->max_shape().empty() ? x_shape : x->shape()->max_shape();
auto y = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
MS_EXCEPTION_IF_NULL(y);
MS_EXCEPTION_IF_NULL(y->shape());
ShapeVector y_shape = y->shape()->shape();
ShapeVector y_shape_min = y->shape()->min_shape().empty() ? y_shape : y->shape()->min_shape();
ShapeVector y_shape_max = y->shape()->max_shape().empty() ? y_shape : y->shape()->max_shape();
auto out_shape = BroadcastShape(x_shape, y_shape);
if (out_shape.empty()) {
MS_LOG(EXCEPTION) << "BroadcastShape fail: " << args_spec_list[0]->ToString() << ","
<< args_spec_list[1]->ToString();
}
auto out_shape_min = BroadcastShape(x_shape_min, y_shape_min);
auto out_shape_max = BroadcastShape(x_shape_max, y_shape_max);
auto output_type = std::make_shared<Bool>();
auto ret =
std::make_shared<AbstractTensor>(output_type, std::make_shared<Shape>(out_shape, out_shape_min, out_shape_max));
return ret;
}
} // namespace abstract
} // namespace mindspore

@ -74,6 +74,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimAddN, {InferImplAddN, true}},
{prim::kPrimMatMul, {InferImplMatMul, true}},
{prim::kPrimBatchMatMul, {InferImplBatchMatMul, true}},
{prim::kPrimLess, {InferImplLess, true}},
// Array
{prim::kPrimScalarToArray, {InferImplScalarToArray, true}},
{prim::kPrimArrayToScalar, {InferImplArrayToScalar, true}},
@ -108,6 +109,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimSequenceMask, {InferImplSequenceMask, true}},
{prim::kPrimConcat, {InferImplConcat, true}},
{prim::kPrimRange, {InferImplRange, true}},
{prim::kPrimArgMaxWithValue, {InferImplArgMaxWithValue, true}},
// Structure
{prim::kPrimMakeTuple, {InferImplMakeTuple, true}},
{prim::kPrimMakeList, {InferImplMakeList, true}},

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -20,7 +20,6 @@ from mindspore._checkparam import Rel
from mindspore.ops.primitive import constexpr
from mindspore.ops import functional as F
from .. import operations as P
from ..operations import _inner_ops as inner
@constexpr
@ -105,7 +104,15 @@ def repeat_elements(x, rep, axis=0):
return x_rep
def sequence_mask(lengths, maxlen):
@constexpr
def _check_sequence_mask_input_len(input_shape):
if not input_shape:
raise ValueError(f"sequence_mask input lengths_shape should be > 0. "
f"current lengths_shape is {input_shape}.")
def sequence_mask(lengths, maxlen=None):
"""
Returns a mask tensor representing the first N positions of each cell.
@ -135,4 +142,29 @@ def sequence_mask(lengths, maxlen):
[[True, True, False],
[False, False, False]]]
"""
return inner.SequenceMask()(lengths, maxlen)
argmax_op = P.ArgMaxWithValue()
reshape_op = P.Reshape()
range_op = P.Range()
expand_op = P.ExpandDims()
cast_op = P.Cast()
shape_op = P.Shape()
to_tensor_op = P.ScalarToArray()
const_utils.check_type_valid(F.dtype(lengths), [mstype.int64, mstype.int32], 'lengths')
_check_sequence_mask_input_len(shape_op(lengths))
if maxlen is None:
flatten_data = reshape_op(lengths, (-1,))
flatten_data = cast_op(flatten_data, mstype.float32)
_, value = argmax_op(flatten_data)
maxlen = cast_op(value, mstype.int32)
else:
maxlen = _check_positive_int(maxlen, "maxlen", "sequence_mask")
maxlen = to_tensor_op(maxlen)
range_vector = range_op(to_tensor_op(0), maxlen
, to_tensor_op(1))
mask = expand_op(lengths, -1)
result = range_vector < mask
return result

@ -1,3 +1,17 @@
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
@ -16,7 +30,6 @@ def sequence_mask(x, maxlen):
def test_sequence_mask_1d():
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
a = np.array([2, 3, 1])
maxlen = 4
ms_out = sequence_mask(a, maxlen)
expected_out = Tensor(np.array([[True, True, False, False],
@ -30,7 +43,6 @@ def test_sequence_mask_1d():
def test_sequence_mask_2d():
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
a = np.array([[0, 1, 3, 2], [1, 4, 4, 2]])
maxlen = 6
ms_out = sequence_mask(a, maxlen)
expected_out = Tensor(np.array([[[False, False, False, False, False, False],
@ -51,7 +63,6 @@ def test_sequence_mask_3d():
a = np.array([[[2, 2], [1, 1]],
[[2, 0], [2, 1]],
[[0, 0], [0, 0]]])
maxlen = 2
ms_out = sequence_mask(a, maxlen)
expected_out = Tensor(np.array([[[[True, True], [True, True]], [[True, False], [True, False]]],
@ -68,7 +79,6 @@ def test_sequence_mask_maxlen_1():
a = np.array([[[0, 1], [1, 1]],
[[1, 0], [1, 1]],
[[0, 1], [0, 1]]])
maxlen = 1
ms_out = sequence_mask(a, maxlen)
expected_out = Tensor(np.array([[[[False], [True]], [[True], [True,]]],
@ -81,9 +91,9 @@ def test_sequence_mask_maxlen_1():
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_sequence_mask_dynamic():
class SequenceMaskDynamicNet(nn.Cell):
class SequenceMaskDynamicNet1(nn.Cell):
def __init__(self, maxlen):
super(SequenceMaskDynamicNet, self).__init__()
super(SequenceMaskDynamicNet1, self).__init__()
self.maxlen = maxlen
self.convert_to_dynamic_shape = inner.GpuConvertToDynamicShape()
@ -91,9 +101,18 @@ def test_sequence_mask_dynamic():
converted_to_dynamic_shape = self.convert_to_dynamic_shape(x)
return C.sequence_mask(converted_to_dynamic_shape, self.maxlen)
class SequenceMaskDynamicNet2(nn.Cell):
def __init__(self):
super(SequenceMaskDynamicNet2, self).__init__()
self.convert_to_dynamic_shape = inner.GpuConvertToDynamicShape()
def construct(self, x):
converted_to_dynamic_shape = self.convert_to_dynamic_shape(x)
return C.sequence_mask(converted_to_dynamic_shape)
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
sequence_mask_net = SequenceMaskDynamicNet(4)
sequence_mask_net = SequenceMaskDynamicNet1(4)
a = Tensor(np.array([0, 1, 0, 2, 0, 5]))
ms_out = sequence_mask_net(a)
@ -113,3 +132,39 @@ def test_sequence_mask_dynamic():
[[False, False, False, False],
[True, False, False, False],
[True, True, True, False]]]))
np.testing.assert_array_equal(expected_out.asnumpy(), ms_out.asnumpy())
net_without_maxlen = SequenceMaskDynamicNet2()
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
a = np.array([2, 3, 1])
ms_out = net_without_maxlen(Tensor(a))
expected_out = Tensor(np.array([[True, True, False],
[True, True, True],
[True, False, False]]))
np.testing.assert_array_equal(expected_out.asnumpy(), ms_out.asnumpy())
def sequence_mask_optional(x):
return C.sequence_mask(Tensor(x.astype(np.int32)))
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_sequence_mask_optional_maxlen():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
a = np.array([2, 3, 1])
ms_out = sequence_mask_optional(a)
expected_out = Tensor(np.array([[True, True, False],
[True, True, True],
[True, False, False]]))
np.testing.assert_array_equal(expected_out.asnumpy(), ms_out.asnumpy())
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
a = np.array([2, 3, 1])
ms_out = sequence_mask_optional(a)
expected_out = Tensor(np.array([[True, True, False],
[True, True, True],
[True, False, False]]))
np.testing.assert_array_equal(expected_out.asnumpy(), ms_out.asnumpy())

Loading…
Cancel
Save