fix slice op bug

pull/13693/head
wangyanling 4 years ago
parent f0016f5574
commit fb64e14265

@ -1,5 +1,5 @@
/**
* Copyright 2021Huawei Technologies Co., Ltd
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

@ -102,8 +102,10 @@ bool SliceCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
ret = LaunchKernel<float>(inputs, outputs);
} else if (dtype_ == kNumberTypeBool) {
ret = LaunchKernel<bool>(inputs, outputs);
} else if (dtype_ == kNumberTypeFloat64) {
ret = LaunchKernel<double>(inputs, outputs);
} else {
MS_LOG(ERROR) << "Slice op only support input_x int32 and float32";
MS_LOG(ERROR) << "Slice op only support input_x bool,int32,float32 and float64";
return false;
}
return ret;

@ -55,9 +55,14 @@ class SliceCPUKernel : public CPUKernel {
TypeId dtype_{kTypeUnknown};
};
MS_REG_CPU_KERNEL(Slice, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
SliceCPUKernel);
MS_REG_CPU_KERNEL(Slice, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
SliceCPUKernel);
MS_REG_CPU_KERNEL(Slice, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), SliceCPUKernel);
MS_REG_CPU_KERNEL(Slice, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), SliceCPUKernel);
MS_REG_CPU_KERNEL(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
SliceCPUKernel);
MS_REG_CPU_KERNEL(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
SliceCPUKernel);
MS_REG_CPU_KERNEL(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),

@ -86,8 +86,10 @@ bool SliceGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
ret = LaunchKernel<float>(inputs, outputs);
} else if (dtype_ == kNumberTypeBool) {
ret = LaunchKernel<bool>(inputs, outputs);
} else if (dtype_ == kNumberTypeFloat64) {
ret = LaunchKernel<double>(inputs, outputs);
} else {
MS_LOG(ERROR) << "Slice op only support input_x int32 and float32";
MS_LOG(ERROR) << "Slice op only support input_x bool,int32,float32 and float64";
return false;
}
return ret;

@ -60,10 +60,23 @@ MS_REG_CPU_KERNEL(
SliceGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
SliceGradCPUKernel);
MS_REG_CPU_KERNEL(
SliceGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
SliceGradCPUKernel);
MS_REG_CPU_KERNEL(
SliceGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
SliceGradCPUKernel);
MS_REG_CPU_KERNEL(
SliceGrad, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
SliceGradCPUKernel);
MS_REG_CPU_KERNEL(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
SliceGradCPUKernel);
MS_REG_CPU_KERNEL(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
SliceGradCPUKernel);
MS_REG_CPU_KERNEL(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
SliceGradCPUKernel);
MS_REG_CPU_KERNEL(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
SliceGradCPUKernel);
} // namespace kernel

@ -293,7 +293,7 @@ class FakeQuantWithMinMaxObserver(UniformQuantObserver):
TypeError: If `quant_delay` is not greater than or equal to 0.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
``Ascend`` ``GPU``
Examples:
>>> fake_quant = nn.FakeQuantWithMinMaxObserver()
@ -448,7 +448,7 @@ class Conv2dBnFoldQuantOneConv(Cell):
ValueError: If `pad_mode` is not one of 'same', 'valid', 'pad'.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
``Ascend`` ``GPU``
Examples:
>>> qconfig = compression.quant.create_quant_config()

@ -4572,7 +4572,7 @@ class BroadcastTo(PrimitiveWithInfer):
target shape is in an invalid location.
Supported Platforms:
``Ascend`` ``GPU``
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> shape = (2, 3)

@ -78,6 +78,21 @@ def test_slice_grad2():
[[0., 0.], [8., 9.], [10., 11.]]]
assert (output.asnumpy() == expect).all()
def test_slice_grad3():
x = Tensor(np.array([[[1.0, 3.5, 5.8], [2.5, 4, 1]], [[3.5, 15.3, 3.1], [2.2, 4.0, 1.1]],
[[43.4, 1.1, 12.1], [2.4, 6.5, 6.3]]]), mstype.float64)
dy = Tensor(np.array([[[3.1, 1.1, 2.2]], [[4.4, 1.2, 4.2]]]), mstype.float64)
slicegrad = SliceGrad()
output = slicegrad(dy, x)
expect = [[[0., 0., 0.],
[3.1, 1.1, 2.2]],
[[0., 0., 0.],
[4.4, 1.2, 4.2]],
[[0., 0., 0.],
[0., 0., 0.]]]
print("output:\n", output)
assert (output.asnumpy() == expect).all()
class StridedSliceGrad(nn.Cell):
def __init__(self, x, begin, end, stride):
super(StridedSliceGrad, self).__init__()

@ -69,6 +69,14 @@ def test_slice2():
output = slice_op(x)
assert (output.asnumpy() == expect).all()
def test_slice_float64():
data = Tensor(np.array([[[1, 1, 1], [2, 2, 2]],
[[3, 3, 3], [4, 4, 4]],
[[5, 5, 5], [6, 6, 6]]]).astype(np.float64))
slice_op = P.Slice()
output = slice_op(data, (1, 0, 0), (1, 1, 3))
expect = [[[3.0, 3.0, 3.0]]]
assert (output.asnumpy() == expect).all()
class Slice3(nn.Cell):
def __init__(self):

Loading…
Cancel
Save