|
|
|
@ -41,7 +41,9 @@ inline std::vector<T> GetDataFromTensor(const framework::Tensor* x) {
|
|
|
|
|
// NOTE: Converting int64 to int32 may cause data overflow.
|
|
|
|
|
vec_new_data = std::vector<T>(data, data + x->numel());
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_THROW("The dtype of Tensor must be int32 or int64.");
|
|
|
|
|
PADDLE_THROW(platform::errors::InvalidArgument(
|
|
|
|
|
"The dtype of Tensor must be int32 or int64, but received: %s",
|
|
|
|
|
x->type()));
|
|
|
|
|
}
|
|
|
|
|
return vec_new_data;
|
|
|
|
|
}
|
|
|
|
@ -53,10 +55,11 @@ inline std::vector<T> GetDataFromTensorList(
|
|
|
|
|
for (size_t i = 0; i < list_tensor.size(); ++i) {
|
|
|
|
|
auto tensor = list_tensor[i];
|
|
|
|
|
PADDLE_ENFORCE_EQ(tensor->dims(), framework::make_ddim({1}),
|
|
|
|
|
"ShapeError: The shape of Tensor in list must be [1]. "
|
|
|
|
|
"But received the shape "
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The shape of Tensor in list must be [1]. "
|
|
|
|
|
"But received its shape "
|
|
|
|
|
"is [%s]",
|
|
|
|
|
tensor->dims());
|
|
|
|
|
tensor->dims()));
|
|
|
|
|
|
|
|
|
|
if (tensor->type() == framework::proto::VarType::INT32) {
|
|
|
|
|
if (platform::is_gpu_place(tensor->place())) {
|
|
|
|
@ -76,7 +79,10 @@ inline std::vector<T> GetDataFromTensorList(
|
|
|
|
|
vec_new_data.push_back(static_cast<T>(*tensor->data<int64_t>()));
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_THROW("The dtype of Tensor in list must be int32 or int64.");
|
|
|
|
|
PADDLE_THROW(platform::errors::InvalidArgument(
|
|
|
|
|
"The dtype of Tensor in list must be int32 or int64, but received: "
|
|
|
|
|
"%s",
|
|
|
|
|
tensor->type()));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return vec_new_data;
|
|
|
|
|