fix dynamic shape and add testcase for SparseGatherV2 GPU

pull/9132/head
TFbunny 4 years ago
parent 0856639fc5
commit 419f8bf72a

@ -48,6 +48,19 @@ MS_REG_GPU_KERNEL_TWO(
SparseGatherV2,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
GatherV2GpuFwdKernel, half, int)
MS_REG_GPU_KERNEL_TWO(SparseGatherV2,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat32),
GatherV2GpuFwdKernel, float, int)
MS_REG_GPU_KERNEL_TWO(SparseGatherV2,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat16),
GatherV2GpuFwdKernel, half, int)
} // namespace kernel
} // namespace mindspore

@ -794,6 +794,9 @@ class SparseGatherV2(GatherV2):
Outputs:
Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`.
Supported Platforms:
``GPU``
Examples:
>>> input_params = Tensor(np.array([[1, 2, 7, 42], [3, 4, 54, 22], [2, 2, 55, 3]]), mindspore.float32)
>>> input_indices = Tensor(np.array([1, 2]), mindspore.int32)

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save