From 1ce576831c72350196866563a07ca9a0f65bd686 Mon Sep 17 00:00:00 2001 From: wxl Date: Thu, 25 Feb 2021 21:05:25 +0800 Subject: [PATCH] fix slice constant folding bug --- ge/host_kernels/slice_kernel.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ge/host_kernels/slice_kernel.cc b/ge/host_kernels/slice_kernel.cc index c3274465..6b91db1d 100644 --- a/ge/host_kernels/slice_kernel.cc +++ b/ge/host_kernels/slice_kernel.cc @@ -56,6 +56,8 @@ Status SliceKernel::Compute(const OpDescPtr attr, const std::vectorGetTensorDesc().GetDataType(); + uint32_t type_size = 0; + (void)TypeUtils::GetDataTypeLength(data_type, type_size); // check data type of begin and size if (begin->GetTensorDesc().GetDataType() != DT_INT32 || size->GetTensorDesc().GetDataType() != DT_INT32) { GELOGW("Data type of begin and size for slice are not DT_INT32."); @@ -69,7 +71,7 @@ Status SliceKernel::Compute(const OpDescPtr attr, const std::vectorGetData().size() / sizeof(int32_t); + size_t data_size = x_->GetData().size() / type_size; size_t begin_size = begin->GetData().size() / sizeof(int32_t); size_t size_size = size->GetData().size() / sizeof(int32_t); const ge::GeShape &x_shape = x_->GetTensorDesc().GetShape();