!12879 fix bug of construct tensor

From: @lianliguang
Reviewed-by: @kisnwang,@chujinjin
Signed-off-by: @chujinjin
pull/12879/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 8707cd4f39

@ -19,7 +19,7 @@
namespace mindspore {
tensor::TensorPtr TensorConstructUtils::CreateZerosTensor(TypeId type, const std::vector<int64_t> &shape) {
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type, shape);
size_t mem_size = GetTypeByte(tensor->type()) * IntToSize(tensor->ElementsNum());
size_t mem_size = IntToSize(tensor->ElementsNum());
auto tensor_data = tensor->data_c();
char *data = reinterpret_cast<char *>(tensor_data);
MS_EXCEPTION_IF_NULL(data);
@ -30,11 +30,11 @@ tensor::TensorPtr TensorConstructUtils::CreateZerosTensor(TypeId type, const std
tensor::TensorPtr TensorConstructUtils::CreateOnesTensor(TypeId type, const std::vector<int64_t> &shape) {
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type, shape);
size_t mem_size = GetTypeByte(tensor->type()) * IntToSize(tensor->ElementsNum());
size_t mem_size = IntToSize(tensor->ElementsNum());
if (tensor->data_type() == kNumberTypeFloat32) {
SetTensorData(tensor->data_c(), 1.0, mem_size);
SetTensorData<float>(tensor->data_c(), 1.0, mem_size);
} else if (tensor->data_type() == kNumberTypeInt) {
SetTensorData(tensor->data_c(), 1, mem_size);
SetTensorData<int>(tensor->data_c(), 1, mem_size);
}
return tensor;
}

Loading…
Cancel
Save