initial commit

fix ci

fix ci

remove old sequence mask api

fix ci

fix ci

remove old seuqence mask tests
pull/9241/head
Peilin Wang 4 years ago
parent 689f102f86
commit f7dc2432a0

@ -0,0 +1,35 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cstdint>
#include "backend/kernel_compiler/gpu/arrays/sequence_mask_gpu_kernel.h"
namespace mindspore {
namespace kernel {
// keep this as TWO but output is always bool, just in case framework can
// support passing optional dtype and then we can be identical to tf
MS_REG_GPU_KERNEL_TWO(
SequenceMask,
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
SequenceMaskGpuKernel, int32_t, bool)
MS_REG_GPU_KERNEL_TWO(
SequenceMask,
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
SequenceMaskGpuKernel, int64_t, bool)
} // namespace kernel
} // namespace mindspore

@ -0,0 +1,101 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SEQUENCE_MASK_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SEQUENCE_MASK_GPU_KERNEL_H_
#include "backend/kernel_compiler/gpu/cuda_impl/sequence_mask_impl.cuh"
#include <cuda_runtime.h>
#include <vector>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
namespace mindspore {
namespace kernel {
template <typename T, typename S>
class SequenceMaskGpuKernel : public GpuKernel {
public:
SequenceMaskGpuKernel() { ResetResource(); }
~SequenceMaskGpuKernel() = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
T *lengths_device_address = GetDeviceAddress<T>(inputs, 0);
T *maxlen_device_address = GetDeviceAddress<T>(inputs, 1);
S *output_device_address = GetDeviceAddress<S>(outputs, 0);
CalSequenceMask(lengths_device_address, maxlen_device_address, output_device_address, output_size_,
reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
bool Init(const CNodePtr &kernel_node) override {
size_t input_count = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_count != 2) {
MS_LOG(EXCEPTION) << input_count << " inputs were provided, but SequenceMaskGpuKernel expects 2.";
}
input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
for (const int &e : input_shape_) {
lengths_size_ *= e;
}
std::vector<size_t> inferred_output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0);
for (const size_t &e : inferred_output_shape) {
output_size_ *= e;
}
InitSizeLists();
return true;
}
void ResetResource() noexcept override {
output_size_ = 1;
lengths_size_ = 1;
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
}
protected:
void InitSizeLists() override {
input_size_list_.push_back(lengths_size_ * sizeof(T));
input_size_list_.push_back(sizeof(T));
output_size_list_.push_back(output_size_);
}
private:
std::vector<size_t> input_shape_;
size_t lengths_size_;
size_t output_size_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SEQUENCE_MASK_GPU_KERNEL_H_

@ -0,0 +1,50 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cuda_runtime.h>
#include "sequence_mask_impl.cuh"
#include "runtime/device/gpu/cuda_common.h"
__global__ void ValidateArgs(int *maxlen, const int lengths_size, const int max_output_size) {
int maxlen_value = *maxlen;
if (maxlen_value < 0 || lengths_size * maxlen_value > max_output_size) {
asm("trap;");
}
}
template <typename T, typename S>
__global__ void SequenceMask(
const T *input, T *maxlen, S *output, const size_t output_size) {
T maxlen_value = *maxlen;
for (size_t gt_id = blockIdx.x * blockDim.x + threadIdx.x; gt_id < output_size; gt_id += gridDim.x * blockDim.x) {
T mask_comparison_value = gt_id % maxlen_value;
T input_comparison_index = (gt_id - mask_comparison_value) / maxlen_value;
S result = mask_comparison_value < input[input_comparison_index];
output[gt_id] = result;
}
}
template <typename T, typename S>
void CalSequenceMask(const T *lengths, T *maxlen, S *output, const size_t output_size, cudaStream_t cuda_stream) {
SequenceMask<<<GET_BLOCKS(output_size), GET_THREADS, 0, cuda_stream>>>(lengths, maxlen, output, output_size);
}
template void CalSequenceMask<int, bool>(const int *lengths, int *maxlen, bool *output, const size_t output_size,
cudaStream_t cuda_stream);
template void CalSequenceMask<int64_t, bool>(const int64_t *lengths, int64_t *maxlen, bool *output,
const size_t output_size, cudaStream_t cuda_stream);

@ -0,0 +1,25 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_SEQUENCE_MASK_CUH_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_SEQUENCE_MASK_CUH_
#include <cuda_runtime.h>
template <typename T, typename S>
void CalSequenceMask(const T *lengths, T *maxlen, S *output, const size_t output_size, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_SEQUENCE_MASK_CUH_

@ -263,6 +263,9 @@ AbstractBasePtr InferImplPad(const AnalysisEnginePtr &, const PrimitivePtr &prim
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplSplit(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplSequenceMask(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
template <typename T>
AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) {
// Inputs: a tuple or list or dict.

@ -771,5 +771,56 @@ AbstractBasePtr InferImplSplit(const AnalysisEnginePtr &, const PrimitivePtr &pr
}
return std::make_shared<AbstractTuple>(output_list);
}
AbstractBasePtr InferImplSequenceMask(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
const std::string &op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 2);
AbstractTensorPtr lengths = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
(void)CheckTensorDType(lengths, {kInt32, kInt64}, "Input 1 (lengths) for SequenceMask should be one of: %s");
int64_t maxlen_value = 0;
if (args_spec_list[1]->isa<AbstractScalar>()) {
AbstractScalarPtr maxlen = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
(void)CheckScalarType(maxlen, {kInt32, kInt64}, "Input 0 (maxlen) for SequenceMask should be one of: %s");
TypePtr maxlen_type = nullptr;
maxlen_type = maxlen->GetTypeTrack();
MS_EXCEPTION_IF_NULL(maxlen_type);
if (maxlen_type->type_id() == TypeId::kNumberTypeInt32) {
maxlen_value = static_cast<int64_t>(GetValue<int32_t>(maxlen->BuildValue()));
} else if (maxlen_type->type_id() == TypeId::kNumberTypeInt64) {
maxlen_value = GetValue<int64_t>(maxlen->BuildValue());
}
} else if (args_spec_list[1]->isa<AbstractTensor>()) {
auto maxlen_tensor_ptr = args_spec_list[1]->cast<AbstractTensorPtr>();
MS_EXCEPTION_IF_NULL(maxlen_tensor_ptr);
auto maxlen_value_ptr = maxlen_tensor_ptr->BuildValue();
MS_EXCEPTION_IF_NULL(maxlen_value_ptr);
auto maxlen_tensor = maxlen_value_ptr->cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(maxlen_tensor);
maxlen_value = *static_cast<int64_t *>(maxlen_tensor->data_c());
}
ShapeVector lengths_shape = lengths->shape()->shape();
ShapeVector lengths_shape_min = lengths->shape()->min_shape();
if (lengths_shape_min.empty()) {
lengths_shape_min = lengths_shape;
}
ShapeVector lengths_shape_max = lengths->shape()->max_shape();
if (lengths_shape_max.empty()) {
lengths_shape_max = lengths_shape;
}
lengths_shape.push_back(maxlen_value);
lengths_shape_min.push_back(maxlen_value);
lengths_shape_max.push_back(maxlen_value);
ShapePtr output_shape = std::make_shared<Shape>(lengths_shape, lengths_shape_min, lengths_shape_max);
return std::make_shared<AbstractTensor>(kBool, output_shape);
}
} // namespace abstract
} // namespace mindspore

@ -71,6 +71,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimTranspose, {InferImplTranspose, true}},
{prim::kPrimReshape, {InferImplReshape, true}},
{prim::kPrimSplit, {InferImplSplit, true}},
{prim::kPrimSequenceMask, {InferImplSequenceMask, true}},
// Structure
{prim::kPrimMakeTuple, {InferImplMakeTuple, true}},
{prim::kPrimMakeList, {InferImplMakeList, true}},

@ -119,6 +119,7 @@ inline const PrimitivePtr kPrimDynamicGRUV2Grad = std::make_shared<Primitive>("D
inline const PrimitivePtr kPrimScatterAdd = std::make_shared<Primitive>("ScatterAdd");
inline const PrimitivePtr kPrimScatterUpdate = std::make_shared<Primitive>("ScatterUpdate");
inline const PrimitivePtr kPrimSplit = std::make_shared<Primitive>("Split");
inline const PrimitivePtr kPrimSequenceMask = std::make_shared<Primitive>("SequenceMask");
// NN
inline const PrimitivePtr kPrimFlatten = std::make_shared<Primitive>("Flatten");

@ -22,7 +22,7 @@ A collection of operators to build neural networks or to compute functions.
from .image_ops import (CropAndResize)
from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
Diag, DiagPart, DType, ExpandDims, Eye,
Fill, Ones, Zeros, SequenceMask, GatherNd, GatherV2, SparseGatherV2, InvertPermutation,
Fill, Ones, Zeros, GatherNd, GatherV2, SparseGatherV2, InvertPermutation,
IsInstance, IsSubClass, ArgMaxWithValue, OnesLike, ZerosLike,
Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue, Meshgrid,
SameTypeShape, ScatterAdd, ScatterSub, ScatterMul, ScatterDiv, ScatterMax, ScatterMin,
@ -33,7 +33,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, UnsortedSegmentMax,
UnsortedSegmentProd, UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace,
SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence, EmbeddingLookup,
Unique, GatherD, Identity)
Unique, GatherD, Identity, SequenceMask)
from .comm_ops import (AllGather, AllReduce, _AlltoAll, AllSwap, ReduceScatter, Broadcast,
_MirrorOperator, ReduceOp, _VirtualDataset,
_VirtualDiv, _GetTensorSlice, Send, Receive,
@ -394,6 +394,7 @@ __all__ = [
"Pull",
"ReLUV2",
"SparseToDense",
"SequenceMask",
]
__all__.sort()

@ -1216,68 +1216,6 @@ class Zeros(PrimitiveWithInfer):
return out
class SequenceMask(PrimitiveWithInfer):
r"""
Generates sequence mask according to input lengths.
Creates a mask tensor which retains the first N elements in tensor by setting the values
to be True or one. The rest values in mask are set to False or zero.
Args:
max_length (int): Nonnegative integer, size of the last dimension in mask. Default: None.
Inputs:
- **lengths** (Union[tuple[int], list[int]]) - Defines the first N elements that are retained.
Only constant value is allowed.
- **dtype** (mindspore.dtype) - The specified type of output tensor. Only constant value is allowed.
Outputs:
Tensor.
If max_length is set, the shape of the output is (lengths.shape, max_length).
If max_length is not set and the biggest value in lengths is x. Then, the shape of
the output is (lengths.shape, x).
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> from mindspore.ops import operations as P
>>> sequence_mask = P.SequenceMask()
>>> mask = sequence_mask([2, 2, 4], mindspore.int32)
>>> print(mask)
[[1, 1, 0, 0],
[1, 1, 0, 0],
[1, 1, 1, 1]]
"""
@prim_attr_register
def __init__(self):
"""Initialize SequenceMask"""
def __infer__(self, lengths, dtype, max_length=None):
validator.check_value_type("shape", lengths['value'], [tuple, list], self.name)
valid_types = [mstype.bool_, mstype.int8, mstype.int16, mstype.int32, mstype.int64,
mstype.uint8, mstype.uint16, mstype.uint32, mstype.uint64,
mstype.float16, mstype.float32, mstype.float64]
validator.check_subclass("dtype", dtype['value'], valid_types, self.name)
nptype = mstype.dtype_to_nptype(dtype['value'])
if max_length is None:
max_length = np.max(lengths['value'])
else:
validator.check_non_negative_int(max_length['value'])
max_length = max_length['value']
row_vector = np.arange(0, max_length)
col_matrix = np.expand_dims(lengths['value'], -1)
result = (row_vector < col_matrix).astype(nptype)
out = {
'value': Tensor(result),
'shape': result.shape,
'dtype': dtype['value']
}
return out
class OnesLike(PrimitiveWithInfer):
"""
Creates a new tensor. The values of all elements are 1.
@ -4655,3 +4593,47 @@ class Identity(PrimitiveWithInfer):
'dtype': x['dtype'],
'value': None}
return out
class SequenceMask(PrimitiveWithCheck):
"""
Returns a mask tensor representing the first N positions of each cell.
If lengths has shape [d_1, d_2, ..., d_n], then the resulting tensor mask has type dtype and shape
[d_1, d_2, ..., d_n, maxlen], with mask[i_1, i_2, ..., i_n, j] = (j < lengths[i_1, i_2, ..., i_n])
Inputs:
- **lengths** (Tensor) - Tensor to calculate the mask for. All values in this tensor must be
less than `maxlen`. Must be type int32 or int64.
- **maxlen** (int) - size of the last dimension of returned tensor. Must be positive and same
tyupe as elements in `lengths`.
Outputs:
One mask tensor of shape lengths.shape + (maxlen,).
Supported Platforms:
``GPU``
Examples:
>>> x = Tensor(np.array([[1, 3], [2, 0]])
>>> sequence_mask = P.SequenceMask()
>>> output = sequence_mask(x, 3)
>>> print(output)
[[[True, False, False],
[True, True, True]],
[[True, True, False],
[False, False, False]]]
"""
@prim_attr_register
def __init__(self):
self.init_prim_io_names(inputs=["lengths", "maxlen"], outputs=["mask"])
def check_shape(self, lengths_shape, maxlen_shape):
validator.check("lengths_shape", len(lengths_shape), "", 0, Rel.GT, self.name)
validator.check("maxlen_shape", len(maxlen_shape), "", 0, Rel.EQ, self.name)
def check_dtype(self, lengths_dtype, maxlen_dtype):
validator.check_subclass("lengths_dtype", lengths_dtype, mstype.tensor, self.name)
validator.check_subclass("maxlen", maxlen_dtype, mstype.number, self.name)

@ -0,0 +1,117 @@
import numpy as np
import pytest
from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.ops.operations import _inner_ops as inner
import mindspore.nn as nn
import mindspore.context as context
def sequence_mask(x, maxlen):
sequence_mask_op = P.SequenceMask()
return sequence_mask_op(Tensor(x.astype(np.int32)), maxlen)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_sequence_mask_1d():
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
a = np.array([2, 3, 1])
maxlen = 4
ms_out = sequence_mask(a, maxlen)
expected_out = Tensor(np.array([[True, True, False, False],
[True, True, True, False],
[True, False, False, False]]))
np.testing.assert_array_equal(expected_out.asnumpy(), ms_out.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_sequence_mask_2d():
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
a = np.array([[0, 1, 3, 2], [1, 4, 4, 2]])
maxlen = 6
ms_out = sequence_mask(a, maxlen)
expected_out = Tensor(np.array([[[False, False, False, False, False, False],
[True, False, False, False, False, False],
[True, True, True, False, False, False],
[True, True, False, False, False, False]],
[[True, False, False, False, False, False],
[True, True, True, True, False, False],
[True, True, True, True, False, False],
[True, True, False, False, False, False]]]))
np.testing.assert_array_equal(expected_out.asnumpy(), ms_out.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_sequence_mask_3d():
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
a = np.array([[[2, 2], [1, 1]],
[[2, 0], [2, 1]],
[[0, 0], [0, 0]]])
maxlen = 2
ms_out = sequence_mask(a, maxlen)
expected_out = Tensor(np.array([[[[True, True], [True, True]], [[True, False], [True, False]]],
[[[True, True], [False, False]], [[True, True], [True, False]]],
[[[False, False], [False, False]], [[False, False], [False, False]]]]))
np.testing.assert_array_equal(expected_out.asnumpy(), ms_out.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_sequence_mask_maxlen_1():
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
a = np.array([[[0, 1], [1, 1]],
[[1, 0], [1, 1]],
[[0, 1], [0, 1]]])
maxlen = 1
ms_out = sequence_mask(a, maxlen)
expected_out = Tensor(np.array([[[[False], [True]], [[True], [True,]]],
[[[True], [False]], [[True], [True]]],
[[[False], [True]], [[False], [True]]]]))
np.testing.assert_array_equal(expected_out.asnumpy(), ms_out.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_sequence_mask_dynamic():
class SequenceMaskDynamicNet(nn.Cell):
def __init__(self, maxlen):
super(SequenceMaskDynamicNet, self).__init__()
self.maxlen = maxlen
self.convert_to_dynamic_shape = inner.GpuConvertToDynamicShape()
self.sequence_mask = P.SequenceMask()
def construct(self, x):
converted_to_dynamic_shape = self.convert_to_dynamic_shape(x)
return self.sequence_mask(converted_to_dynamic_shape, self.maxlen)
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
sequence_mask_net = SequenceMaskDynamicNet(4)
a = Tensor(np.array([0, 1, 0, 2, 0, 5]))
ms_out = sequence_mask_net(a)
expected_out = Tensor(np.array([[False, False, False, False],
[True, False, False, False],
[False, False, False, False],
[True, True, False, False],
[False, False, False, False],
[True, True, True, True]]))
np.testing.assert_array_equal(expected_out.asnumpy(), ms_out.asnumpy())
a = Tensor(np.array([[4, 3, 0], [0, 1, 3]]))
ms_out = sequence_mask_net(a)
expected_out = Tensor(np.array([[[True, True, True, True],
[True, True, True, False],
[False, False, False, False]],
[[False, False, False, False],
[True, False, False, False],
[True, True, True, False]]]))

@ -42,28 +42,6 @@ def test_expand_dims():
assert output.asnumpy().shape == (1, 2, 2)
def test_sequence_mask():
list_ = [2, 2, 4]
sequence_mask = P.SequenceMask()
mask1 = sequence_mask(list_, mstype.int32)
mask2 = sequence_mask(list_, mstype.int32, 5)
assert mask1.shape == (3, 4)
assert mask1.dtype == mstype.int32
assert mask2.shape == (3, 5)
assert mask2.dtype == mstype.int32
def test_sequence_mask_1():
list_ = [[2, 2, 4], [3, 4, 4]]
sequence_mask = P.SequenceMask()
mask1 = sequence_mask(list_, mstype.bool_)
mask2 = sequence_mask(list_, mstype.bool_, 5)
assert mask1.shape == (2, 3, 4)
assert mask1.dtype == mstype.bool_
assert mask2.shape == (2, 3, 5)
assert mask2.dtype == mstype.bool_
def test_cast():
input_np = np.random.randn(2, 3, 4, 5).astype(np.float32)
input_x = Tensor(input_np)

Loading…
Cancel
Save