tensor data type support string && tensor

pull/14353/head
jianghui58 4 years ago
parent e408c9efc8
commit 3b7c96afa4

@ -447,6 +447,10 @@ TensorDataPtr MakeTensorData(TypeId data_type, const ShapeVector &shape, const A
return std::make_shared<TensorDataImpl<float>>(shape, args...); return std::make_shared<TensorDataImpl<float>>(shape, args...);
case kNumberTypeFloat64: case kNumberTypeFloat64:
return std::make_shared<TensorDataImpl<double>>(shape, args...); return std::make_shared<TensorDataImpl<double>>(shape, args...);
case kObjectTypeString:
return std::make_shared<TensorDataImpl<uint8_t>>(shape, args...);
case kObjectTypeTensorType:
return std::make_shared<TensorDataImpl<int>>(shape, args...);
default: default:
break; break;
} }
@ -549,8 +553,8 @@ Tensor &Tensor::AssignValue(const Tensor &tensor) {
abstract::AbstractBasePtr Tensor::ToAbstract() { abstract::AbstractBasePtr Tensor::ToAbstract() {
auto tens = shared_from_base<Tensor>(); auto tens = shared_from_base<Tensor>();
auto dtype = tens->Dtype(); auto dtype = tens->Dtype();
if (!IsSubType(dtype, kNumber)) { if (!IsSubType(dtype, kNumber) && !IsSubType(dtype, kString) && !IsSubType(dtype, kTensorType)) {
MS_LOG(EXCEPTION) << "Expect tensor type kNumber but got: " << dtype->ToString() << "."; MS_LOG(EXCEPTION) << "Expect tensor type kNumber or kString or kTensor but got: " << dtype->ToString() << ".";
} }
auto tensor_shape = tens->shape(); auto tensor_shape = tens->shape();
auto abs_tensor = std::make_shared<abstract::AbstractTensor>(dtype, tensor_shape); auto abs_tensor = std::make_shared<abstract::AbstractTensor>(dtype, tensor_shape);

Loading…
Cancel
Save