test=developrevert-15207-remove_op_handle_lock_and_fix_var
commit
07c7eaabb4
@ -0,0 +1,78 @@
|
||||
# Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
IF(NOT ${WITH_LIBMCT})
|
||||
return()
|
||||
ENDIF(NOT ${WITH_LIBMCT})
|
||||
|
||||
IF(WIN32 OR APPLE)
|
||||
MESSAGE(WARNING
|
||||
"Windows or Mac is not supported with LIBMCT in Paddle yet."
|
||||
"Force WITH_LIBMCT=OFF")
|
||||
SET(WITH_LIBMCT OFF CACHE STRING "Disable LIBMCT package in Windows and MacOS" FORCE)
|
||||
return()
|
||||
ENDIF()
|
||||
|
||||
INCLUDE(ExternalProject)
|
||||
|
||||
SET(LIBMCT_PROJECT "extern_libmct")
|
||||
IF((NOT DEFINED LIBMCT_VER) OR (NOT DEFINED LIBMCT_URL))
|
||||
MESSAGE(STATUS "use pre defined download url")
|
||||
SET(LIBMCT_VER "0.1.0" CACHE STRING "" FORCE)
|
||||
SET(LIBMCT_NAME "libmct" CACHE STRING "" FORCE)
|
||||
SET(LIBMCT_URL "https://raw.githubusercontent.com/PaddlePaddle/Fleet/release/${LIBMCT_VER}/${LIBMCT_NAME}.tar.gz" CACHE STRING "" FORCE)
|
||||
ENDIF()
|
||||
MESSAGE(STATUS "LIBMCT_NAME: ${LIBMCT_NAME}, LIBMCT_URL: ${LIBMCT_URL}")
|
||||
SET(LIBMCT_SOURCE_DIR "${THIRD_PARTY_PATH}/libmct")
|
||||
SET(LIBMCT_DOWNLOAD_DIR "${LIBMCT_SOURCE_DIR}/src/${LIBMCT_PROJECT}")
|
||||
SET(LIBMCT_DST_DIR "libmct")
|
||||
SET(LIBMCT_INSTALL_ROOT "${THIRD_PARTY_PATH}/install")
|
||||
SET(LIBMCT_INSTALL_DIR ${LIBMCT_INSTALL_ROOT}/${LIBMCT_DST_DIR})
|
||||
SET(LIBMCT_ROOT ${LIBMCT_INSTALL_DIR})
|
||||
SET(LIBMCT_INC_DIR ${LIBMCT_ROOT}/include)
|
||||
SET(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_RPATH}" "${LIBMCT_ROOT}/lib")
|
||||
|
||||
INCLUDE_DIRECTORIES(${LIBMCT_INC_DIR})
|
||||
|
||||
FILE(WRITE ${LIBMCT_DOWNLOAD_DIR}/CMakeLists.txt
|
||||
"PROJECT(LIBMCT)\n"
|
||||
"cmake_minimum_required(VERSION 3.0)\n"
|
||||
"install(DIRECTORY ${LIBMCT_NAME}/include ${LIBMCT_NAME}/lib \n"
|
||||
" DESTINATION ${LIBMCT_DST_DIR})\n")
|
||||
|
||||
ExternalProject_Add(
|
||||
${LIBMCT_PROJECT}
|
||||
${EXTERNAL_PROJECT_LOG_ARGS}
|
||||
PREFIX ${LIBMCT_SOURCE_DIR}
|
||||
DOWNLOAD_DIR ${LIBMCT_DOWNLOAD_DIR}
|
||||
DOWNLOAD_COMMAND wget --no-check-certificate ${LIBMCT_URL} -c -q -O ${LIBMCT_NAME}.tar.gz
|
||||
&& tar zxvf ${LIBMCT_NAME}.tar.gz
|
||||
DOWNLOAD_NO_PROGRESS 1
|
||||
UPDATE_COMMAND ""
|
||||
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${LIBMCT_INSTALL_ROOT}
|
||||
CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${LIBMCT_INSTALL_ROOT}
|
||||
)
|
||||
|
||||
if (${CMAKE_VERSION} VERSION_LESS "3.3.0" OR NOT WIN32)
|
||||
set(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/boost_dummy.c)
|
||||
file(WRITE ${dummyfile} "const char *dummy = \"${dummyfile}\";")
|
||||
add_library(libmct STATIC ${dummyfile})
|
||||
else()
|
||||
add_library(libmct INTERFACE)
|
||||
endif()
|
||||
|
||||
#ADD_LIBRARY(libmct SHARED IMPORTED GLOBAL)
|
||||
ADD_DEPENDENCIES(libmct ${LIBMCT_PROJECT})
|
||||
LIST(APPEND external_project_dependencies libmct)
|
||||
|
@ -0,0 +1,77 @@
|
||||
# Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
IF(NOT ${WITH_PSLIB})
|
||||
return()
|
||||
ENDIF(NOT ${WITH_PSLIB})
|
||||
|
||||
IF(WIN32 OR APPLE)
|
||||
MESSAGE(WARNING
|
||||
"Windows or Mac is not supported with PSLIB in Paddle yet."
|
||||
"Force WITH_PSLIB=OFF")
|
||||
SET(WITH_PSLIB OFF CACHE STRING "Disable PSLIB package in Windows and MacOS" FORCE)
|
||||
return()
|
||||
ENDIF()
|
||||
|
||||
INCLUDE(ExternalProject)
|
||||
|
||||
SET(PSLIB_PROJECT "extern_pslib")
|
||||
IF((NOT DEFINED PSLIB_VER) OR (NOT DEFINED PSLIB_URL))
|
||||
MESSAGE(STATUS "use pre defined download url")
|
||||
SET(PSLIB_VER "0.1.0" CACHE STRING "" FORCE)
|
||||
SET(PSLIB_NAME "pslib" CACHE STRING "" FORCE)
|
||||
SET(PSLIB_URL "https://raw.githubusercontent.com/PaddlePaddle/Fleet/release/${PSLIB_VER}/${PSLIB_NAME}.tar.gz" CACHE STRING "" FORCE)
|
||||
ENDIF()
|
||||
MESSAGE(STATUS "PSLIB_NAME: ${PSLIB_NAME}, PSLIB_URL: ${PSLIB_URL}")
|
||||
SET(PSLIB_SOURCE_DIR "${THIRD_PARTY_PATH}/pslib")
|
||||
SET(PSLIB_DOWNLOAD_DIR "${PSLIB_SOURCE_DIR}/src/${PSLIB_PROJECT}")
|
||||
SET(PSLIB_DST_DIR "pslib")
|
||||
SET(PSLIB_INSTALL_ROOT "${THIRD_PARTY_PATH}/install")
|
||||
SET(PSLIB_INSTALL_DIR ${PSLIB_INSTALL_ROOT}/${PSLIB_DST_DIR})
|
||||
SET(PSLIB_ROOT ${PSLIB_INSTALL_DIR})
|
||||
SET(PSLIB_INC_DIR ${PSLIB_ROOT}/include)
|
||||
SET(PSLIB_LIB_DIR ${PSLIB_ROOT}/lib)
|
||||
SET(PSLIB_LIB ${PSLIB_LIB_DIR}/libps.so)
|
||||
SET(PSLIB_IOMP_LIB ${PSLIB_LIB_DIR}/libiomp5.so) #todo what is this
|
||||
SET(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_RPATH}" "${PSLIB_ROOT}/lib")
|
||||
|
||||
INCLUDE_DIRECTORIES(${PSLIB_INC_DIR})
|
||||
|
||||
FILE(WRITE ${PSLIB_DOWNLOAD_DIR}/CMakeLists.txt
|
||||
"PROJECT(PSLIB)\n"
|
||||
"cmake_minimum_required(VERSION 3.0)\n"
|
||||
"install(DIRECTORY ${PSLIB_NAME}/include ${PSLIB_NAME}/lib \n"
|
||||
" DESTINATION ${PSLIB_DST_DIR})\n")
|
||||
|
||||
ExternalProject_Add(
|
||||
${PSLIB_PROJECT}
|
||||
${EXTERNAL_PROJECT_LOG_ARGS}
|
||||
PREFIX ${PSLIB_SOURCE_DIR}
|
||||
DOWNLOAD_DIR ${PSLIB_DOWNLOAD_DIR}
|
||||
DOWNLOAD_COMMAND wget --no-check-certificate ${PSLIB_URL} -c -q -O ${PSLIB_NAME}.tar.gz
|
||||
&& tar zxvf ${PSLIB_NAME}.tar.gz
|
||||
DOWNLOAD_NO_PROGRESS 1
|
||||
UPDATE_COMMAND ""
|
||||
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${PSLIB_INSTALL_ROOT}
|
||||
CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${PSLIB_INSTALL_ROOT}
|
||||
)
|
||||
|
||||
ADD_LIBRARY(pslib SHARED IMPORTED GLOBAL)
|
||||
SET_PROPERTY(TARGET pslib PROPERTY IMPORTED_LOCATION ${PSLIB_LIB})
|
||||
ADD_DEPENDENCIES(pslib ${PSLIB_PROJECT})
|
||||
LIST(APPEND external_project_dependencies pslib)
|
||||
|
||||
IF(WITH_C_API)
|
||||
INSTALL(FILES ${PSLIB_LIB} ${PSLIB_IOMP_LIB} DESTINATION lib)
|
||||
ENDIF()
|
@ -0,0 +1,77 @@
|
||||
# Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
IF(NOT ${WITH_PSLIB_BRPC})
|
||||
return()
|
||||
ENDIF(NOT ${WITH_PSLIB_BRPC})
|
||||
|
||||
IF(WIN32 OR APPLE)
|
||||
MESSAGE(WARNING
|
||||
"Windows or Mac is not supported with PSLIB_BRPC in Paddle yet."
|
||||
"Force WITH_PSLIB_BRPC=OFF")
|
||||
SET(WITH_PSLIB_BRPC OFF CACHE STRING "Disable PSLIB_BRPC package in Windows and MacOS" FORCE)
|
||||
return()
|
||||
ENDIF()
|
||||
|
||||
INCLUDE(ExternalProject)
|
||||
|
||||
SET(PSLIB_BRPC_PROJECT "extern_pslib_brpc")
|
||||
IF((NOT DEFINED PSLIB_BRPC_NAME) OR (NOT DEFINED PSLIB_BRPC_URL))
|
||||
MESSAGE(STATUS "use pre defined download url")
|
||||
SET(PSLIB_BRPC_VER "0.1.0" CACHE STRING "" FORCE)
|
||||
SET(PSLIB_BRPC_NAME "pslib_brpc" CACHE STRING "" FORCE)
|
||||
SET(PSLIB_BRPC_URL "https://raw.githubusercontent.com/PaddlePaddle/Fleet/release/${PSLIB_BRPC_VER}/${PSLIB_BRPC_NAME}.tar.gz" CACHE STRING "" FORCE)
|
||||
ENDIF()
|
||||
MESSAGE(STATUS "PSLIB_BRPC_NAME: ${PSLIB_BRPC_NAME}, PSLIB_BRPC_URL: ${PSLIB_BRPC_URL}")
|
||||
SET(PSLIB_BRPC_SOURCE_DIR "${THIRD_PARTY_PATH}/pslib_brpc")
|
||||
SET(PSLIB_BRPC_DOWNLOAD_DIR "${PSLIB_BRPC_SOURCE_DIR}/src/${PSLIB_BRPC_PROJECT}")
|
||||
SET(PSLIB_BRPC_DST_DIR "pslib_brpc")
|
||||
SET(PSLIB_BRPC_INSTALL_ROOT "${THIRD_PARTY_PATH}/install")
|
||||
SET(PSLIB_BRPC_INSTALL_DIR ${PSLIB_BRPC_INSTALL_ROOT}/${PSLIB_BRPC_DST_DIR})
|
||||
SET(PSLIB_BRPC_ROOT ${PSLIB_BRPC_INSTALL_DIR})
|
||||
SET(PSLIB_BRPC_INC_DIR ${PSLIB_BRPC_ROOT}/include)
|
||||
SET(PSLIB_BRPC_LIB_DIR ${PSLIB_BRPC_ROOT}/lib)
|
||||
SET(PSLIB_BRPC_LIB ${PSLIB_BRPC_LIB_DIR}/libbrpc.a)
|
||||
SET(PSLIB_BRPC_IOMP_LIB ${PSLIB_BRPC_LIB_DIR}/libiomp5.so) #todo what is this
|
||||
SET(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_RPATH}" "${PSLIB_BRPC_ROOT}/lib")
|
||||
|
||||
INCLUDE_DIRECTORIES(${PSLIB_BRPC_INC_DIR})
|
||||
|
||||
FILE(WRITE ${PSLIB_BRPC_DOWNLOAD_DIR}/CMakeLists.txt
|
||||
"PROJECT(PSLIB_BRPC)\n"
|
||||
"cmake_minimum_required(VERSION 3.0)\n"
|
||||
"install(DIRECTORY ${PSLIB_BRPC_NAME}/include ${PSLIB_BRPC_NAME}/lib \n"
|
||||
" DESTINATION ${PSLIB_BRPC_DST_DIR})\n")
|
||||
|
||||
ExternalProject_Add(
|
||||
${PSLIB_BRPC_PROJECT}
|
||||
${EXTERNAL_PROJECT_LOG_ARGS}
|
||||
PREFIX ${PSLIB_BRPC_SOURCE_DIR}
|
||||
DOWNLOAD_DIR ${PSLIB_BRPC_DOWNLOAD_DIR}
|
||||
DOWNLOAD_COMMAND wget --no-check-certificate ${PSLIB_BRPC_URL} -c -q -O ${PSLIB_BRPC_NAME}.tar.gz
|
||||
&& tar zxvf ${PSLIB_BRPC_NAME}.tar.gz
|
||||
DOWNLOAD_NO_PROGRESS 1
|
||||
UPDATE_COMMAND ""
|
||||
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${PSLIB_BRPC_INSTALL_ROOT}
|
||||
CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${PSLIB_BRPC_INSTALL_ROOT}
|
||||
)
|
||||
|
||||
ADD_LIBRARY(pslib_brpc SHARED IMPORTED GLOBAL)
|
||||
SET_PROPERTY(TARGET pslib_brpc PROPERTY IMPORTED_LOCATION ${PSLIB_BRPC_LIB})
|
||||
ADD_DEPENDENCIES(pslib_brpc ${PSLIB_BRPC_PROJECT})
|
||||
LIST(APPEND external_project_dependencies pslib_brpc)
|
||||
|
||||
IF(WITH_C_API)
|
||||
INSTALL(FILES ${PSLIB_BRPC_LIB} ${PSLIB_BRPC_IOMP_LIB} DESTINATION lib)
|
||||
ENDIF()
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,120 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <algorithm>
|
||||
#include <list>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "paddle/fluid/framework/data_type.h"
|
||||
#include "paddle/fluid/framework/details/memory_reuse_types.h"
|
||||
#include "paddle/fluid/framework/ir/graph.h"
|
||||
#include "paddle/fluid/framework/ir/pass.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
constexpr char kAllOpDescs[] = "all_op_descs";
|
||||
|
||||
std::vector<ir::Node*> SortOpLikeDescOrder(const ir::Graph& graph);
|
||||
// sort op in bfs order
|
||||
std::vector<ir::Node*> BFSSortGraphOps(const ir::Graph& graph);
|
||||
|
||||
class ControlFlowGraph;
|
||||
|
||||
class AnalysisVarPass : public ir::Pass {
|
||||
protected:
|
||||
std::unique_ptr<ir::Graph> ApplyImpl(
|
||||
std::unique_ptr<ir::Graph> graph) const override;
|
||||
|
||||
private:
|
||||
// fill the variable map(var_nodes) by version.
|
||||
void InitSSAGraphNodes() const;
|
||||
// update program descs
|
||||
void RenameVarInGraphDesc(const std::string& var,
|
||||
const std::string& cache_var, size_t idx) const;
|
||||
// update ir nodes
|
||||
void RenameVarInGraphNode(const std::string& var,
|
||||
const std::string& cache_var, size_t idx,
|
||||
ir::Graph* graph) const;
|
||||
|
||||
void SubGraphOptimize(OpDesc* op_desc) const;
|
||||
// valid a tensor can be reuse or not
|
||||
bool NodeCanReused(ir::Node* node) const;
|
||||
// scan subblock and collect the output/input variables.
|
||||
std::unordered_set<std::string> GetSubBlockVars(
|
||||
const std::unordered_set<ir::Node*>&) const;
|
||||
// check op has subblock or not
|
||||
bool OpHasSubBlock(OpDesc* desc) const;
|
||||
|
||||
private:
|
||||
// Reuse Node Pool, Owned.
|
||||
mutable OrderedNodePairPool pool_;
|
||||
// controlflow Graph
|
||||
mutable std::unique_ptr<ControlFlowGraph> cfg_;
|
||||
// skip set
|
||||
mutable std::unordered_set<std::string> skip_set_;
|
||||
// var nodes
|
||||
mutable std::map<std::string, std::vector<ir::Node*>> var_nodes_;
|
||||
};
|
||||
|
||||
class ControlFlowGraph {
|
||||
public:
|
||||
ControlFlowGraph() = default;
|
||||
// For IR Graph in parallelexecutor
|
||||
explicit ControlFlowGraph(const ir::Graph& graph);
|
||||
|
||||
void LiveVariableAnalysis();
|
||||
|
||||
void RenameVarInCFGGraph(const std::string& old_node,
|
||||
const std::string& new_node, int begin_idx);
|
||||
|
||||
const std::set<std::string> LiveIn(ir::Node* op) const;
|
||||
const std::set<std::string> LiveOut(ir::Node* op) const;
|
||||
const std::set<std::string> Use(ir::Node* op) const;
|
||||
const std::vector<ir::Node*> Ops() const;
|
||||
std::vector<ir::Node*>& Ops();
|
||||
|
||||
// for ssa-graph nodes
|
||||
ir::Node* GetNodeFromVarName(const std::string& name, ir::Node* op) const;
|
||||
|
||||
private:
|
||||
void BuildCFGGraph();
|
||||
void ConnectNodes();
|
||||
using NodeListMap = std::unordered_map<ir::Node*, std::set<ir::Node*>>;
|
||||
using VarSetMap = std::map<ir::Node*, std::set<std::string>>;
|
||||
// successors ops use the output variables.
|
||||
NodeListMap successors_;
|
||||
// predecessors ops generated input variables.
|
||||
NodeListMap predecessors_;
|
||||
// variables lived before run current op.
|
||||
VarSetMap live_in_;
|
||||
// variables lived after run current op.
|
||||
VarSetMap live_out_;
|
||||
VarSetMap uses_; // op inputs
|
||||
VarSetMap defs_; // op outputs
|
||||
|
||||
std::vector<ir::Node*> ops_; // op sequence by topology sort
|
||||
};
|
||||
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,140 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/details/computation_op_handle.h"
|
||||
#include "paddle/fluid/framework/details/op_handle_base.h"
|
||||
#include "paddle/fluid/framework/details/var_handle.h"
|
||||
#include "paddle/fluid/framework/garbage_collector.h"
|
||||
#include "paddle/fluid/framework/lod_tensor_array.h"
|
||||
#include "paddle/fluid/framework/scope.h"
|
||||
#include "paddle/fluid/framework/selected_rows.h"
|
||||
#include "paddle/fluid/framework/tensor.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
|
||||
class EarlyDeleteOpHandle : public OpHandleBase {
|
||||
public:
|
||||
EarlyDeleteOpHandle(ir::Node* node, const Scope* scope,
|
||||
const platform::Place& place,
|
||||
const std::vector<std::string>& names,
|
||||
GarbageCollector* gc)
|
||||
: OpHandleBase(node),
|
||||
scope_(scope),
|
||||
place_(place),
|
||||
names_(names),
|
||||
gc_(gc) {
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
if (IsStreamGarabageCollector()) {
|
||||
auto gpu_place = boost::get<platform::CUDAPlace>(place);
|
||||
PADDLE_ENFORCE(cudaSetDevice(gpu_place.device));
|
||||
PADDLE_ENFORCE(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
~EarlyDeleteOpHandle() {
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
if (IsStreamGarabageCollector()) {
|
||||
auto gpu_place = boost::get<platform::CUDAPlace>(dev_ctx_->GetPlace());
|
||||
PADDLE_ENFORCE(cudaSetDevice(gpu_place.device));
|
||||
PADDLE_ENFORCE(cudaEventDestroy(event_));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
std::string Name() const override { return "early_delete"; }
|
||||
|
||||
protected:
|
||||
void RunImpl() override {
|
||||
std::vector<std::shared_ptr<memory::Allocation>> tensors;
|
||||
auto* local_scope = scope_->FindVar(kLocalExecScopeName)->Get<Scope*>();
|
||||
for (auto& var_name : names_) {
|
||||
auto* var = local_scope->FindVar(var_name);
|
||||
PADDLE_ENFORCE(var != nullptr,
|
||||
string::Sprintf("Local Scope not has var %s", var_name));
|
||||
if (var->IsType<LoDTensor>()) {
|
||||
tensors.emplace_back(var->GetMutable<LoDTensor>()->MoveMemoryHolder());
|
||||
} else if (var->IsType<SelectedRows>()) {
|
||||
tensors.emplace_back(var->GetMutable<SelectedRows>()
|
||||
->mutable_value()
|
||||
->MoveMemoryHolder());
|
||||
} else if (var->IsType<LoDTensorArray>()) {
|
||||
LoDTensorArray* tensor_array = var->GetMutable<LoDTensorArray>();
|
||||
for (auto& tensor : *tensor_array) {
|
||||
tensors.emplace_back(tensor.MoveMemoryHolder());
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!tensors.empty()) {
|
||||
ClearTensors(tensors);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
void ClearTensors(
|
||||
const std::vector<std::shared_ptr<memory::Allocation>>& tensors) {
|
||||
if (platform::is_cpu_place(place_)) {
|
||||
ClearCPUTensors(tensors);
|
||||
} else {
|
||||
ClearGPUTensors(tensors);
|
||||
}
|
||||
}
|
||||
|
||||
void ClearCPUTensors(
|
||||
const std::vector<std::shared_ptr<memory::Allocation>>& tensors) {
|
||||
auto* gc = dynamic_cast<CPUGarbageCollector*>(gc_);
|
||||
if (gc != nullptr) {
|
||||
gc->Add(tensors);
|
||||
}
|
||||
}
|
||||
|
||||
void ClearGPUTensors(
|
||||
const std::vector<std::shared_ptr<memory::Allocation>>& tensors) {
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
auto* gc = dynamic_cast<StreamGarbageCollector*>(gc_);
|
||||
if (gc != nullptr) {
|
||||
auto compute_stream = dev_ctx_->stream();
|
||||
auto callback_stream = gc->stream();
|
||||
auto callback_func = [=]() {
|
||||
PADDLE_ENFORCE(cudaEventRecord(event_, compute_stream));
|
||||
PADDLE_ENFORCE(cudaStreamWaitEvent(callback_stream, event_, 0));
|
||||
};
|
||||
gc_->Add(tensors, callback_func);
|
||||
} else {
|
||||
gc_->Add(tensors);
|
||||
}
|
||||
}
|
||||
|
||||
bool IsStreamGarabageCollector() const {
|
||||
return dynamic_cast<const StreamGarbageCollector*>(gc_) != nullptr;
|
||||
#endif
|
||||
}
|
||||
|
||||
const Scope* scope_;
|
||||
const platform::Place place_;
|
||||
std::vector<std::string> names_;
|
||||
GarbageCollector* gc_;
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
platform::CUDADeviceContext* dev_ctx_;
|
||||
cudaEvent_t event_;
|
||||
#endif
|
||||
};
|
||||
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,117 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/framework/details/memory_early_delete_pass.h"
|
||||
#include <queue>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/details/memory_reuse_types.h"
|
||||
#include "paddle/fluid/framework/details/multi_devices_helper.h"
|
||||
#include "paddle/fluid/framework/details/reference_count_pass_helper.h"
|
||||
#include "paddle/fluid/framework/ir/graph_helper.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
|
||||
static ComputationOpHandle* FindNextComputationOpHandle(VarHandle* var_in) {
|
||||
std::queue<VarHandleBase*> queue;
|
||||
queue.push(var_in);
|
||||
do {
|
||||
auto* var = queue.front();
|
||||
queue.pop();
|
||||
for (auto* op : var->PendingOps()) {
|
||||
auto* compute_op = dynamic_cast<ComputationOpHandle*>(op);
|
||||
if (compute_op != nullptr && compute_op->GetPlace() == var_in->place_) {
|
||||
return compute_op;
|
||||
}
|
||||
for (auto* out_var : op->Outputs()) {
|
||||
queue.push(out_var);
|
||||
}
|
||||
}
|
||||
} while (!queue.empty());
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::unique_ptr<ir::Graph> MemoryEarlyDeletePass::ApplyImpl(
|
||||
std::unique_ptr<ir::Graph> graph) const {
|
||||
auto& graph_pool = Get<GraphNodePool>(kGraphNodePool);
|
||||
auto& gcs = Get<GarbageCollectorMap>(kGarbageCollector);
|
||||
|
||||
std::unordered_map<std::string, std::unordered_set<OpDesc*>> unlived_vars;
|
||||
unlived_vars.reserve(graph_pool.size());
|
||||
for (auto& pair : graph_pool) {
|
||||
unlived_vars.insert(std::make_pair(pair.first, pair.second));
|
||||
}
|
||||
|
||||
auto compare_and_insert_early_delete_op = [&](
|
||||
OpHandleBase* op, const std::vector<VarHandleBase*>& vars) {
|
||||
if (unlived_vars.empty()) return;
|
||||
// unlived vars can be deleted after the last used op has finished.
|
||||
auto* compute_op = dynamic_cast<ComputationOpHandle*>(op);
|
||||
const auto& places = Get<std::vector<platform::Place>>(kAllPlaces);
|
||||
for (auto& var : vars) {
|
||||
auto* var_handle = dynamic_cast<VarHandle*>(var);
|
||||
auto var_name = var->Node()->Name();
|
||||
auto& var_place = var_handle->place_;
|
||||
if (unlived_vars.count(var_name) == 0) continue;
|
||||
if (!unlived_vars[var_name].empty()) {
|
||||
if (compute_op != nullptr &&
|
||||
unlived_vars[var_name].count(compute_op->Node()->Op()) != 0) {
|
||||
unlived_vars[var_name].erase(compute_op->Node()->Op());
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (var_handle == nullptr || !var_handle->Node()->IsVar() ||
|
||||
var_handle->Node()->IsCtrlVar())
|
||||
continue;
|
||||
|
||||
// shameless copyed from reference count pass.
|
||||
if (compute_op == nullptr) {
|
||||
// use next computation op scope
|
||||
compute_op = FindNextComputationOpHandle(var_handle);
|
||||
}
|
||||
auto* early_delete_node =
|
||||
graph->CreateEmptyNode("early_delete", ir::Node::Type::kOperation);
|
||||
GarbageCollector* gc = gcs.at(places[compute_op->GetScopeIdx()]).get();
|
||||
auto* early_delete_handle = new EarlyDeleteOpHandle(
|
||||
early_delete_node, compute_op->GetScope(), var_place, {var_name}, gc);
|
||||
if (compute_op->Outputs().empty()) {
|
||||
auto* dep_var = new DummyVarHandle(graph->CreateControlDepVar());
|
||||
compute_op->AddOutput(dep_var);
|
||||
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
|
||||
}
|
||||
early_delete_handle->AddInput(compute_op->Outputs().front());
|
||||
VLOG(5) << "Add early delete op " << var_name << " to Operator"
|
||||
<< compute_op->Name();
|
||||
}
|
||||
};
|
||||
|
||||
auto all_ops = ir::FilterByNodeWrapper<OpHandleBase>(*graph);
|
||||
for (auto& op : all_ops) {
|
||||
compare_and_insert_early_delete_op(op, op->Inputs());
|
||||
compare_and_insert_early_delete_op(op, op->Outputs());
|
||||
}
|
||||
return graph;
|
||||
}
|
||||
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
|
||||
REGISTER_PASS(memory_early_delete_pass,
|
||||
paddle::framework::details::MemoryEarlyDeletePass)
|
||||
.RequireGraphAttr(paddle::framework::details::kGraphNodePool)
|
||||
.RequireGraphAttr(paddle::framework::details::kGarbageCollector);
|
@ -0,0 +1,32 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
#include "paddle/fluid/framework/details/early_delete_op_handle.h"
|
||||
#include "paddle/fluid/framework/ir/graph.h"
|
||||
#include "paddle/fluid/framework/ir/pass.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
|
||||
class MemoryEarlyDeletePass : public ir::Pass {
|
||||
protected:
|
||||
std::unique_ptr<ir::Graph> ApplyImpl(
|
||||
std::unique_ptr<ir::Graph> graph) const override;
|
||||
};
|
||||
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,155 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/framework/details/memory_reuse_types.h"
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
|
||||
size_t NodeSizeInBytes(ir::Node* n) {
|
||||
auto* desc = FindVarDescInBlock(n);
|
||||
auto shape = desc->GetShape();
|
||||
size_t type_size = SizeOfType(desc->GetDataType());
|
||||
int size = 1;
|
||||
for (auto& s : shape) {
|
||||
size *= s;
|
||||
}
|
||||
return type_size * std::abs(size);
|
||||
}
|
||||
|
||||
std::string DebugStringImpl(VarDesc* var) {
|
||||
std::stringstream ss;
|
||||
ss << var->Name();
|
||||
ss << "[";
|
||||
try {
|
||||
auto shape = var->GetShape();
|
||||
for (size_t i = 0; i < shape.size(); ++i) {
|
||||
if (i != shape.size() - 1) {
|
||||
ss << shape[i] << ",";
|
||||
} else {
|
||||
ss << shape[i];
|
||||
}
|
||||
}
|
||||
ss << "]";
|
||||
} catch (...) {
|
||||
ss << "Var has no VarDesc !!! Name:" << var->Name();
|
||||
}
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
std::string DebugString(ir::Node* var) {
|
||||
return DebugStringImpl(FindVarDescInBlock(var));
|
||||
}
|
||||
// return DebugString(var->Var()); }
|
||||
|
||||
// NOTE(dzh): based ir node, if a large node has been reused
|
||||
// by a small size node, then next time it appear in pool, it will
|
||||
// have the small size. Find the original node shap from blockdesc.
|
||||
VarDesc* FindVarDescInBlock(ir::Node* n) {
|
||||
PADDLE_ENFORCE(n->IsVar() && !n->IsCtrlVar() && n->inputs.size() == 1);
|
||||
BlockDesc* block = n->inputs[0]->Op()->Block();
|
||||
PADDLE_ENFORCE(block->HasVar(n->Name()),
|
||||
string::Sprintf("Block do not has var %s", n->Name()));
|
||||
return block->FindVar(n->Name());
|
||||
}
|
||||
|
||||
struct NodeComparator {
|
||||
bool operator()(ir::Node* lhs, ir::Node* rhs) const {
|
||||
auto* lhs_desc = FindVarDescInBlock(lhs);
|
||||
auto* rhs_desc = FindVarDescInBlock(rhs);
|
||||
auto lhs_shape = lhs_desc->GetShape();
|
||||
auto rhs_shape = rhs_desc->GetShape();
|
||||
if ((lhs_shape[0] == -1 && rhs_shape[0] == -1) ||
|
||||
(lhs_shape[0] != -1 && rhs_shape[0] != -1)) {
|
||||
return NodeSizeInBytes(lhs) <= NodeSizeInBytes(rhs);
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
void OrderedNodePairPool::Insert(ir::Node* var, ir::Node* op) {
|
||||
PADDLE_ENFORCE(var->IsVar() && !var->IsCtrlVar());
|
||||
PADDLE_ENFORCE(op->IsOp());
|
||||
if (mark_table_.count(var->Name()) != 0) {
|
||||
mark_table_[var->Name()]->second.insert(op);
|
||||
return;
|
||||
}
|
||||
|
||||
auto* var_desc = FindVarDescInBlock(var);
|
||||
auto var_shape = var_desc->GetShape();
|
||||
int batch_size = static_cast<int>(var_shape[0]);
|
||||
|
||||
NodeComparator compare_node;
|
||||
Iter it = nodes_.begin();
|
||||
while (it != nodes_.end()) {
|
||||
auto* cache_desc = FindVarDescInBlock(it->first);
|
||||
int cache_batch_size = cache_desc->GetShape()[0];
|
||||
if ((cache_batch_size == -1 && batch_size == -1) ||
|
||||
(cache_batch_size != -1 && batch_size != -1)) {
|
||||
if (compare_node(it->first, var)) {
|
||||
++it;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
} else if (cache_batch_size == -1 && batch_size != -1) {
|
||||
++it;
|
||||
} else if (cache_batch_size != -1 && batch_size == -1) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
it =
|
||||
nodes_.insert(it, std::make_pair(var, std::unordered_set<ir::Node*>{op}));
|
||||
mark_table_[var->Name()] = it;
|
||||
}
|
||||
|
||||
int OrderedNodePairPool::GetIndex(ir::Node* var) {
|
||||
return std::distance(nodes_.begin(), mark_table_[var->Name()]);
|
||||
}
|
||||
|
||||
ir::Node* OrderedNodePairPool::NodeMatch(ir::Node* var) const {
|
||||
ir::Node* found_node = nullptr;
|
||||
NodeComparator compare_node;
|
||||
|
||||
for (auto it = nodes_.begin(); it != nodes_.end(); ++it) {
|
||||
if (compare_node(var, it->first)) {
|
||||
found_node = it->first;
|
||||
break;
|
||||
}
|
||||
}
|
||||
return found_node;
|
||||
}
|
||||
|
||||
void OrderedNodePairPool::Erase(ir::Node* var) {
|
||||
PADDLE_ENFORCE(mark_table_.count(var->Name()));
|
||||
nodes_.erase(mark_table_[var->Name()]);
|
||||
mark_table_.erase(var->Name());
|
||||
}
|
||||
|
||||
std::string OrderedNodePairPool::ToString() const {
|
||||
std::stringstream ss;
|
||||
for (auto it = nodes_.begin(); it != nodes_.end(); ++it) {
|
||||
ss << DebugString(it->first) << " ";
|
||||
}
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,87 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include <iterator>
|
||||
#include <list>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/data_type.h"
|
||||
#include "paddle/fluid/framework/ir/graph.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
|
||||
constexpr char kFetchedVars[] = "fetched_vars";
|
||||
constexpr char kGraphNodePool[] = "graph_node_pool";
|
||||
|
||||
// NOTE(dzh): Variable and the operators use the var.
|
||||
// for early delete pass.
|
||||
// Because analysis var pass build base on ir::Node, which maybe released
|
||||
// or modified between passes, so we use OpDesc* to mark ops.
|
||||
using GraphNodePool = std::vector<
|
||||
std::pair<std::string /*var node*/, std::unordered_set<OpDesc*> /* ops */>>;
|
||||
|
||||
// NOTE(dzh): by default, it sort node in ascend order(by node bytes size).
|
||||
// in fluid, -1 means the batch_size is determined in runtime.
|
||||
// the node batch_size equal -1 always ranking in the front than the node not.
|
||||
// For example,
|
||||
// node0[-1, 1] node1[-1, 1, 1], node2[1,1], node3[1,1024], ..
|
||||
// O(1) insert, delete
|
||||
class OrderedNodePairPool {
|
||||
public:
|
||||
using NodePair = std::pair<ir::Node*, std::unordered_set<ir::Node*>>;
|
||||
using Iter = typename std::list<NodePair>::iterator;
|
||||
using ConstIter = typename std::list<NodePair>::const_iterator;
|
||||
|
||||
void Insert(ir::Node* var, ir::Node* op);
|
||||
|
||||
void Erase(ir::Node* var);
|
||||
|
||||
bool Has(ir::Node* var) { return mark_table_.count(var->Name()); }
|
||||
|
||||
ir::Node* NodeMatch(ir::Node* var) const;
|
||||
// map store non-const iterator, can not promise const
|
||||
int GetIndex(ir::Node* var);
|
||||
// pool all node to string
|
||||
std::string ToString() const;
|
||||
|
||||
Iter begin() { return nodes_.begin(); }
|
||||
Iter end() { return nodes_.end(); }
|
||||
ConstIter begin() const { return nodes_.begin(); }
|
||||
ConstIter end() const { return nodes_.end(); }
|
||||
size_t size() const { return nodes_.size(); }
|
||||
|
||||
private:
|
||||
// for searching.
|
||||
std::unordered_map<std::string, Iter> mark_table_;
|
||||
// node swap pairs. var -> ops dep var
|
||||
std::list<NodePair> nodes_;
|
||||
};
|
||||
|
||||
// node memory size in bytes
|
||||
size_t NodeSizeInBytes(ir::Node* n);
|
||||
|
||||
std::string DebugString(ir::Node* var);
|
||||
|
||||
// std::string DebugString(VarDesc* var);
|
||||
VarDesc* FindVarDescInBlock(ir::Node* n);
|
||||
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,99 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/framework/details/memory_reuse_types.h"
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "glog/logging.h"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
|
||||
TEST(OrderedNodePairPool, Normal) {
|
||||
OrderedNodePairPool pool;
|
||||
std::vector<std::unique_ptr<ir::Node>> nodes;
|
||||
|
||||
// clang-format off
|
||||
std::vector<std::vector<int64_t>> shapes = {{-1, 10},
|
||||
{-1, 20},
|
||||
{1, 2},
|
||||
{5, 2},
|
||||
{10, 20},
|
||||
{-1, 2, 5},
|
||||
{-1, 1, 5},
|
||||
{-1, 1}};
|
||||
// clang-format on
|
||||
const int COUNT = shapes.size();
|
||||
ProgramDesc prog;
|
||||
BlockDesc* block_desc = prog.MutableBlock(0);
|
||||
auto* op_desc = block_desc->AppendOp();
|
||||
op_desc->SetType("dummy");
|
||||
std::unique_ptr<ir::Node> op = ir::CreateNodeForTest(op_desc);
|
||||
|
||||
for (int i = 0; i < COUNT; ++i) {
|
||||
auto desc = block_desc->Var(std::to_string(i));
|
||||
desc->SetShape(shapes[i]);
|
||||
std::unique_ptr<ir::Node> node = ir::CreateNodeForTest(desc);
|
||||
node->inputs.emplace_back(op.get());
|
||||
nodes.emplace_back(std::move(node));
|
||||
}
|
||||
|
||||
for (auto& node : nodes) {
|
||||
pool.Insert(node.get(), op.get());
|
||||
}
|
||||
|
||||
// assert its order and interface.
|
||||
std::cout << pool.ToString() << std::endl;
|
||||
pool.Erase(nodes.front().get());
|
||||
std::cout << pool.ToString() << std::endl;
|
||||
|
||||
ASSERT_EQ(pool.size(), static_cast<size_t>(COUNT - 1));
|
||||
ASSERT_EQ(pool.GetIndex(nodes.back().get()), 0);
|
||||
|
||||
{
|
||||
auto v1 = block_desc->Var("11");
|
||||
v1->SetShape({-1, 256, 56, 56});
|
||||
std::unique_ptr<ir::Node> node1 = ir::CreateNodeForTest(v1);
|
||||
node1->inputs.emplace_back(op.get());
|
||||
auto* cache = pool.NodeMatch(node1.get());
|
||||
ASSERT_EQ(cache, nullptr);
|
||||
}
|
||||
{
|
||||
auto v2 = block_desc->Var("12");
|
||||
v2->SetShape({-1, 2, 5});
|
||||
std::unique_ptr<ir::Node> node1 = ir::CreateNodeForTest(v2);
|
||||
node1->inputs.emplace_back(op.get());
|
||||
auto* cache = pool.NodeMatch(node1.get());
|
||||
ASSERT_EQ(pool.GetIndex(cache), 2); // match 6:[-1,2,5]
|
||||
}
|
||||
{
|
||||
auto v3 = block_desc->Var("13");
|
||||
v3->SetShape({2, 5});
|
||||
std::unique_ptr<ir::Node> node1 = ir::CreateNodeForTest(v3);
|
||||
node1->inputs.emplace_back(op.get());
|
||||
auto* cache = pool.NodeMatch(node1.get());
|
||||
ASSERT_EQ(pool.GetIndex(cache), 5); // match 4:[5,2]
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue