|
|
|
@ -203,6 +203,8 @@ bool NativePaddlePredictor::SetFeed(const std::vector<PaddleTensor> &inputs,
|
|
|
|
|
input_ptr = input.mutable_data<int64_t>(ddim, place_);
|
|
|
|
|
} else if (inputs[i].dtype == PaddleDType::FLOAT32) {
|
|
|
|
|
input_ptr = input.mutable_data<float>(ddim, place_);
|
|
|
|
|
} else if (inputs[i].dtype == PaddleDType::INT32) {
|
|
|
|
|
input_ptr = input.mutable_data<int32_t>(ddim, place_);
|
|
|
|
|
} else {
|
|
|
|
|
LOG(ERROR) << "unsupported feed type " << inputs[i].dtype;
|
|
|
|
|
return false;
|
|
|
|
@ -281,8 +283,11 @@ bool NativePaddlePredictor::GetFetch(std::vector<PaddleTensor> *outputs,
|
|
|
|
|
} else if (type == framework::DataTypeTrait<int64_t>::DataType) {
|
|
|
|
|
GetFetchOne<int64_t>(fetch, output);
|
|
|
|
|
output->dtype = PaddleDType::INT64;
|
|
|
|
|
} else if (type == framework::DataTypeTrait<int32_t>::DataType) {
|
|
|
|
|
GetFetchOne<int32_t>(fetch, output);
|
|
|
|
|
output->dtype = PaddleDType::INT32;
|
|
|
|
|
} else {
|
|
|
|
|
LOG(ERROR) << "unknown type, only support float32 and int64 now.";
|
|
|
|
|
LOG(ERROR) << "unknown type, only support float32, int64 and int32 now.";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|