[Feature] Lite subgraph (#22114)
parent
7d10edc5ee
commit
ad0dfb17c1
@ -0,0 +1,87 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
if(NOT LINUX OR NOT WITH_MKL)
|
||||
message("Paddle-lite will not build because the required Linux and MKL do not exist.")
|
||||
set(WITH_LITE OFF)
|
||||
return()
|
||||
endif()
|
||||
|
||||
if (NOT LITE_SOURCE_DIR OR NOT LITE_BINARY_DIR)
|
||||
include(ExternalProject)
|
||||
set(LITE_PROJECT extern_lite)
|
||||
set(LITE_SOURCES_DIR ${THIRD_PARTY_PATH}/lite)
|
||||
set(LITE_INSTALL_DIR ${THIRD_PARTY_PATH}/install/lite)
|
||||
|
||||
# No quotes, so cmake can resolve it as a command with arguments.
|
||||
set(LITE_BUILD_COMMAND $(MAKE) -j)
|
||||
set(LITE_OPTIONAL_ARGS -DWITH_MKL=ON
|
||||
-DLITE_WITH_CUDA=${WITH_GPU}
|
||||
-DWITH_MKLDNN=OFF
|
||||
-DLITE_WITH_X86=ON
|
||||
-DLITE_WITH_PROFILE=OFF
|
||||
-DWITH_LITE=OFF
|
||||
-DLITE_WITH_LIGHT_WEIGHT_FRAMEWORK=OFF
|
||||
-DWITH_PYTHON=OFF
|
||||
-DWITH_TESTING=ON
|
||||
-DLITE_BUILD_EXTRA=ON
|
||||
-DCUDNN_ROOT=${CUDNN_ROOT}
|
||||
-DLITE_WITH_ARM=OFF)
|
||||
|
||||
ExternalProject_Add(
|
||||
${LITE_PROJECT}
|
||||
${EXTERNAL_PROJECT_LOG_ARGS}
|
||||
GIT_REPOSITORY "https://github.com/PaddlePaddle/Paddle-Lite.git"
|
||||
GIT_TAG 947cda26637d46dc23f4e39d2b52e7d9a1fa6eef
|
||||
PREFIX ${LITE_SOURCES_DIR}
|
||||
UPDATE_COMMAND ""
|
||||
BUILD_COMMAND ${LITE_BUILD_COMMAND}
|
||||
INSTALL_COMMAND ""
|
||||
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
|
||||
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
|
||||
-DCMAKE_CXX_FLAGS=${LITE_CMAKE_CXX_FLAGS}
|
||||
-DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE}
|
||||
-DCMAKE_CXX_FLAGS_DEBUG=${CMAKE_CXX_FLAGS_DEBUG}
|
||||
-DCMAKE_C_FLAGS=${CMAKE_C_FLAGS}
|
||||
-DCMAKE_C_FLAGS_DEBUG=${CMAKE_C_FLAGS_DEBUG}
|
||||
-DCMAKE_C_FLAGS_RELEASE=${CMAKE_C_FLAGS_RELEASE}
|
||||
-DCMAKE_POSITION_INDEPENDENT_CODE=ON
|
||||
-DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE}
|
||||
${EXTERNAL_OPTIONAL_ARGS}
|
||||
${LITE_OPTIONAL_ARGS}
|
||||
)
|
||||
ExternalProject_Get_property(${LITE_PROJECT} BINARY_DIR)
|
||||
ExternalProject_Get_property(${LITE_PROJECT} SOURCE_DIR)
|
||||
set(LITE_BINARY_DIR ${BINARY_DIR})
|
||||
set(LITE_SOURCE_DIR ${SOURCE_DIR})
|
||||
|
||||
endif()
|
||||
|
||||
message(STATUS "Paddle-lite BINARY_DIR: ${LITE_BINARY_DIR}")
|
||||
message(STATUS "Paddle-lite SOURCE_DIR: ${LITE_SOURCE_DIR}")
|
||||
include_directories(${LITE_SOURCE_DIR})
|
||||
include_directories(${LITE_BINARY_DIR})
|
||||
|
||||
function(external_lite_static_libs alias path)
|
||||
add_library(${alias} STATIC IMPORTED GLOBAL)
|
||||
SET_PROPERTY(TARGET ${alias} PROPERTY IMPORTED_LOCATION
|
||||
${path})
|
||||
if (LITE_PROJECT)
|
||||
add_dependencies(${alias} ${LITE_PROJECT})
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
external_lite_static_libs(lite_full_static ${LITE_BINARY_DIR}/lite/api/libapi_full_static.a)
|
||||
|
||||
add_definitions(-DPADDLE_WITH_LITE)
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,45 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
#include <paddle/fluid/framework/ir/fuse_pass_base.h>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/ir/pass.h"
|
||||
#include "paddle/fluid/inference/analysis/ir_passes/subgraph_util.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace analysis {
|
||||
|
||||
class LiteSubgraphPass : public framework::ir::FusePassBase {
|
||||
public:
|
||||
void ApplyImpl(framework::ir::Graph* graph) const override;
|
||||
|
||||
private:
|
||||
void BuildOperator(framework::ir::Node* merged_node,
|
||||
framework::ProgramDesc* global_program,
|
||||
std::vector<std::string>* repetitive_params) const;
|
||||
|
||||
void SetUpEngine(framework::ProgramDesc* program,
|
||||
const std::vector<std::string>& repetitive_params,
|
||||
const std::string& unique_key,
|
||||
bool dump_model = false) const;
|
||||
};
|
||||
|
||||
} // namespace analysis
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
@ -0,0 +1,59 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/inference/analysis/ir_passes/lite_subgraph_pass.h"
|
||||
#include <gtest/gtest.h>
|
||||
#include "paddle/fluid/inference/io.h"
|
||||
#include "paddle/fluid/inference/lite/op_teller.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace analysis {
|
||||
namespace lite {
|
||||
void StrToBinaryFile(const std::string& path, const std::string& str);
|
||||
void ModifyHostSubgraphOps(framework::ProgramDesc* host_program,
|
||||
framework::BlockDesc* host_sub_block,
|
||||
const std::vector<framework::OpDesc*>& subgraph_ops);
|
||||
void AppendLiteSubBlocks(const std::vector<framework::OpDesc*>& subgraph_ops,
|
||||
framework::ProgramDesc* engine_program,
|
||||
framework::ProgramDesc* host_program,
|
||||
const int32_t host_sub_id);
|
||||
}
|
||||
|
||||
TEST(LiteSubgraphPass, basic) {
|
||||
framework::ProgramDesc host_program;
|
||||
framework::ProgramDesc engine_program;
|
||||
framework::BlockDesc* host_main_block = host_program.MutableBlock(0);
|
||||
framework::BlockDesc* host_sub_block =
|
||||
host_program.AppendBlock(*host_main_block);
|
||||
framework::OpDesc* host_while_op = host_main_block->AppendOp();
|
||||
host_main_block->Var("var_main");
|
||||
host_sub_block->Var("var_sub");
|
||||
host_while_op->SetType("while");
|
||||
host_while_op->SetAttr("sub_block", host_sub_block);
|
||||
framework::OpDesc* host_sub_block_op = host_sub_block->AppendOp();
|
||||
host_sub_block_op->SetType("leaky_relu");
|
||||
|
||||
CHECK(inference::lite::OpTeller::Global().Tell("while", *host_while_op))
|
||||
<< "Lite operator teller test failed.";
|
||||
|
||||
lite::AppendLiteSubBlocks({host_while_op}, &engine_program, &host_program,
|
||||
host_sub_block->ID());
|
||||
lite::ModifyHostSubgraphOps(&host_program, host_sub_block, {host_while_op});
|
||||
lite::StrToBinaryFile("./", "test");
|
||||
}
|
||||
|
||||
} // namespace analysis
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
@ -0,0 +1,5 @@
|
||||
cc_library(lite_op_teller SRCS op_teller.cc DEPS lite_full_static framework_proto device_context boost xxhash)
|
||||
cc_library(lite_engine SRCS engine.cc DEPS lite_full_static framework_proto)
|
||||
cc_library(lite_tensor_utils SRCS tensor_utils.cc DEPS memcpy lite_full_static framework_proto boost)
|
||||
cc_test(test_lite_engine SRCS test_engine.cc DEPS lite_engine protobuf framework_proto glog gtest analysis)
|
||||
cc_test(test_lite_tensor_utils SRCS test_tensor_utils.cc DEPS lite_engine lite_tensor_utils)
|
@ -0,0 +1,64 @@
|
||||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
#define LITE_WITH_CUDA 1
|
||||
#endif
|
||||
|
||||
#include "paddle/fluid/inference/lite/engine.h"
|
||||
#include "lite/core/context.h"
|
||||
#include "lite/core/device_info.h"
|
||||
|
||||
#include "lite/api/paddle_use_kernels.h"
|
||||
#include "lite/api/paddle_use_ops.h"
|
||||
#include "lite/api/paddle_use_passes.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace lite {
|
||||
|
||||
bool EngineManager::Empty() const { return engines_.size() == 0; }
|
||||
|
||||
bool EngineManager::Has(const std::string& name) const {
|
||||
if (engines_.count(name) == 0) {
|
||||
return false;
|
||||
}
|
||||
return engines_.at(name).get() != nullptr;
|
||||
}
|
||||
|
||||
paddle::lite::Predictor* EngineManager::Get(const std::string& name) const {
|
||||
return engines_.at(name).get();
|
||||
}
|
||||
|
||||
paddle::lite::Predictor* EngineManager::Create(const std::string& name,
|
||||
const EngineConfig& cfg) {
|
||||
auto* p = new paddle::lite::Predictor();
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
paddle::lite::Env<TARGET(kCUDA)>::Init();
|
||||
#endif
|
||||
p->Build("", cfg.model, cfg.param, cfg.valid_places, cfg.neglected_passes,
|
||||
cfg.model_type, cfg.model_from_memory);
|
||||
engines_[name].reset(p);
|
||||
return p;
|
||||
}
|
||||
|
||||
void EngineManager::DeleteAll() {
|
||||
for (auto& item : engines_) {
|
||||
item.second.reset(nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace lite
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
@ -0,0 +1,55 @@
|
||||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "lite/api/cxx_api.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace lite {
|
||||
|
||||
struct EngineConfig {
|
||||
std::string model;
|
||||
std::string param;
|
||||
paddle::lite::Place prefer_place;
|
||||
std::vector<paddle::lite::Place> valid_places;
|
||||
std::vector<std::string> neglected_passes;
|
||||
lite_api::LiteModelType model_type{lite_api::LiteModelType::kProtobuf};
|
||||
bool model_from_memory{true};
|
||||
};
|
||||
|
||||
class EngineManager {
|
||||
public:
|
||||
bool Empty() const;
|
||||
bool Has(const std::string& name) const;
|
||||
paddle::lite::Predictor* Get(const std::string& name) const;
|
||||
paddle::lite::Predictor* Create(const std::string& name,
|
||||
const EngineConfig& cfg);
|
||||
void DeleteAll();
|
||||
|
||||
private:
|
||||
std::unordered_map<std::string, std::unique_ptr<paddle::lite::Predictor>>
|
||||
engines_;
|
||||
};
|
||||
|
||||
} // namespace lite
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
@ -0,0 +1,92 @@
|
||||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <map>
|
||||
|
||||
#include "paddle/fluid/framework/block_desc.h"
|
||||
#include "paddle/fluid/framework/program_desc.h"
|
||||
#include "paddle/fluid/inference/lite/op_teller.h"
|
||||
|
||||
#include "lite/core/op_registry.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace lite {
|
||||
|
||||
// Just tell by the op_types.
|
||||
struct SimpleOpTeller : public Teller {
|
||||
SimpleOpTeller() {
|
||||
const std::map<std::string, std::string>& op2path =
|
||||
OpKernelInfoCollector::Global().GetOp2PathDict();
|
||||
auto is_non_inst = [](const std::string& op) -> bool {
|
||||
const std::vector<std::string> ops = {"feed", "fetch", "while"};
|
||||
return std::find(ops.begin(), ops.end(), op) != ops.end();
|
||||
};
|
||||
for (const auto& op : op2path) {
|
||||
if (!is_non_inst(op.first)) {
|
||||
ops_.insert(op.first);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool operator()(const std::string& op_type,
|
||||
const framework::OpDesc& op_desc) override {
|
||||
return ops_.count(op_type);
|
||||
}
|
||||
|
||||
private:
|
||||
std::unordered_set<std::string> ops_{};
|
||||
};
|
||||
|
||||
struct SingleBlockOpTeller : public Teller {
|
||||
SingleBlockOpTeller() { ops_.insert("while"); }
|
||||
|
||||
bool operator()(const std::string& op_type,
|
||||
const framework::OpDesc& op_desc) override {
|
||||
if (ops_.count(op_type)) {
|
||||
SimpleOpTeller supported;
|
||||
const int id = op_desc.GetBlockAttrId("sub_block");
|
||||
const framework::BlockDesc& block_desc =
|
||||
op_desc.Block()->Program()->Block(id);
|
||||
const std::vector<framework::OpDesc*>& ops_sub_block =
|
||||
block_desc.AllOps();
|
||||
for (auto* op : ops_sub_block) {
|
||||
if (!supported(op->Type(), *op) && !this->operator()(op->Type(), *op)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
private:
|
||||
std::unordered_set<std::string> ops_;
|
||||
};
|
||||
|
||||
bool OpTeller::Tell(const std::string& op_type, const framework::OpDesc& desc) {
|
||||
for (auto& teller : tellers_) {
|
||||
if ((*teller)(op_type, desc)) return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
OpTeller::OpTeller() {
|
||||
tellers_.emplace_back(new SimpleOpTeller);
|
||||
tellers_.emplace_back(new SingleBlockOpTeller);
|
||||
}
|
||||
|
||||
} // namespace lite
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
@ -0,0 +1,70 @@
|
||||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/op_desc.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace lite {
|
||||
|
||||
/*
|
||||
* Single Op teller definition.
|
||||
* One can override this and define a more complex tell logic, considerring more
|
||||
* issues such as op_desc.
|
||||
*/
|
||||
struct Teller {
|
||||
virtual bool operator()(const std::string& op_type,
|
||||
const framework::OpDesc& desc) = 0;
|
||||
|
||||
virtual ~Teller() = default;
|
||||
};
|
||||
/*
|
||||
* A real example:
|
||||
*
|
||||
* struct SomeTeller : public Teller {
|
||||
* bool operator()(const std::string& op_type,
|
||||
* const framework::OpDesc& desc) override {
|
||||
* return op_type == "fc" && desc.Inputs().size() == 2;
|
||||
* }
|
||||
*};
|
||||
*/
|
||||
|
||||
/*
|
||||
* class OpTeller helps to tell whether a fluid
|
||||
* operator can be transformed to a TensorRT layer.
|
||||
*/
|
||||
class OpTeller {
|
||||
public:
|
||||
static OpTeller& Global() {
|
||||
static std::unique_ptr<OpTeller> x(new OpTeller);
|
||||
return *x;
|
||||
}
|
||||
|
||||
bool Tell(const std::string& op_type, const framework::OpDesc& desc);
|
||||
|
||||
private:
|
||||
OpTeller();
|
||||
|
||||
private:
|
||||
std::vector<std::unique_ptr<Teller>> tellers_;
|
||||
};
|
||||
|
||||
} // namespace lite
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
@ -0,0 +1,181 @@
|
||||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/inference/lite/tensor_utils.h"
|
||||
#include <map>
|
||||
#include "paddle/fluid/framework/data_type.h"
|
||||
#include "paddle/fluid/inference/lite/engine.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace lite {
|
||||
namespace utils {
|
||||
|
||||
using paddle::lite_api::TargetType;
|
||||
using paddle::lite_api::PrecisionType;
|
||||
using paddle::lite_api::DataLayoutType;
|
||||
|
||||
template <typename DstLoD, typename SrcLoD>
|
||||
void SetLoD(DstLoD* dst, const SrcLoD& src) {
|
||||
dst->reserve(src.size());
|
||||
dst->clear();
|
||||
for (auto&& v : src) {
|
||||
dst->emplace_back(v);
|
||||
}
|
||||
}
|
||||
template void SetLoD<paddle::lite::LoD, framework::LoD>(
|
||||
paddle::lite::LoD* dst, const framework::LoD& src);
|
||||
template void SetLoD<framework::LoD, paddle::lite::LoD>(
|
||||
framework::LoD* dst, const paddle::lite::LoD& src);
|
||||
|
||||
platform::Place GetNativePlace(const TargetType& type, int id = 0) {
|
||||
switch (type) {
|
||||
case TargetType::kHost:
|
||||
case TargetType::kX86:
|
||||
return platform::CPUPlace();
|
||||
case TargetType::kCUDA:
|
||||
return platform::CUDAPlace(id);
|
||||
default:
|
||||
LOG(FATAL) << "Error target type.";
|
||||
return platform::Place();
|
||||
}
|
||||
}
|
||||
|
||||
TargetType GetLiteTargetType(const platform::Place& place) {
|
||||
if (platform::is_cpu_place(place)) {
|
||||
return TargetType::kHost;
|
||||
}
|
||||
return TargetType::kCUDA;
|
||||
}
|
||||
|
||||
PrecisionType GetLitePrecisionType(framework::proto::VarType::Type type) {
|
||||
switch (type) {
|
||||
case framework::proto::VarType_Type_FP32:
|
||||
return PrecisionType::kFloat;
|
||||
case framework::proto::VarType_Type_INT8:
|
||||
return PrecisionType::kInt8;
|
||||
case framework::proto::VarType_Type_INT32:
|
||||
return PrecisionType::kInt32;
|
||||
case framework::proto::VarType_Type_INT64:
|
||||
return PrecisionType::kInt64;
|
||||
default:
|
||||
LOG(FATAL) << "Error precision type.";
|
||||
return PrecisionType::kUnk;
|
||||
}
|
||||
}
|
||||
|
||||
framework::proto::VarType::Type GetNativePrecisionType(
|
||||
const PrecisionType& type) {
|
||||
switch (type) {
|
||||
case PrecisionType::kFloat:
|
||||
return framework::proto::VarType_Type_FP32;
|
||||
case PrecisionType::kInt8:
|
||||
return framework::proto::VarType_Type_INT8;
|
||||
case PrecisionType::kInt32:
|
||||
return framework::proto::VarType_Type_INT32;
|
||||
case PrecisionType::kInt64:
|
||||
return framework::proto::VarType_Type_INT64;
|
||||
default:
|
||||
LOG(FATAL) << "Error precision type.";
|
||||
return static_cast<framework::proto::VarType::Type>(-1);
|
||||
}
|
||||
}
|
||||
|
||||
framework::DataLayout GetNativeLayoutType(const DataLayoutType& type) {
|
||||
switch (type) {
|
||||
case DataLayoutType::kNCHW:
|
||||
return framework::DataLayout::kNCHW;
|
||||
default:
|
||||
LOG(FATAL) << "Error layout type.";
|
||||
return static_cast<framework::DataLayout>(-1);
|
||||
}
|
||||
}
|
||||
|
||||
void MemoryCopyAsync(const platform::Place& dst_place, void* dst_data,
|
||||
const platform::Place& src_place, const void* src_data,
|
||||
const size_t size, const platform::DeviceContext& ctx) {
|
||||
const platform::CPUPlace cpu_place;
|
||||
if (platform::is_cpu_place(dst_place) && platform::is_cpu_place(src_place)) {
|
||||
memory::Copy(cpu_place, dst_data, cpu_place, src_data, size);
|
||||
} else {
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
if (platform::is_cpu_place(dst_place) &&
|
||||
platform::is_gpu_place(src_place)) {
|
||||
LOG(FATAL) << "lite::MemoryCopy GPU->CPU is not yet implemented.";
|
||||
} else if (platform::is_gpu_place(dst_place) &&
|
||||
platform::is_cpu_place(src_place)) {
|
||||
LOG(FATAL) << "lite::MemoryCopy CPU->GPU is not yet implemented.";
|
||||
} else if (platform::is_gpu_place(dst_place) &&
|
||||
platform::is_gpu_place(src_place)) {
|
||||
auto gpu_place = boost::get<platform::CUDAPlace>(src_place);
|
||||
memory::Copy(
|
||||
gpu_place, dst_data, gpu_place, src_data, size,
|
||||
static_cast<const platform::CUDADeviceContext&>(ctx).stream());
|
||||
}
|
||||
#else
|
||||
LOG(FATAL) << "You must define PADDLE_WITH_CUDA for using CUDAPlace.";
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
void InitDstTensor(paddle::lite::Tensor* dst, const framework::LoDTensor& src) {
|
||||
// Currently, Lite needs to explicitly specify the target type of
|
||||
// the input tensor.
|
||||
constexpr int empty_size = 0;
|
||||
dst->mutable_data(GetLiteTargetType(src.place()), empty_size);
|
||||
dst->set_precision(GetLitePrecisionType(src.type()));
|
||||
SetLoD(dst->mutable_lod(), src.lod());
|
||||
}
|
||||
|
||||
void InitDstTensor(framework::LoDTensor* dst, const paddle::lite::Tensor& src) {
|
||||
constexpr framework::proto::VarType::Type dtype =
|
||||
framework::proto::VarType_Type_FP32;
|
||||
dst->mutable_data(inference::lite::utils::GetNativePlace(src.target()),
|
||||
dtype);
|
||||
SetLoD(dst->mutable_lod(), src.lod());
|
||||
}
|
||||
|
||||
template <>
|
||||
void TensorCopyAsync(paddle::lite::Tensor* dst, const framework::LoDTensor& src,
|
||||
const platform::DeviceContext& ctx) {
|
||||
InitDstTensor(dst, src);
|
||||
const platform::Place& src_place = src.place();
|
||||
const platform::Place& dst_place = GetNativePlace(dst->target());
|
||||
const size_t bytes =
|
||||
static_cast<size_t>(src.numel()) * framework::SizeOfType(src.type());
|
||||
dst->Resize(framework::vectorize(src.dims()));
|
||||
const void* src_data = src.data<void>();
|
||||
void* dst_data = dst->mutable_data(bytes);
|
||||
MemoryCopyAsync(dst_place, dst_data, src_place, src_data, bytes, ctx);
|
||||
}
|
||||
|
||||
template <>
|
||||
void TensorCopyAsync(framework::LoDTensor* dst, const paddle::lite::Tensor& src,
|
||||
const platform::DeviceContext& ctx) {
|
||||
InitDstTensor(dst, src);
|
||||
const platform::Place& src_place = GetNativePlace(src.target());
|
||||
const platform::Place& dst_place = dst->place();
|
||||
dst->Resize(paddle::framework::make_ddim(src.dims().Vectorize()));
|
||||
const size_t bytes =
|
||||
static_cast<size_t>(src.numel()) * framework::SizeOfType(dst->type());
|
||||
const void* src_data = src.raw_data();
|
||||
// When Lite is ready, the source type needs to be modified here.
|
||||
void* dst_data = dst->mutable_data(dst_place, dst->type());
|
||||
MemoryCopyAsync(dst_place, dst_data, src_place, src_data, bytes, ctx);
|
||||
}
|
||||
|
||||
} // namespace utils
|
||||
} // namespace lite
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue