update lite subgraph. (#30056)

revert-31562-mean
Wilber 5 years ago committed by GitHub
parent a64822589f
commit 66e16b7e99
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -34,7 +34,7 @@ if (NOT LITE_SOURCE_DIR OR NOT LITE_BINARY_DIR)
set(LITE_INSTALL_DIR ${THIRD_PARTY_PATH}/install/lite)
if(NOT LITE_GIT_TAG)
set(LITE_GIT_TAG 68e64e0eb74cdd13383ae78caf889973499ebd14)
set(LITE_GIT_TAG d3a3a6931b6d22d504d21ba32b3ae972770e9204)
endif()
if(NOT CUDA_ARCH_NAME)

@ -272,6 +272,8 @@ void LiteSubgraphPass::SetUpEngine(
paddle::lite_api::Place({target_type, PRECISION(kInt64)}),
paddle::lite_api::Place({target_type, PRECISION(kFloat)}),
paddle::lite_api::Place({TARGET(kHost), PRECISION(kFloat)}),
paddle::lite_api::Place({TARGET(kX86), precision_type}),
paddle::lite_api::Place({TARGET(kX86), PRECISION(kFloat)}),
};
config.cpu_math_library_num_threads = cpu_math_library_num_threads;
config.xpu_l3_workspace_size = xpu_l3_workspace_size;

@ -195,10 +195,8 @@ void InitDstTensor(paddle::lite_api::Tensor* dst,
void InitDstTensor(framework::LoDTensor* dst,
const paddle::lite_api::Tensor& src) {
constexpr framework::proto::VarType::Type dtype =
framework::proto::VarType_Type_FP32;
dst->mutable_data(inference::lite::utils::GetNativePlace(src.target()),
dtype);
GetNativePrecisionType(src.precision()));
SetLoD(dst->mutable_lod(), src.lod());
}
@ -254,17 +252,17 @@ void TensorDataShare(paddle::lite_api::Tensor* dst, framework::LoDTensor* src) {
template <>
void TensorDataShare(framework::LoDTensor* dst, paddle::lite_api::Tensor* src) {
constexpr framework::proto::VarType::Type dtype =
framework::proto::VarType_Type_FP32;
void* src_raw_data =
GetLiteTensorDataPtr(src, GetLitePrecisionType(dtype), src->target());
size_t memory_size = GetLiteTensorNumel(*src) * sizeof(float);
GetLiteTensorDataPtr(src, src->precision(), src->target());
size_t memory_size =
GetLiteTensorNumel(*src) *
framework::SizeOfType(GetNativePrecisionType(src->precision()));
std::shared_ptr<memory::allocation::Allocation> holder(
new memory::allocation::Allocation(src_raw_data, memory_size,
GetNativePlace(src->target())));
dst->Resize(paddle::framework::make_ddim(src->shape()));
SetLoD(dst->mutable_lod(), src->lod());
dst->ResetHolderWithType(holder, dtype);
dst->ResetHolderWithType(holder, GetNativePrecisionType(src->precision()));
}
} // namespace utils

Loading…
Cancel
Save