|
|
|
@ -21,15 +21,21 @@
|
|
|
|
|
namespace paddle {
|
|
|
|
|
|
|
|
|
|
void ZeroCopyTensor::Reshape(const std::vector<int> &shape) {
|
|
|
|
|
PADDLE_ENFORCE(!name_.empty(),
|
|
|
|
|
"Need to SetName first, so that the corresponding tensor can "
|
|
|
|
|
"be retrieved.");
|
|
|
|
|
PADDLE_ENFORCE(input_or_output_,
|
|
|
|
|
"Can't reshape the output tensor, it is readonly");
|
|
|
|
|
PADDLE_ENFORCE(scope_);
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
name_.empty(), false,
|
|
|
|
|
platform::errors::PreconditionNotMet(
|
|
|
|
|
"Need to SetName first, so that the corresponding tensor can "
|
|
|
|
|
"be retrieved."));
|
|
|
|
|
PADDLE_ENFORCE_EQ(input_or_output_, true,
|
|
|
|
|
platform::errors::PermissionDenied(
|
|
|
|
|
"Can't reshape the output tensor, it is readonly"));
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(scope_, platform::errors::PreconditionNotMet(
|
|
|
|
|
"The scope should not be nullptr."));
|
|
|
|
|
auto *scope = static_cast<framework::Scope *>(scope_);
|
|
|
|
|
auto *var = scope->FindVar(name_);
|
|
|
|
|
PADDLE_ENFORCE(var, "No tensor called [%s] in the runtime scope", name_);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(
|
|
|
|
|
var, platform::errors::PreconditionNotMet(
|
|
|
|
|
"No tensor called [%s] in the runtime scope", name_));
|
|
|
|
|
auto *tensor = var->GetMutable<framework::LoDTensor>();
|
|
|
|
|
tensor->Resize(framework::make_ddim(shape));
|
|
|
|
|
}
|
|
|
|
@ -45,8 +51,10 @@ T *ZeroCopyTensor::mutable_data(PaddlePlace place) {
|
|
|
|
|
EAGER_GET_TENSOR;
|
|
|
|
|
PADDLE_ENFORCE_GT(
|
|
|
|
|
tensor->numel(), 0,
|
|
|
|
|
"You should call ZeroCopyTensor::Reshape(const std::vector<int> &shape)"
|
|
|
|
|
"function before retrieving mutable_data from input tensor.");
|
|
|
|
|
platform::errors::PreconditionNotMet(
|
|
|
|
|
"You should call ZeroCopyTensor::Reshape(const std::vector<int> "
|
|
|
|
|
"&shape)"
|
|
|
|
|
"function before retrieving mutable_data from input tensor."));
|
|
|
|
|
switch (static_cast<int>(place)) {
|
|
|
|
|
case static_cast<int>(PaddlePlace::kCPU): {
|
|
|
|
|
return tensor->mutable_data<T>(platform::CPUPlace());
|
|
|
|
@ -55,7 +63,8 @@ T *ZeroCopyTensor::mutable_data(PaddlePlace place) {
|
|
|
|
|
return tensor->mutable_data<T>(platform::CUDAPlace(device_));
|
|
|
|
|
}
|
|
|
|
|
default:
|
|
|
|
|
PADDLE_THROW("Unsupported place: %d", static_cast<int>(place));
|
|
|
|
|
PADDLE_THROW(platform::errors::Unavailable("Unsupported place: %d",
|
|
|
|
|
static_cast<int>(place)));
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
return nullptr;
|
|
|
|
@ -96,10 +105,11 @@ PaddleDType ZeroCopyTensor::type() const {
|
|
|
|
|
template <typename T>
|
|
|
|
|
void ZeroCopyTensor::copy_from_cpu(const T *data) {
|
|
|
|
|
EAGER_GET_TENSOR;
|
|
|
|
|
PADDLE_ENFORCE_GE(
|
|
|
|
|
tensor->numel(), 0,
|
|
|
|
|
"You should call ZeroCopyTensor::Reshape(const std::vector<int> &shape)"
|
|
|
|
|
"function before copying data from cpu.");
|
|
|
|
|
PADDLE_ENFORCE_GE(tensor->numel(), 0,
|
|
|
|
|
platform::errors::PreconditionNotMet(
|
|
|
|
|
"You should call ZeroCopyTensor::Reshape(const "
|
|
|
|
|
"std::vector<int> &shape)"
|
|
|
|
|
"function before copying data from cpu."));
|
|
|
|
|
size_t ele_size = tensor->numel() * sizeof(T);
|
|
|
|
|
|
|
|
|
|
if (place_ == PaddlePlace::kCPU) {
|
|
|
|
@ -116,7 +126,8 @@ void ZeroCopyTensor::copy_from_cpu(const T *data) {
|
|
|
|
|
memory::Copy(gpu_place, static_cast<void *>(t_data), platform::CPUPlace(),
|
|
|
|
|
data, ele_size, dev_ctx->stream());
|
|
|
|
|
#else
|
|
|
|
|
PADDLE_THROW("Not compiled with CUDA, should not reach here.");
|
|
|
|
|
PADDLE_THROW(platform::errors::Unavailable(
|
|
|
|
|
"Not compiled with CUDA, should not reach here."));
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -141,7 +152,8 @@ void ZeroCopyTensor::copy_to_cpu(T *data) {
|
|
|
|
|
|
|
|
|
|
cudaStreamSynchronize(dev_ctx->stream());
|
|
|
|
|
#else
|
|
|
|
|
PADDLE_THROW("Not compile with CUDA, should not reach here.");
|
|
|
|
|
PADDLE_THROW(platform::errors::Unavailable(
|
|
|
|
|
"Not compile with CUDA, should not reach here."));
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -176,20 +188,27 @@ template PD_INFER_DECL uint8_t *ZeroCopyTensor::mutable_data<uint8_t>(
|
|
|
|
|
PaddlePlace place);
|
|
|
|
|
|
|
|
|
|
void *ZeroCopyTensor::FindTensor() const {
|
|
|
|
|
PADDLE_ENFORCE(!name_.empty(),
|
|
|
|
|
"Need to SetName first, so that the corresponding tensor can "
|
|
|
|
|
"be retrieved.");
|
|
|
|
|
PADDLE_ENFORCE(scope_);
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
name_.empty(), false,
|
|
|
|
|
platform::errors::PreconditionNotMet(
|
|
|
|
|
"Need to SetName first, so that the corresponding tensor can "
|
|
|
|
|
"be retrieved."));
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(scope_, platform::errors::PreconditionNotMet(
|
|
|
|
|
"The scope should not be nullptr."));
|
|
|
|
|
auto *scope = static_cast<framework::Scope *>(scope_);
|
|
|
|
|
auto *var = scope->FindVar(name_);
|
|
|
|
|
PADDLE_ENFORCE(var, "No tensor called [%s] in the runtime scope", name_);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(
|
|
|
|
|
var, platform::errors::PreconditionNotMet(
|
|
|
|
|
"No tensor called [%s] in the runtime scope", name_));
|
|
|
|
|
auto *tensor = var->GetMutable<framework::LoDTensor>();
|
|
|
|
|
return tensor;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<int> ZeroCopyTensor::shape() const {
|
|
|
|
|
EAGER_GET_TENSOR;
|
|
|
|
|
PADDLE_ENFORCE(tensor_, "not found tensor called %s in the scope", name_);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(
|
|
|
|
|
tensor_, platform::errors::PreconditionNotMet(
|
|
|
|
|
"Not found tensor called %s in the scope", name_));
|
|
|
|
|
return framework::vectorize<int>(tensor->dims());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|