Add float64 support to slice ops

pull/12347/head
TFBunny 4 years ago
parent f9a2b2004f
commit 799e51cff0

@ -1,5 +1,5 @@
/** /**
* Copyright 2019 Huawei Technologies Co., Ltd * Copyright 2019-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -18,16 +18,18 @@
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
MS_REG_GPU_KERNEL_ONE(Slice, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
SliceGpuFwdKernel, double)
MS_REG_GPU_KERNEL_ONE(Slice, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), MS_REG_GPU_KERNEL_ONE(Slice, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
SliceGpuFwdKernel, float) SliceGpuFwdKernel, float)
MS_REG_GPU_KERNEL_ONE(Slice, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
SliceGpuFwdKernel, int)
MS_REG_GPU_KERNEL_ONE(Slice, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), MS_REG_GPU_KERNEL_ONE(Slice, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
SliceGpuFwdKernel, half) SliceGpuFwdKernel, half)
MS_REG_GPU_KERNEL_ONE(Slice, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
SliceGpuFwdKernel, int16_t)
MS_REG_GPU_KERNEL_ONE(Slice, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), MS_REG_GPU_KERNEL_ONE(Slice, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
SliceGpuFwdKernel, int64_t) SliceGpuFwdKernel, int64_t)
MS_REG_GPU_KERNEL_ONE(Slice, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
SliceGpuFwdKernel, int)
MS_REG_GPU_KERNEL_ONE(Slice, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
SliceGpuFwdKernel, int16_t)
MS_REG_GPU_KERNEL_ONE(Slice, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), MS_REG_GPU_KERNEL_ONE(Slice, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
SliceGpuFwdKernel, uchar) SliceGpuFwdKernel, uchar)
MS_REG_GPU_KERNEL_ONE(Slice, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), MS_REG_GPU_KERNEL_ONE(Slice, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),

@ -1,5 +1,5 @@
/** /**
* Copyright 2019 Huawei Technologies Co., Ltd * Copyright 2019-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SLICE_GPU_KERNEL_H #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_SLICE_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SLICE_GPU_KERNEL_H #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_SLICE_GPU_KERNEL_H_
#include <vector> #include <vector>
#include <utility> #include <utility>
@ -134,4 +134,4 @@ class SliceGpuFwdKernel : public GpuKernel {
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SLICE_GPU_KERNEL_H #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_SLICE_GPU_KERNEL_H_

@ -1,5 +1,5 @@
/** /**
* Copyright 2019 Huawei Technologies Co., Ltd * Copyright 2019-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -18,17 +18,21 @@
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
MS_REG_GPU_KERNEL_ONE(
SliceGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
SliceGradGpuKernel, double)
MS_REG_GPU_KERNEL_ONE( MS_REG_GPU_KERNEL_ONE(
SliceGrad, SliceGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
SliceGradGpuKernel, float) SliceGradGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(
SliceGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
SliceGradGpuKernel, int)
MS_REG_GPU_KERNEL_ONE( MS_REG_GPU_KERNEL_ONE(
SliceGrad, SliceGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
SliceGradGpuKernel, half) SliceGradGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(
SliceGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
SliceGradGpuKernel, int)
MS_REG_GPU_KERNEL_ONE( MS_REG_GPU_KERNEL_ONE(
SliceGrad, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), SliceGrad, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
SliceGradGpuKernel, int16_t) SliceGradGpuKernel, int16_t)

@ -1,5 +1,5 @@
/** /**
* Copyright 2019 Huawei Technologies Co., Ltd * Copyright 2019-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SLICE_GRAD_GPU_KERNEL_H #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_SLICE_GRAD_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SLICE_GRAD_GPU_KERNEL_H #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_SLICE_GRAD_GPU_KERNEL_H_
#include <vector> #include <vector>
#include <algorithm> #include <algorithm>
@ -143,4 +143,4 @@ class SliceGradGpuKernel : public GpuKernel {
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SLICE_GRAD_GPU_KERNEL_H #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_SLICE_GRAD_GPU_KERNEL_H_

@ -18,6 +18,8 @@
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
MS_REG_GPU_KERNEL_ONE(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
StridedSliceGpuKernel, double)
MS_REG_GPU_KERNEL_ONE(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), MS_REG_GPU_KERNEL_ONE(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
StridedSliceGpuKernel, float) StridedSliceGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), MS_REG_GPU_KERNEL_ONE(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_STRIDED_SLICE_GPU_KERNEL_H #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_STRIDED_SLICE_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_STRIDED_SLICE_GPU_KERNEL_H #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_STRIDED_SLICE_GPU_KERNEL_H_
#include <vector> #include <vector>
#include <bitset> #include <bitset>
@ -210,4 +210,4 @@ class StridedSliceGpuKernel : public GpuKernel {
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_STRIDED_SLICE_GPU_KERNEL_H #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_STRIDED_SLICE_GPU_KERNEL_H_

@ -18,6 +18,8 @@
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
MS_REG_GPU_KERNEL_ONE(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
StridedSliceGradGpuKernel, double)
MS_REG_GPU_KERNEL_ONE(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), MS_REG_GPU_KERNEL_ONE(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
StridedSliceGradGpuKernel, float) StridedSliceGradGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), MS_REG_GPU_KERNEL_ONE(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_STRIDED_SLICE_GRAD_GPU_KERNEL_H #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_STRIDED_SLICE_GRAD_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_STRIDED_SLICE_GRAD_GPU_KERNEL_H #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_STRIDED_SLICE_GRAD_GPU_KERNEL_H_
#include <vector> #include <vector>
#include <bitset> #include <bitset>
@ -211,4 +211,4 @@ class StridedSliceGradGpuKernel : public GpuKernel {
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_STRIDED_SLICE_GRAD_GPU_KERNEL_H #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_STRIDED_SLICE_GRAD_GPU_KERNEL_H_

@ -1,5 +1,5 @@
/** /**
* Copyright 2019 Huawei Technologies Co., Ltd * Copyright 2019-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -159,57 +159,57 @@ void StridedSliceGrad(const std::vector<size_t> &dy_shape, const std::vector<int
dy, dx); dy, dx);
} }
template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1,
const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2,
const size_t d3, const size_t d4, const double *input, double *output, cudaStream_t stream);
template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1, template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1,
const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2, const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2,
const size_t d3, const size_t d4, const float *input, float *output, cudaStream_t stream); const size_t d3, const size_t d4, const float *input, float *output, cudaStream_t stream);
template void CalSliceGrad<float>(const size_t input_size, const float *dy, const std::vector<size_t> in_shape,
const std::vector<int64_t> begin, const std::vector<int64_t> size, float *output,
cudaStream_t cuda_stream);
template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1, template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1,
const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2, const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2,
const size_t d3, const size_t d4, const half *input, half *output, cudaStream_t stream); const size_t d3, const size_t d4, const half *input, half *output, cudaStream_t stream);
template void CalSliceGrad<half>(const size_t input_size, const half *dy, const std::vector<size_t> in_shape, template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1,
const std::vector<int64_t> begin, const std::vector<int64_t> size, half *output, const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2,
cudaStream_t cuda_stream); const size_t d3, const size_t d4, const int64_t *input, int64_t *output,
cudaStream_t stream);
template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1, template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1,
const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2, const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2,
const size_t d3, const size_t d4, const int *input, int *output, cudaStream_t stream); const size_t d3, const size_t d4, const int *input, int *output, cudaStream_t stream);
template void CalSliceGrad<int>(const size_t input_size, const int *dy, const std::vector<size_t> in_shape,
const std::vector<int64_t> begin, const std::vector<int64_t> size, int *output,
cudaStream_t cuda_stream);
template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1, template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1,
const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2, const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2,
const size_t d3, const size_t d4, const short *input, short *output, // NOLINT const size_t d3, const size_t d4, const short *input, short *output, // NOLINT
cudaStream_t stream); cudaStream_t stream);
template void CalSliceGrad<short>(const size_t input_size, const short *dy, // NOLINT
const std::vector<size_t> in_shape, const std::vector<int64_t> begin,
const std::vector<int64_t> size,
short *output, // NOLINT
cudaStream_t cuda_stream);
template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1, template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1,
const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2, const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2,
const size_t d3, const size_t d4, const unsigned char *input, unsigned char *output, const size_t d3, const size_t d4, const unsigned char *input, unsigned char *output,
cudaStream_t stream); cudaStream_t stream);
template void CalSliceGrad<unsigned char>(const size_t input_size, const unsigned char *dy,
const std::vector<size_t> in_shape, const std::vector<int64_t> begin,
const std::vector<int64_t> size, unsigned char *output,
cudaStream_t cuda_stream);
template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1, template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1,
const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2, const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2,
const size_t d3, const size_t d4, const int64_t *input, int64_t *output, const size_t d3, const size_t d4, const bool *input, bool *output, cudaStream_t stream);
cudaStream_t stream);
template void CalSliceGrad<double>(const size_t input_size, const double *dy, const std::vector<size_t> in_shape,
const std::vector<int64_t> begin, const std::vector<int64_t> size, double *output,
cudaStream_t cuda_stream);
template void CalSliceGrad<float>(const size_t input_size, const float *dy, const std::vector<size_t> in_shape,
const std::vector<int64_t> begin, const std::vector<int64_t> size, float *output,
cudaStream_t cuda_stream);
template void CalSliceGrad<half>(const size_t input_size, const half *dy, const std::vector<size_t> in_shape,
const std::vector<int64_t> begin, const std::vector<int64_t> size, half *output,
cudaStream_t cuda_stream);
template void CalSliceGrad<int64_t>(const size_t input_size, const int64_t *dy, const std::vector<size_t> in_shape, template void CalSliceGrad<int64_t>(const size_t input_size, const int64_t *dy, const std::vector<size_t> in_shape,
const std::vector<int64_t> begin, const std::vector<int64_t> size, int64_t *output, const std::vector<int64_t> begin, const std::vector<int64_t> size, int64_t *output,
cudaStream_t cuda_stream); cudaStream_t cuda_stream);
template void CalSliceGrad<int>(const size_t input_size, const int *dy, const std::vector<size_t> in_shape,
template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1, const std::vector<int64_t> begin, const std::vector<int64_t> size, int *output,
const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2, cudaStream_t cuda_stream);
const size_t d3, const size_t d4, const bool *input, bool *output, cudaStream_t stream); template void CalSliceGrad<short>(const size_t input_size, const short *dy, // NOLINT
const std::vector<size_t> in_shape, const std::vector<int64_t> begin,
const std::vector<int64_t> size, short *output, // NOLINT
cudaStream_t cuda_stream);
template void CalSliceGrad<unsigned char>(const size_t input_size, const unsigned char *dy,
const std::vector<size_t> in_shape, const std::vector<int64_t> begin,
const std::vector<int64_t> size, unsigned char *output,
cudaStream_t cuda_stream);
template void CalSliceGrad<bool>(const size_t input_size, const bool *dy, const std::vector<size_t> in_shape, template void CalSliceGrad<bool>(const size_t input_size, const bool *dy, const std::vector<size_t> in_shape,
const std::vector<int64_t> begin, const std::vector<int64_t> size, bool *output, const std::vector<int64_t> begin, const std::vector<int64_t> size, bool *output,
cudaStream_t cuda_stream); cudaStream_t cuda_stream);
@ -232,10 +232,15 @@ template void FillDeviceArray<unsigned char>(const size_t input_size, unsigned c
cudaStream_t cuda_stream); cudaStream_t cuda_stream);
template void FillDeviceArray<half>(const size_t input_size, half *addr, const float value, cudaStream_t cuda_stream); template void FillDeviceArray<half>(const size_t input_size, half *addr, const float value, cudaStream_t cuda_stream);
template void FillDeviceArray<float>(const size_t input_size, float *addr, const float value, cudaStream_t cuda_stream); template void FillDeviceArray<float>(const size_t input_size, float *addr, const float value, cudaStream_t cuda_stream);
template void FillDeviceArray<double>(const size_t input_size, double *addr, const float value,
cudaStream_t cuda_stream);
template void StridedSlice(const std::vector<size_t> &input_shape, const std::vector<int64_t> &begin, template void StridedSlice(const std::vector<size_t> &input_shape, const std::vector<int64_t> &begin,
const std::vector<int64_t> &strides, const std::vector<size_t> &output_shape, const std::vector<int64_t> &strides, const std::vector<size_t> &output_shape,
const bool *input, bool *output, cudaStream_t cuda_stream); const bool *input, bool *output, cudaStream_t cuda_stream);
template void StridedSlice(const std::vector<size_t> &input_shape, const std::vector<int64_t> &begin,
const std::vector<int64_t> &strides, const std::vector<size_t> &output_shape,
const double *input, double *output, cudaStream_t cuda_stream);
template void StridedSlice(const std::vector<size_t> &input_shape, const std::vector<int64_t> &begin, template void StridedSlice(const std::vector<size_t> &input_shape, const std::vector<int64_t> &begin,
const std::vector<int64_t> &strides, const std::vector<size_t> &output_shape, const std::vector<int64_t> &strides, const std::vector<size_t> &output_shape,
const float *input, float *output, cudaStream_t cuda_stream); const float *input, float *output, cudaStream_t cuda_stream);
@ -270,6 +275,9 @@ template void StridedSlice(const std::vector<size_t> &input_shape, const std::ve
template void StridedSliceGrad(const std::vector<size_t> &dy_shape, const std::vector<int64_t> &begin, template void StridedSliceGrad(const std::vector<size_t> &dy_shape, const std::vector<int64_t> &begin,
const std::vector<int64_t> &strides, const std::vector<size_t> &dx_shape, const bool *dy, const std::vector<int64_t> &strides, const std::vector<size_t> &dx_shape, const bool *dy,
bool *dx, cudaStream_t cuda_stream); bool *dx, cudaStream_t cuda_stream);
template void StridedSliceGrad(const std::vector<size_t> &dy_shape, const std::vector<int64_t> &begin,
const std::vector<int64_t> &strides, const std::vector<size_t> &dx_shape,
const double *dy, double *dx, cudaStream_t cuda_stream);
template void StridedSliceGrad(const std::vector<size_t> &dy_shape, const std::vector<int64_t> &begin, template void StridedSliceGrad(const std::vector<size_t> &dy_shape, const std::vector<int64_t> &begin,
const std::vector<int64_t> &strides, const std::vector<size_t> &dx_shape, const std::vector<int64_t> &strides, const std::vector<size_t> &dx_shape,
const float *dy, float *dx, cudaStream_t cuda_stream); const float *dy, float *dx, cudaStream_t cuda_stream);

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd # Copyright 2020-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -1859,7 +1859,7 @@ class StridedSliceGrad(PrimitiveWithInfer):
ellipsis_mask=0, ellipsis_mask=0,
new_axis_mask=0, new_axis_mask=0,
shrink_axis_mask=0): shrink_axis_mask=0):
"""Initialize StrideSliceGrad""" """Initialize StridedSliceGrad"""
validator.check_value_type('begin_mask', begin_mask, [int], self.name) validator.check_value_type('begin_mask', begin_mask, [int], self.name)
validator.check_value_type('end_mask', end_mask, [int], self.name) validator.check_value_type('end_mask', end_mask, [int], self.name)
validator.check_value_type('ellipsis_mask', ellipsis_mask, [int], self.name) validator.check_value_type('ellipsis_mask', ellipsis_mask, [int], self.name)

@ -2792,7 +2792,7 @@ class StridedSlice(PrimitiveWithInfer):
ellipsis_mask=0, ellipsis_mask=0,
new_axis_mask=0, new_axis_mask=0,
shrink_axis_mask=0): shrink_axis_mask=0):
"""Initialize StrideSlice""" """Initialize StridedSlice"""
self.init_prim_io_names(inputs=['x', 'begin', 'end', 'strides'], outputs=['output']) self.init_prim_io_names(inputs=['x', 'begin', 'end', 'strides'], outputs=['output'])
validator.check_non_negative_int(begin_mask, 'begin_mask', self.name) validator.check_non_negative_int(begin_mask, 'begin_mask', self.name)
validator.check_non_negative_int(end_mask, 'end_mask', self.name) validator.check_non_negative_int(end_mask, 'end_mask', self.name)

@ -1,4 +1,4 @@
# Copyright 2019 Huawei Technologies Co., Ltd # Copyright 2019-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -55,6 +55,9 @@ class SliceNet(nn.Cell):
return self.slice(x, (0, 11, 0, 0), (32, 7, 224, 224)) return self.slice(x, (0, 11, 0, 0), (32, 7, 224, 224))
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_slice_4d(): def test_slice_4d():
x_np = np.random.randn(32, 24, 224, 224).astype(np.float32) x_np = np.random.randn(32, 24, 224, 224).astype(np.float32)
output_np = x_np[:, 11:18, :, :] output_np = x_np[:, 11:18, :, :]
@ -64,3 +67,18 @@ def test_slice_4d():
output_ms = net(x_ms) output_ms = net(x_ms)
assert (output_ms.asnumpy() == output_np).all() assert (output_ms.asnumpy() == output_np).all()
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_slice_float64():
x = Tensor(
np.array([[[1, -1, 1], [2, -2, 2]], [[3, -3, 3], [4, -4, 4]], [[5, -5, 5], [6, -6, 6]]]).astype(np.float64))
expect = np.array([[[2., -2., 2.]],
[[4., -4., 4.]]]).astype(np.float64)
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
slice_op = Slice()
output = slice_op(x)
assert (output.asnumpy() == expect).all()

@ -1,4 +1,4 @@
# Copyright 2019 Huawei Technologies Co., Ltd # Copyright 2019-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -50,5 +50,21 @@ def test_slice():
[4., 1., 4.]], [4., 1., 4.]],
[[0., 0., 0.], [[0., 0., 0.],
[0., 0., 0.]]] [0., 0., 0.]]]
print(output) assert (output.asnumpy() == expect).all()
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_slice_float64():
x = Tensor(np.array([[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]], [[5, 5, 5], [6, 6, 6]]]).astype(np.float64))
dy = Tensor(np.array([[[3., 1., 2.]], [[4., 1., 4.]]]).astype(np.float64))
slicegrad = SliceGrad()
output = slicegrad(dy, x)
expect = np.array([[[0., 0., 0.],
[3., 1., 2.]],
[[0., 0., 0.],
[4., 1., 4.]],
[[0., 0., 0.],
[0., 0., 0.]]]).astype(np.float64)
assert (output.asnumpy() == expect).all() assert (output.asnumpy() == expect).all()

@ -239,6 +239,12 @@ def strided_slice_grad(nptype):
[0., 0., 0., 0., 0.]]]]]]]).astype(nptype) [0., 0., 0., 0., 0.]]]]]]]).astype(nptype)
assert np.allclose(dx[0].asnumpy(), expect) assert np.allclose(dx[0].asnumpy(), expect)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_strided_slice_grad_float64():
strided_slice_grad(np.float64)
@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_x86_gpu_training @pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard @pytest.mark.env_onecard

@ -102,6 +102,12 @@ def strided_slice(nptype):
[[[2122.]]]]]]]).astype(nptype) [[[2122.]]]]]]]).astype(nptype)
assert np.allclose(y.asnumpy(), expect) assert np.allclose(y.asnumpy(), expect)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_strided_slice_float64():
strided_slice(np.float64)
@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_x86_gpu_training @pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard @pytest.mark.env_onecard

Loading…
Cancel
Save