!9140 [MS][GPU][CUDA][DynamicShape] - New GPU kernel -> UnsortedSegmentMin + DynamicShape support changes to API + inferImpl func (+SegMax ST correction)

From: @danishnxt
Reviewed-by: 
Signed-off-by:
pull/9140/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit dc62360eed

@ -0,0 +1,56 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/arrays/unsorted_segment_min_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(
UnsortedSegmentMin,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
UnsortedSegmentMinGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(
UnsortedSegmentMin,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
UnsortedSegmentMinGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(
UnsortedSegmentMin,
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
UnsortedSegmentMinGpuKernel, int)
// Dynamic Mode
MS_REG_GPU_KERNEL_ONE(UnsortedSegmentMin,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat32),
UnsortedSegmentMinGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(UnsortedSegmentMin,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat16),
UnsortedSegmentMinGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(UnsortedSegmentMin,
KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt32),
UnsortedSegmentMinGpuKernel, int)
} // namespace kernel
} // namespace mindspore

@ -0,0 +1,131 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_UNSORTED_SEGMENT_MIN_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_UNSORTED_SEGMENT_MIN_H_
#include <vector>
#include <limits>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/cuda_impl/unsorted_segment_min.cuh"
namespace mindspore {
namespace kernel {
template <typename T>
class UnsortedSegmentMinGpuKernel : public GpuKernel {
public:
UnsortedSegmentMinGpuKernel() { ResetResource(); }
~UnsortedSegmentMinGpuKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
if (is_null_input_) {
return true;
}
T *input_addr = GetDeviceAddress<T>(inputs, 0);
int *indices_addr = GetDeviceAddress<int>(inputs, 1);
T *output_addr = GetDeviceAddress<T>(outputs, 0);
CalUnsortedSegmentMin(input_addr, indices_addr, num_segments_, outer_size_, inner_size_, output_addr,
reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
bool Init(const CNodePtr &kernel_node) override {
auto input_shapes = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 0);
is_null_input_ = CHECK_NULL_INPUT(input_shapes);
if (is_null_input_) {
MS_LOG(WARNING) << "UnsortedSegmentMin input is null";
InitSizeLists();
return true;
}
auto segment_ids_shapes = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 1);
auto output_shapes = AnfAlgo::GetOutputRealDeviceShapeIfExist(kernel_node, 0);
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num == 3) {
MS_LOG(INFO) << "UnsortedSegmentMin Kernel Input count is 3 - dynamic mode";
} else {
MS_LOG(INFO) << "UnsortedSegmentMin Kernel Input count is 2";
}
num_segments_ = output_shapes[0];
input_size_ = 1;
for (size_t i = 0; i < input_shapes.size(); i++) {
input_size_ *= input_shapes[i];
}
segment_ids_size_ = 1;
for (size_t i = 0; i < segment_ids_shapes.size(); i++) {
segment_ids_size_ *= segment_ids_shapes[i];
}
output_size_ = 1;
for (size_t i = 0; i < output_shapes.size(); i++) {
output_size_ *= output_shapes[i];
}
outer_size_ = input_shapes[0];
inner_size_ = 1;
for (size_t i = 1; i < input_shapes.size(); i++) {
inner_size_ *= input_shapes[i];
}
InitSizeLists();
return true;
}
void ResetResource() noexcept override {
num_segments_ = 1;
inner_size_ = 1;
outer_size_ = 1;
input_size_ = 1;
segment_ids_size_ = 1;
output_size_ = 1;
is_null_input_ = false;
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
}
protected:
void InitSizeLists() override {
input_size_list_.push_back(input_size_ * sizeof(T));
input_size_list_.push_back(segment_ids_size_ * sizeof(int));
output_size_list_.push_back(output_size_ * sizeof(T));
}
private:
int num_segments_;
size_t inner_size_;
size_t outer_size_;
size_t input_size_;
size_t segment_ids_size_;
size_t output_size_;
bool is_null_input_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_UNSORTED_SEGMENT_MIN_H_

@ -0,0 +1,79 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/cuda_impl/unsorted_segment_min.cuh"
#include <limits>
template<typename T>
__device__ __forceinline__ void max_val_init(T *init_val) {
*init_val = std::numeric_limits<T>::max();
}
// Handle fp16 differently for assignment
template<>
__device__ __forceinline__ void max_val_init(half *init_val) {
*init_val = __int2half_rd(65504); // Max value for Half
}
template <typename T>
__global__ void UnsortedSegmentMin(const T *input, const int *segment_ids, const int num_segments, size_t outer_size,
size_t inner_size, T init_K, T *output) {
max_val_init(&init_K);
for (int t_idx = blockIdx.x * blockDim.x + threadIdx.x; t_idx < KWARPSIZE * num_segments * inner_size;
t_idx += blockDim.x * gridDim.x) {
int segment_id = t_idx / KWARPSIZE / inner_size;
int inner_id = t_idx / KWARPSIZE % inner_size;
int lane_id = threadIdx.x % KWARPSIZE;
T threadK = init_K;
for (int i = lane_id; i < outer_size; i += KWARPSIZE) {
if (segment_ids[i] != segment_id) continue;
T other_K = input[i * inner_size + inner_id];
if (threadK > other_K) {
threadK = other_K;
}
}
__syncwarp();
for (int offset = KWARPSIZE / 2; offset > 0; offset /= 2) {
T other_K = __shfl_down_sync(0xffffffff, threadK, offset);
if (threadK > other_K) {
threadK = other_K;
}
}
__syncwarp();
if (lane_id == 0) {
output[segment_id * inner_size + inner_id] = threadK;
}
__syncthreads();
}
}
template <typename T>
void CalUnsortedSegmentMin(const T *input, const int *segment_ids, const int num_segments, size_t outer_size,
size_t inner_size, T *output, cudaStream_t stream) {
int size = (inner_size * KWARPSIZE * num_segments);
T init_K = std::numeric_limits<T>::lowest(); // only init here - overwritten later
UnsortedSegmentMin<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(input, segment_ids, num_segments, outer_size,
inner_size, init_K, output);
return;
}
template void CalUnsortedSegmentMin<float>(const float *input, const int *segment_ids, const int num_segments,
size_t outer_size, size_t inner_size, float *output, cudaStream_t stream);
template void CalUnsortedSegmentMin<half>(const half *input, const int *segment_ids, const int num_segments,
size_t outer_size, size_t inner_size, half *output, cudaStream_t stream);
template void CalUnsortedSegmentMin<int>(const int *input, const int *segment_ids, const int num_segments,
size_t outer_size, size_t inner_size, int *output, cudaStream_t stream);

@ -0,0 +1,28 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNSORTED_SEGMENT_MIN_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNSORTED_SEGMENT_MIN_H_
#include <cuda_runtime.h>
#include "runtime/device/gpu/cuda_common.h"
// Setting warp size to sync data across threads
#define KWARPSIZE 32
template <typename T>
void CalUnsortedSegmentMin(const T *input, const int *segment_ids, const int num_segments, size_t outer_size,
size_t inner_size, T *output, cudaStream_t stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNSORT_SEGMENT_MIN_H_

@ -115,6 +115,8 @@ AbstractBasePtr InferImplUnsortedSegmentSum(const AnalysisEnginePtr &, const Pri
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplUnsortedSegmentMax(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplUnsortedSegmentMin(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplMakeSlice(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplMakeKwarg(const AnalysisEnginePtr &, const PrimitivePtr &primitive,

@ -341,6 +341,74 @@ AbstractBasePtr InferImplUnsortedSegmentMax(const AnalysisEnginePtr &, const Pri
return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(shape, min_shape, max_shape));
}
AbstractBasePtr InferImplUnsortedSegmentMin(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 3);
auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
MS_EXCEPTION_IF_NULL(x);
MS_EXCEPTION_IF_NULL(x->shape());
auto segment_ids = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
MS_EXCEPTION_IF_NULL(segment_ids);
MS_EXCEPTION_IF_NULL(segment_ids->shape());
auto segment_ids_shape = segment_ids->shape()->shape();
(void)CheckTensorDType(x, {kFloat16, kFloat32, kInt32}, "Input 0 (x) for UnsortedSegmentMin should be %s");
(void)CheckTensorDType(segment_ids, {kInt32}, "Input 1 (segment_ids) for UnsortedSegmentMin should be %s");
// check if dynamic shape
bool x_is_dyn = (!x->shape()->min_shape().empty() && !x->shape()->max_shape().empty());
bool ids_is_dyn = (!segment_ids->shape()->min_shape().empty() && !segment_ids->shape()->max_shape().empty());
bool op_is_dynamic = x_is_dyn && ids_is_dyn;
auto x_shape = x->shape()->shape();
ShapeVector shape;
int64_t num_segments_value = 0;
if (args_spec_list[2]->isa<AbstractTensor>()) { // num_segments is Tensor
auto num_segments = args_spec_list[2]->cast<AbstractTensorPtr>();
MS_EXCEPTION_IF_NULL(num_segments);
auto num_segments_value_ptr = num_segments->BuildValue();
MS_EXCEPTION_IF_NULL(num_segments_value_ptr);
auto num_segments_tensor = num_segments_value_ptr->cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(num_segments_tensor);
num_segments_value = *static_cast<int64_t *>(num_segments_tensor->data_c());
} else if (args_spec_list[2]->isa<AbstractScalar>()) { // num_segments is Scalar
auto num_segments = CheckArg<AbstractScalar>(op_name, args_spec_list, 2);
num_segments_value = GetValue<int64_t>(num_segments->BuildValue());
} else {
MS_LOG(EXCEPTION) << "num_segments incorrect type in UnsortedSegmentMin";
}
if (num_segments_value <= 0) {
MS_LOG(EXCEPTION) << "num_segments must be > 0 in UnsortedSegmentMin";
}
shape.emplace_back(num_segments_value);
shape.insert(shape.end(), x_shape.begin() + segment_ids_shape.size(), x_shape.end());
if (!op_is_dynamic) {
if (x_shape[0] != segment_ids_shape[0]) {
MS_LOG(EXCEPTION) << "Length of segment_ids must match first value of x shape UnsortedSegmentMin";
}
return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(shape));
}
// is dynamic
ShapeVector min_shape;
ShapeVector max_shape;
min_shape.emplace_back(num_segments_value);
max_shape.emplace_back(num_segments_value);
// only run validation if shape values are known
bool x_any_shape = std::any_of(x_shape.begin(), x_shape.end(), [](int64_t dim) { return dim == Shape::SHP_ANY; });
bool ids_any_shape =
std::any_of(segment_ids_shape.begin(), segment_ids_shape.end(), [](int64_t dim) { return dim == Shape::SHP_ANY; });
if (!x_any_shape && !ids_any_shape) {
if (x_shape[0] != segment_ids_shape[0]) {
MS_LOG(EXCEPTION) << "Length of segment_ids must match first value of x shape UnsortedSegmentMin";
}
}
ShapeVector x_shape_min;
ShapeVector x_shape_max;
x_shape_min = (x_is_dyn) ? x->shape()->min_shape() : x->shape()->shape();
x_shape_max = (x_is_dyn) ? x->shape()->max_shape() : x->shape()->shape();
min_shape.insert(min_shape.end(), x_shape_min.begin() + segment_ids_shape.size(), x_shape_min.end());
max_shape.insert(max_shape.end(), x_shape_max.begin() + segment_ids_shape.size(), x_shape_max.end());
return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(shape, min_shape, max_shape));
}
AbstractBasePtr InferImplScatterAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
const std::string op_name = primitive->name();

@ -59,6 +59,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimEmbeddingLookup, {InferImplEmbeddingLookup, true}},
{prim::kPrimUnsortedSegmentSum, {InferImplUnsortedSegmentSum, true}},
{prim::kPrimUnsortedSegmentMax, {InferImplUnsortedSegmentMax, true}},
{prim::kPrimUnsortedSegmentMin, {InferImplUnsortedSegmentMin, true}},
{prim::kPrimScatterAdd, {InferImplScatterAdd, true}},
{prim::kPrimScatterUpdate, {InferImplScatterUpdate, true}},
{prim::kPrimMapCacheIdx, {InferImplMapCacheIdx, true}},

@ -1922,7 +1922,7 @@ class UnsortedSegmentSum(PrimitiveWithInfer):
return out
class UnsortedSegmentMin(PrimitiveWithInfer):
class UnsortedSegmentMin(PrimitiveWithCheck):
"""
Computes the minimum of a tensor along segments.
@ -1959,26 +1959,19 @@ class UnsortedSegmentMin(PrimitiveWithInfer):
"""Initialize UnsortedSegmentMin"""
self.init_prim_io_names(inputs=['x', 'segment_ids', 'num_segments'], outputs=['y'])
def __infer__(self, x, segment_ids, num_segments):
x_type = x['dtype']
x_shape = x['shape']
def __check__(self, x, segment_ids, num_segments):
segment_ids_shape = segment_ids['shape']
valid_type = [mstype.float16, mstype.float32, mstype.int32]
validator.check_tensor_dtype_valid("x", x['dtype'], valid_type, self.name)
validator.check_tensor_dtype_valid("segment_ids", segment_ids['dtype'], [mstype.int32], self.name)
validator.check_equal_int(len(segment_ids_shape), 1, "rank of segment_ids_shape", self.name)
validator.check(f'first shape of input_x', x_shape[0],
'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name)
num_segments_v = num_segments['value']
validator.check_value_type('num_segments', num_segments_v, [int], self.name)
validator.check_positive_int(num_segments_v, "num_segments", self.name)
segment_ids_shape_len = len(segment_ids_shape)
out_shape = [num_segments_v]
out_shape += x_shape[segment_ids_shape_len:]
out = {'shape': out_shape,
'dtype': x_type,
'value': None}
return out
num_segments_type = num_segments['dtype']
validator.check_subclass("num_segments", num_segments_type, [mstype.tensor, mstype.number], self.name)
if isinstance(num_segments_type, type(mstype.tensor)):
validator.check_tensor_dtype_valid("num_segments", num_segments_type, [mstype.int64],
self.name)
else:
validator.check_value_type('num_segments', num_segments['value'], [int], self.name)
class UnsortedSegmentMax(PrimitiveWithCheck):

@ -222,39 +222,12 @@ class UnsortedSegmentMaxDynNet(nn.Cell):
@pytest.mark.env_onecard
def test_3d_float32_dyn():
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
input_x = Tensor(np.arange(
4 * 5 * 3, dtype=np.float32).reshape(4, 5, 3), dtype=mindspore.float32)
segment_ids = Tensor([2, 1, 1, -1], mstype.int32)
num_segments = 3
num_segments = 4
net = UnsortedSegmentMaxDynNet(num_segments)
output = net(input_x, segment_ids).asnumpy()
expect = np.array([[[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38]],
[[3.0000000e+01, 3.1000000e+01, 3.2000000e+01],
[3.3000000e+01, 3.4000000e+01, 3.5000000e+01],
[3.6000000e+01, 3.7000000e+01, 3.8000000e+01],
[3.9000000e+01, 4.0000000e+01, 4.1000000e+01],
[4.2000000e+01, 4.3000000e+01, 4.4000000e+01]],
[[0.0000000e+00, 1.0000000e+00, 2.0000000e+00],
[3.0000000e+00, 4.0000000e+00, 5.0000000e+00],
[6.0000000e+00, 7.0000000e+00, 8.0000000e+00],
[9.0000000e+00, 1.0000000e+01, 1.1000000e+01],
[1.2000000e+01, 1.3000000e+01, 1.4000000e+01]]]).astype(np.float32)
np.testing.assert_array_almost_equal(output, expect)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_3d_single_init_dyn():
context.set_context(device_target='GPU')
input_x = Tensor(np.arange(
4 * 5 * 3, dtype=np.float32).reshape(4, 5, 3), dtype=mindspore.float32)
segment_ids = Tensor([3, 0, 1, -1], mstype.int32)
num_segments = 4
net = UnsortedSegmentMaxDynNet(num_segments)
output = net(input_x, segment_ids).asnumpy()
expect = np.array([[[1.5000000e+01, 1.6000000e+01, 1.7000000e+01],
[1.8000000e+01, 1.9000000e+01, 2.0000000e+01],
@ -278,7 +251,15 @@ def test_3d_single_init_dyn():
[1.2000000e+01, 1.3000000e+01, 1.4000000e+01]]]).astype(np.float32)
np.testing.assert_array_almost_equal(output, expect)
num_segments = 6
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_3d_single_init_dyn():
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
input_x = Tensor(np.arange(
4 * 5 * 3, dtype=np.float32).reshape(4, 5, 3), dtype=mindspore.float32)
segment_ids = Tensor([3, 0, 1, -1], mstype.int32)
num_segments = 4
net = UnsortedSegmentMaxDynNet(num_segments)
output = net(input_x, segment_ids).asnumpy()
expect = np.array([[[1.5000000e+01, 1.6000000e+01, 1.7000000e+01],
@ -300,15 +281,40 @@ def test_3d_single_init_dyn():
[3.0000000e+00, 4.0000000e+00, 5.0000000e+00],
[6.0000000e+00, 7.0000000e+00, 8.0000000e+00],
[9.0000000e+00, 1.0000000e+01, 1.1000000e+01],
[1.2000000e+01, 1.3000000e+01, 1.4000000e+01]],
[[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38]],
[[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38]]]).astype(np.float32)
[1.2000000e+01, 1.3000000e+01, 1.4000000e+01]]]).astype(np.float32)
np.testing.assert_array_almost_equal(output, expect)
# changing the input shape here for same net
input_x = Tensor(np.arange(
4 * 7 * 2, dtype=np.float32).reshape(4, 7, 2), dtype=mindspore.float32)
segment_ids = Tensor([3, 0, 1, -1], mstype.int32)
output = net(input_x, segment_ids).asnumpy()
expect = np.array([[[1.4000000e+01, 1.5000000e+01],
[1.6000000e+01, 1.7000000e+01],
[1.8000000e+01, 1.9000000e+01],
[2.0000000e+01, 2.1000000e+01],
[2.2000000e+01, 2.3000000e+01],
[2.4000000e+01, 2.5000000e+01],
[2.6000000e+01, 2.7000000e+01]],
[[2.8000000e+01, 2.9000000e+01],
[3.0000000e+01, 3.1000000e+01],
[3.2000000e+01, 3.3000000e+01],
[3.4000000e+01, 3.5000000e+01],
[3.6000000e+01, 3.7000000e+01],
[3.8000000e+01, 3.9000000e+01],
[4.0000000e+01, 4.1000000e+01]],
[[-3.4028235e+38, -3.4028235e+38],
[-3.4028235e+38, -3.4028235e+38],
[-3.4028235e+38, -3.4028235e+38],
[-3.4028235e+38, -3.4028235e+38],
[-3.4028235e+38, -3.4028235e+38],
[-3.4028235e+38, -3.4028235e+38],
[-3.4028235e+38, -3.4028235e+38]],
[[0.0000000e+00, 1.0000000e+00],
[2.0000000e+00, 3.0000000e+00],
[4.0000000e+00, 5.0000000e+00],
[6.0000000e+00, 7.0000000e+00],
[8.0000000e+00, 9.0000000e+00],
[1.0000000e+01, 1.1000000e+01],
[1.2000000e+01, 1.3000000e+01]]]).astype(np.float32)
np.testing.assert_array_almost_equal(output, expect)

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save