|
|
|
@ -447,6 +447,10 @@ TensorDataPtr MakeTensorData(TypeId data_type, const ShapeVector &shape, const A
|
|
|
|
|
return std::make_shared<TensorDataImpl<float>>(shape, args...);
|
|
|
|
|
case kNumberTypeFloat64:
|
|
|
|
|
return std::make_shared<TensorDataImpl<double>>(shape, args...);
|
|
|
|
|
case kObjectTypeString:
|
|
|
|
|
return std::make_shared<TensorDataImpl<uint8_t>>(shape, args...);
|
|
|
|
|
case kObjectTypeTensorType:
|
|
|
|
|
return std::make_shared<TensorDataImpl<int>>(shape, args...);
|
|
|
|
|
default:
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
@ -549,8 +553,8 @@ Tensor &Tensor::AssignValue(const Tensor &tensor) {
|
|
|
|
|
abstract::AbstractBasePtr Tensor::ToAbstract() {
|
|
|
|
|
auto tens = shared_from_base<Tensor>();
|
|
|
|
|
auto dtype = tens->Dtype();
|
|
|
|
|
if (!IsSubType(dtype, kNumber)) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Expect tensor type kNumber but got: " << dtype->ToString() << ".";
|
|
|
|
|
if (!IsSubType(dtype, kNumber) && !IsSubType(dtype, kString) && !IsSubType(dtype, kTensorType)) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Expect tensor type kNumber or kString or kTensor but got: " << dtype->ToString() << ".";
|
|
|
|
|
}
|
|
|
|
|
auto tensor_shape = tens->shape();
|
|
|
|
|
auto abs_tensor = std::make_shared<abstract::AbstractTensor>(dtype, tensor_shape);
|
|
|
|
|