Merge pull request #16409 from sneaxiy/feature/advance_gc

Enhance gc to support deleting tensor buffer in advance
move-code
Zeng Jinle 6 years ago committed by GitHub
commit c7c6eeb44e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -63,7 +63,7 @@ cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto
cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor memory) cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor memory)
nv_test(lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor) nv_test(lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor)
cc_library(garbage_collector SRCS garbage_collector.cc DEPS device_context memory) cc_library(garbage_collector SRCS garbage_collector.cc DEPS device_context memory gflags glog)
cc_library(reader SRCS reader.cc DEPS lod_tensor ddim) cc_library(reader SRCS reader.cc DEPS lod_tensor ddim)
cc_test(reader_test SRCS reader_test.cc DEPS reader) cc_test(reader_test SRCS reader_test.cc DEPS reader)
@ -164,6 +164,8 @@ else()
set(NGRAPH_EXE_DEPS) set(NGRAPH_EXE_DEPS)
endif() endif()
cc_library(executor_gc_helper SRCS executor_gc_helper.cc DEPS scope proto_desc operator garbage_collector)
if(WITH_DISTRIBUTE) if(WITH_DISTRIBUTE)
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog
lod_rank_table feed_fetch_method sendrecvop_rpc ${GLOB_DISTRIBUTE_DEPS} graph_to_program_pass variable_helper ${NGRAPH_EXE_DEPS}) lod_rank_table feed_fetch_method sendrecvop_rpc ${GLOB_DISTRIBUTE_DEPS} graph_to_program_pass variable_helper ${NGRAPH_EXE_DEPS})
@ -174,7 +176,7 @@ else()
cc_test(test_naive_executor SRCS naive_executor_test.cc DEPS naive_executor elementwise_add_op) cc_test(test_naive_executor SRCS naive_executor_test.cc DEPS naive_executor elementwise_add_op)
endif() endif()
target_link_libraries(executor garbage_collector while_op_helper) target_link_libraries(executor while_op_helper executor_gc_helper)
cc_library(parallel_executor SRCS parallel_executor.cc DEPS cc_library(parallel_executor SRCS parallel_executor.cc DEPS
threaded_ssa_graph_executor scope_buffered_ssa_graph_executor parallel_ssa_graph_executor threaded_ssa_graph_executor scope_buffered_ssa_graph_executor parallel_ssa_graph_executor
@ -194,6 +196,7 @@ cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_con
cc_test(var_type_inference_test SRCS var_type_inference_test.cc DEPS op_registry cc_test(var_type_inference_test SRCS var_type_inference_test.cc DEPS op_registry
proto_desc) proto_desc)
cc_test(inplace_op_inference_test SRCS inplace_op_inference_test.cc DEPS op_registry proto_desc op_info memory_optimize_helper) cc_test(inplace_op_inference_test SRCS inplace_op_inference_test.cc DEPS op_registry proto_desc op_info memory_optimize_helper)
cc_library(selected_rows SRCS selected_rows.cc DEPS tensor) cc_library(selected_rows SRCS selected_rows.cc DEPS tensor)
cc_test(selected_rows_test SRCS selected_rows_test.cc DEPS selected_rows) cc_test(selected_rows_test SRCS selected_rows_test.cc DEPS selected_rows)

@ -22,14 +22,9 @@
#include "paddle/fluid/framework/details/computation_op_handle.h" #include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/eager_deletion_op_handle.h" #include "paddle/fluid/framework/details/eager_deletion_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
DEFINE_double(memory_fraction_of_eager_deletion, 1.0,
"Fraction of eager deletion. If less than 1.0, all variables in "
"the program would be sorted according to its memory size, and "
"only the FLAGS_memory_fraction_of_eager_deletion of the largest "
"variables would be deleted.");
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
@ -206,8 +201,9 @@ std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl(
} }
} }
op_vars_map = ShrinkGCVars(op_vars_map, vars, places, double memory_fraction = framework::GetEagerDeletionMemoryFraction();
FLAGS_memory_fraction_of_eager_deletion);
op_vars_map = ShrinkGCVars(op_vars_map, vars, places, memory_fraction);
for (auto &pair : op_vars_map) { for (auto &pair : op_vars_map) {
auto *op = pair.first; auto *op = pair.first;
@ -239,8 +235,7 @@ std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl(
eager_deletion_op->AddOutput(dummy_leaf); eager_deletion_op->AddOutput(dummy_leaf);
} }
VLOG(10) << "FLAGS_memory_fraction_of_eager_deletion = " VLOG(10) << "FLAGS_memory_fraction_of_eager_deletion = " << memory_fraction;
<< FLAGS_memory_fraction_of_eager_deletion;
VLOG(10) << "Create " << op_vars_map.size() << " EagerDeletionOpHandle(s)"; VLOG(10) << "Create " << op_vars_map.size() << " EagerDeletionOpHandle(s)";
auto while_op_eager_deletion_pass = auto while_op_eager_deletion_pass =

@ -1,140 +0,0 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/details/var_handle.h"
#include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor.h"
namespace paddle {
namespace framework {
namespace details {
class EarlyDeleteOpHandle : public OpHandleBase {
public:
EarlyDeleteOpHandle(ir::Node* node, const Scope* scope,
const platform::Place& place,
const std::vector<std::string>& names,
GarbageCollector* gc)
: OpHandleBase(node),
scope_(scope),
place_(place),
names_(names),
gc_(gc) {
#ifdef PADDLE_WITH_CUDA
if (IsStreamGarabageCollector()) {
auto gpu_place = boost::get<platform::CUDAPlace>(place);
PADDLE_ENFORCE(cudaSetDevice(gpu_place.device));
PADDLE_ENFORCE(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming));
}
#endif
}
~EarlyDeleteOpHandle() {
#ifdef PADDLE_WITH_CUDA
if (IsStreamGarabageCollector()) {
auto gpu_place = boost::get<platform::CUDAPlace>(dev_ctx_->GetPlace());
PADDLE_ENFORCE(cudaSetDevice(gpu_place.device));
PADDLE_ENFORCE(cudaEventDestroy(event_));
}
#endif
}
std::string Name() const override { return "early_delete"; }
protected:
void RunImpl() override {
std::vector<std::shared_ptr<memory::Allocation>> tensors;
auto* local_scope = scope_->FindVar(kLocalExecScopeName)->Get<Scope*>();
for (auto& var_name : names_) {
auto* var = local_scope->FindVar(var_name);
PADDLE_ENFORCE(var != nullptr,
string::Sprintf("Local Scope not has var %s", var_name));
if (var->IsType<LoDTensor>()) {
tensors.emplace_back(var->GetMutable<LoDTensor>()->MoveMemoryHolder());
} else if (var->IsType<SelectedRows>()) {
tensors.emplace_back(var->GetMutable<SelectedRows>()
->mutable_value()
->MoveMemoryHolder());
} else if (var->IsType<LoDTensorArray>()) {
LoDTensorArray* tensor_array = var->GetMutable<LoDTensorArray>();
for (auto& tensor : *tensor_array) {
tensors.emplace_back(tensor.MoveMemoryHolder());
}
}
}
if (!tensors.empty()) {
ClearTensors(tensors);
}
}
private:
void ClearTensors(
const std::vector<std::shared_ptr<memory::Allocation>>& tensors) {
if (platform::is_cpu_place(place_)) {
ClearCPUTensors(tensors);
} else {
ClearGPUTensors(tensors);
}
}
void ClearCPUTensors(
const std::vector<std::shared_ptr<memory::Allocation>>& tensors) {
auto* gc = dynamic_cast<CPUGarbageCollector*>(gc_);
if (gc != nullptr) {
gc->Add(tensors);
}
}
void ClearGPUTensors(
const std::vector<std::shared_ptr<memory::Allocation>>& tensors) {
#ifdef PADDLE_WITH_CUDA
auto* gc = dynamic_cast<StreamGarbageCollector*>(gc_);
if (gc != nullptr) {
auto compute_stream = dev_ctx_->stream();
auto callback_stream = gc->stream();
auto callback_func = [=]() {
PADDLE_ENFORCE(cudaEventRecord(event_, compute_stream));
PADDLE_ENFORCE(cudaStreamWaitEvent(callback_stream, event_, 0));
};
gc_->Add(tensors, callback_func);
} else {
gc_->Add(tensors);
}
}
bool IsStreamGarabageCollector() const {
return dynamic_cast<const StreamGarbageCollector*>(gc_) != nullptr;
#endif
}
const Scope* scope_;
const platform::Place place_;
std::vector<std::string> names_;
GarbageCollector* gc_;
#ifdef PADDLE_WITH_CUDA
platform::CUDADeviceContext* dev_ctx_;
cudaEvent_t event_;
#endif
};
} // namespace details
} // namespace framework
} // namespace paddle

@ -21,6 +21,7 @@ limitations under the License. */
#include <vector> #include <vector>
#include "paddle/fluid/framework/grad_op_desc_maker.h" #include "paddle/fluid/framework/grad_op_desc_maker.h"
#include "paddle/fluid/framework/inplace_op_inference.h" #include "paddle/fluid/framework/inplace_op_inference.h"
#include "paddle/fluid/framework/no_need_buffer_vars_inference.h"
#include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
@ -36,27 +37,86 @@ enum OpInfoFillType {
kGradOpDescMaker = 2, kGradOpDescMaker = 2,
kVarTypeInference = 3, kVarTypeInference = 3,
kShapeInference = 4, kShapeInference = 4,
kInplaceOpInference = 5 kInplaceOpInference = 5,
kNoNeedBufferVarsInference = 6,
kUnknown = -1
};
namespace internal {
template <typename T, OpInfoFillType kType>
struct TypePair {
using Type = T;
static constexpr OpInfoFillType kFillType = kType;
};
using OpRegistryClasses = std::tuple< // NOLINT
TypePair<OperatorBase, kOperator>, // NOLINT
TypePair<OpProtoAndCheckerMaker, kOpProtoAndCheckerMaker>, // NOLINT
TypePair<GradOpDescMakerBase, kGradOpDescMaker>, // NOLINT
TypePair<VarTypeInference, kVarTypeInference>, // NOLINT
TypePair<InferShapeBase, kShapeInference>, // NOLINT
TypePair<InplaceOpInference, kInplaceOpInference>, // NOLINT
TypePair<NoNeedBufferVarsInference, kNoNeedBufferVarsInference> // NOLINT
>;
static constexpr int kOpRegistryClassNumber =
std::tuple_size<OpRegistryClasses>::value;
template <typename T, int kPos, bool kIsBounded /* = true*/>
struct IsMatchedBaseTypeImpl {
using PairType = typename std::tuple_element<kPos, OpRegistryClasses>::type;
static constexpr bool kValue =
std::is_base_of<typename PairType::Type, T>::value;
};
template <typename T, int kPos>
struct IsMatchedBaseTypeImpl<T, kPos, false> {
static constexpr bool kValue = false;
};
template <typename T, int kPos>
static inline constexpr bool IsMatchedBaseType() {
return IsMatchedBaseTypeImpl<
T, kPos, (kPos >= 0 && kPos < kOpRegistryClassNumber)>::kValue;
}
template <typename T, int kStart, int kEnd, bool kIsEnd, bool kIsMatched>
struct OpInfoFillTypeGetterImpl {};
// This case should not happen
template <typename T, int kStart, int kEnd>
struct OpInfoFillTypeGetterImpl<T, kStart, kEnd, true, true> {};
template <typename T, int kStart, int kEnd>
struct OpInfoFillTypeGetterImpl<T, kStart, kEnd, true, false> {
static constexpr OpInfoFillType kType = kUnknown;
};
template <typename T, int kStart, int kEnd>
struct OpInfoFillTypeGetterImpl<T, kStart, kEnd, false, false> {
static constexpr OpInfoFillType kType =
OpInfoFillTypeGetterImpl<T, kStart + 1, kEnd, kStart + 1 == kEnd,
IsMatchedBaseType<T, kStart + 1>()>::kType;
};
template <typename T, int kStart, int kEnd>
struct OpInfoFillTypeGetterImpl<T, kStart, kEnd, false, true> {
using PairType = typename std::tuple_element<kStart, OpRegistryClasses>::type;
static constexpr OpInfoFillType kType = PairType::kFillType;
}; };
template <typename T>
using OpInfoFillTypeGetter =
OpInfoFillTypeGetterImpl<T, 0, kOpRegistryClassNumber,
kOpRegistryClassNumber == 0,
IsMatchedBaseType<T, 0>()>;
} // namespace internal
template <typename T> template <typename T>
struct OpInfoFillTypeID { struct OpInfoFillTypeID {
static constexpr OpInfoFillType ID() { static constexpr OpInfoFillType ID() {
return std::is_base_of<OperatorBase, T>::value return internal::OpInfoFillTypeGetter<T>::kType;
? kOperator
: (std::is_base_of<OpProtoAndCheckerMaker, T>::value
? kOpProtoAndCheckerMaker
: (std::is_base_of<GradOpDescMakerBase, T>::value
? kGradOpDescMaker
: (std::is_base_of<VarTypeInference, T>::value
? kVarTypeInference
: (std::is_base_of<InferShapeBase, T>::value
? kShapeInference
: (std::is_base_of<
InplaceOpInference, T>::value
? kInplaceOpInference
: static_cast<OpInfoFillType>(
-1))))));
} }
}; };
@ -156,6 +216,18 @@ struct OpInfoFiller<T, kInplaceOpInference> {
} }
}; };
template <typename T>
struct OpInfoFiller<T, kNoNeedBufferVarsInference> {
void operator()(const char* op_type, OpInfo* info) const {
info->infer_no_need_buffer_vars_ = [](const VariableNameMap& inputs,
const VariableNameMap& outputs,
const AttributeMap& attrs) {
T infer(inputs, outputs, attrs);
return infer();
};
}
};
} // namespace details } // namespace details
} // namespace framework } // namespace framework

@ -193,6 +193,79 @@ ExtractComputationOpFromLastLivedVar(VarHandle *var, size_t scope_idx,
return shrink_func(computation_op); return shrink_func(computation_op);
} }
/**
* Shrink op dependencies according to no need buffer vars.
*
* If some ops do not need Tensor buffer of any input,
* just remove the dependency of this op, i.e, decrease reference count.
*
* For example, input Y of elementwise_add_grad op is only used to infer shape
* and lod of Y@GRAD, we do not need the buffer of input Y. Data buffer of
* input Y can be collected before elementwise_add_grad op runs.
*
* This method returns whether the dependency count decreases to 0, and
* shrinks op dependency if possible.
*/
static bool ShrinkNoNeedBufferVarOpDependency(
const std::string &var_name,
std::unordered_set<ComputationOpHandle *> *op_handles) {
std::vector<ComputationOpHandle *> skip_ops;
for (auto *op_handle : *op_handles) {
auto *op_base = op_handle->GetOp();
auto &inferer = op_base->Info().NoNeedBufferVarsInferer();
if (!inferer) {
continue;
}
std::unordered_set<std::string> no_need_buffer_vars =
inferer(op_base->Inputs(), op_base->Outputs(), op_base->Attrs());
// Check whether var_name occurs in other inputs or outputs of the op
// If it occurs, we cannot decrease the dependency number.
bool occurred_in_other_vars = false;
for (auto &in_pair : op_base->Inputs()) {
if (no_need_buffer_vars.count(in_pair.first) > 0) {
continue;
}
auto &args = in_pair.second;
auto iter = std::find(args.begin(), args.end(), var_name);
if (iter != args.end()) {
occurred_in_other_vars = true;
break;
}
}
if (occurred_in_other_vars) {
continue;
}
for (auto &out_pair : op_base->Outputs()) {
auto &args = out_pair.second;
auto iter = std::find(args.begin(), args.end(), var_name);
if (iter != args.end()) {
occurred_in_other_vars = true;
break;
}
}
if (!occurred_in_other_vars) {
VLOG(2) << "Shrink var " << var_name << " in op " << op_handle->Name();
skip_ops.emplace_back(op_handle);
}
}
if (skip_ops.size() == op_handles->size()) {
op_handles->clear();
return true;
} else {
for (auto *skip_op : skip_ops) {
op_handles->erase(skip_op);
}
return false;
}
}
std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl( std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const { std::unique_ptr<ir::Graph> graph) const {
auto &ref_cnts = Get<std::vector<ReferenceCountMap>>(kGlobalReferenceCount); auto &ref_cnts = Get<std::vector<ReferenceCountMap>>(kGlobalReferenceCount);
@ -229,17 +302,43 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
continue; continue;
} }
auto &var_name = name_var_pair.first;
auto &var_handles = name_var_pair.second;
for (auto iter = var_handles.rbegin(); iter != var_handles.rend();
++iter) {
bool ok; bool ok;
auto result = ExtractComputationOpFromLastLivedVar( auto result =
name_var_pair.second.back(), i, shrink_func, &ok); ExtractComputationOpFromLastLivedVar(*iter, i, shrink_func, &ok);
// Seldomly, some vars may have no pending or preceding computation ops
// Just break;
if (!ok) break;
VLOG(10) << "Extract " << result.size() << " ops of var " << var_name;
size_t original_op_deps = result.size();
// If all ops do not need buffer of var_name, calculate reference count
// of the previous version of var_name.
if (ShrinkNoNeedBufferVarOpDependency(var_name, &result)) {
VLOG(10) << "Try to precede reference count computing at var "
<< var_name;
continue;
}
size_t final_op_deps = result.size();
if (final_op_deps < original_op_deps) {
VLOG(5) << "Shrink op deps from " << original_op_deps << " to "
<< final_op_deps;
}
if (ok) {
auto &var_name = name_var_pair.first;
PADDLE_ENFORCE(!result.empty(), "Last living ops of %s cannot be empty", PADDLE_ENFORCE(!result.empty(), "Last living ops of %s cannot be empty",
var_name); var_name);
ref_cnts[i].emplace(var_name, result.size()); ref_cnts[i].emplace(var_name, result.size());
last_live_ops_of_vars[i].emplace(var_name, std::move(result)); last_live_ops_of_vars[i].emplace(var_name, std::move(result));
} }
// Seldomly, all preceding trying failed.
// Just skip this corner case
} }
} }

@ -19,6 +19,7 @@ limitations under the License. */
#include <unordered_set> #include <unordered_set>
#include <utility> #include <utility>
#include "paddle/fluid/framework/executor_gc_helper.h"
#include "paddle/fluid/framework/feed_fetch_method.h" #include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/lod_rank_table.h" #include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/lod_tensor_array.h" #include "paddle/fluid/framework/lod_tensor_array.h"
@ -48,97 +49,23 @@ namespace {
int kProgramId = -1; int kProgramId = -1;
} // namespace } // namespace
static std::unordered_map<std::string, size_t> GetNonPersistableReferenceCounts(
const BlockDesc& block, const std::vector<std::string>& skip_var_list) {
std::unordered_map<std::string, size_t> ref_cnts;
std::unordered_set<std::string> skip_vars(skip_var_list.begin(),
skip_var_list.end());
auto update_ref_cnts = [&](OpDesc* op_desc, const VariableNameMap& name_map) {
for (auto& name_pair : name_map) {
for (auto& name : name_pair.second) {
if (skip_vars.count(name)) continue;
auto* var_desc = block.FindVar(name);
if (var_desc == nullptr || var_desc->Persistable()) continue;
auto type = var_desc->Proto()->type().type();
if (type != proto::VarType::LOD_TENSOR &&
type != proto::VarType::SELECTED_ROWS &&
type != proto::VarType::LOD_TENSOR_ARRAY) {
continue;
}
++ref_cnts[name];
}
}
};
for (auto op_desc : block.AllOps()) {
update_ref_cnts(op_desc, op_desc->Inputs());
update_ref_cnts(op_desc, op_desc->Outputs());
}
return ref_cnts;
}
ExecutorPrepareContext::ExecutorPrepareContext( ExecutorPrepareContext::ExecutorPrepareContext(
const framework::ProgramDesc& prog, size_t block_id, const framework::ProgramDesc& prog, size_t block_id)
const std::vector<std::string>& keep_vars, bool force_disable_gc) : prog_(prog), block_id_(block_id) {}
: prog_(prog), block_id_(block_id), force_disable_gc_(force_disable_gc) {
if (GetEagerDeletionThreshold() >= 0 && !force_disable_gc_) { void ExecutorPrepareContext::PrepareUnusedVars(
global_ref_cnts_ = const std::vector<std::string>& keep_vars, bool force_disable_gc) {
GetNonPersistableReferenceCounts(prog.Block(block_id), keep_vars); force_disable_gc_ = force_disable_gc;
if (GetEagerDeletionThreshold() < 0 || force_disable_gc_) {
return;
} }
unused_vars_ = GetUnusedVars(prog_.Block(block_id_), ops_, keep_vars);
} }
ExecutorPrepareContext::~ExecutorPrepareContext() { ExecutorPrepareContext::~ExecutorPrepareContext() {
VLOG(5) << "destroy ExecutorPrepareContext"; VLOG(5) << "destroy ExecutorPrepareContext";
} }
static void DeleteUnusedTensors(
const Scope& scope, const OperatorBase* op, GarbageCollector* gc,
std::unordered_map<std::string, size_t>* ref_cnts) {
std::deque<std::shared_ptr<memory::Allocation>> garbages;
auto handler = [&](const VariableNameMap& name_map) {
for (auto& name_pair : name_map) {
for (auto& name : name_pair.second) {
auto it = ref_cnts->find(name);
if (it == ref_cnts->end()) continue;
if (--(it->second) != 0) {
continue;
}
auto* var = scope.FindVar(name);
if (var == nullptr) {
continue;
}
VLOG(2) << "Erase variable " << name;
if (var->IsType<LoDTensor>()) {
garbages.emplace_back(
var->GetMutable<LoDTensor>()->MoveMemoryHolder());
} else if (var->IsType<SelectedRows>()) {
garbages.emplace_back(var->GetMutable<SelectedRows>()
->mutable_value()
->MoveMemoryHolder());
} else if (var->IsType<LoDTensorArray>()) {
auto* lod_tensor_arr = var->GetMutable<LoDTensorArray>();
for (auto& t : *lod_tensor_arr) {
garbages.emplace_back(t.MoveMemoryHolder());
}
} else {
PADDLE_THROW("Type %s of %s is not supported eager deletion",
framework::ToTypeName(var->Type()), name);
}
}
}
};
handler(op->Inputs());
handler(op->Outputs());
if (!garbages.empty()) {
gc->Add(std::move(garbages));
}
}
Executor::Executor(const platform::Place& place) : place_(place) {} Executor::Executor(const platform::Place& place) : place_(place) {}
void Executor::Close() { void Executor::Close() {
@ -362,8 +289,8 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
std::unique_ptr<ExecutorPrepareContext> Executor::Prepare( std::unique_ptr<ExecutorPrepareContext> Executor::Prepare(
const ProgramDesc& program, int block_id, const ProgramDesc& program, int block_id,
const std::vector<std::string>& skip_ref_cnt_vars, bool force_disable_gc) { const std::vector<std::string>& skip_ref_cnt_vars, bool force_disable_gc) {
std::unique_ptr<ExecutorPrepareContext> ctx(new ExecutorPrepareContext( std::unique_ptr<ExecutorPrepareContext> ctx(
program, block_id, skip_ref_cnt_vars, force_disable_gc)); new ExecutorPrepareContext(program, block_id));
PADDLE_ENFORCE_LT(static_cast<size_t>(block_id), program.Size()); PADDLE_ENFORCE_LT(static_cast<size_t>(block_id), program.Size());
auto& block = program.Block(block_id); auto& block = program.Block(block_id);
for (auto& op_desc : block.AllOps()) { for (auto& op_desc : block.AllOps()) {
@ -375,6 +302,7 @@ std::unique_ptr<ExecutorPrepareContext> Executor::Prepare(
ctx->prog_.Block(ctx->block_id_), &ctx->ops_); ctx->prog_.Block(ctx->block_id_), &ctx->ops_);
} }
#endif #endif
ctx->PrepareUnusedVars(skip_ref_cnt_vars, force_disable_gc);
return ctx; return ctx;
} }
@ -389,19 +317,17 @@ std::vector<std::shared_ptr<ExecutorPrepareContext>> Executor::Prepare(
std::vector<std::shared_ptr<ExecutorPrepareContext>> result; std::vector<std::shared_ptr<ExecutorPrepareContext>> result;
size_t idx = 0; size_t idx = 0;
for (auto& bid : block_ids) { for (auto& bid : block_ids) {
ExecutorPrepareContext* ctx;
if (skip_ref_cnt_vars.empty()) {
ctx = new ExecutorPrepareContext(program, bid, std::vector<std::string>(),
force_disable_gc);
} else {
ctx = new ExecutorPrepareContext(program, bid, skip_ref_cnt_vars[idx],
force_disable_gc);
}
PADDLE_ENFORCE_LT(static_cast<size_t>(bid), program.Size()); PADDLE_ENFORCE_LT(static_cast<size_t>(bid), program.Size());
auto* ctx = new ExecutorPrepareContext(program, bid);
auto& block = program.Block(bid); auto& block = program.Block(bid);
for (auto& op_desc : block.AllOps()) { for (auto& op_desc : block.AllOps()) {
ctx->ops_.push_back(OpRegistry::CreateOp(*op_desc)); ctx->ops_.push_back(OpRegistry::CreateOp(*op_desc));
} }
if (skip_ref_cnt_vars.empty()) {
ctx->PrepareUnusedVars(std::vector<std::string>(), force_disable_gc);
} else {
ctx->PrepareUnusedVars(skip_ref_cnt_vars[idx], force_disable_gc);
}
result.push_back(std::shared_ptr<ExecutorPrepareContext>(ctx)); result.push_back(std::shared_ptr<ExecutorPrepareContext>(ctx));
++idx; ++idx;
} }
@ -425,7 +351,6 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
// FIXME(zjl): recurrent_op is rather complex, we would // FIXME(zjl): recurrent_op is rather complex, we would
// disable gc forcely in recurrent_op // disable gc forcely in recurrent_op
if (!ctx->force_disable_gc_ && max_memory_size >= 0) { if (!ctx->force_disable_gc_ && max_memory_size >= 0) {
ctx->ResetReferenceCount();
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(place_)) { if (platform::is_gpu_place(place_)) {
if (IsFastEagerDeletionModeEnabled()) { if (IsFastEagerDeletionModeEnabled()) {
@ -453,8 +378,7 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
op->Run(*local_scope, place_); op->Run(*local_scope, place_);
if (gc) { if (gc) {
DeleteUnusedTensors(*local_scope, op.get(), gc.get(), DeleteUnusedTensors(*local_scope, op.get(), ctx->unused_vars_, gc.get());
&(ctx->runtime_ref_cnts_));
} }
} }

@ -30,22 +30,20 @@ namespace paddle {
namespace framework { namespace framework {
struct ExecutorPrepareContext { struct ExecutorPrepareContext {
ExecutorPrepareContext(const framework::ProgramDesc& prog, size_t block_id, ExecutorPrepareContext(const framework::ProgramDesc& prog, size_t block_id);
const std::vector<std::string>& skip_ref_cnt_vars =
std::vector<std::string>(),
bool force_disable_gc = false);
~ExecutorPrepareContext(); ~ExecutorPrepareContext();
void ResetReferenceCount() { runtime_ref_cnts_ = global_ref_cnts_; } void PrepareUnusedVars(const std::vector<std::string>& keep_vars,
bool force_disable_gc = false);
const framework::ProgramDesc& prog_; const framework::ProgramDesc& prog_;
size_t block_id_; const size_t block_id_;
bool force_disable_gc_;
std::vector<std::unique_ptr<OperatorBase>> ops_; std::vector<std::unique_ptr<OperatorBase>> ops_;
std::unordered_map<std::string, size_t> global_ref_cnts_; std::unordered_map<OperatorBase*, std::vector<std::string>> unused_vars_;
std::unordered_map<std::string, size_t> runtime_ref_cnts_; bool force_disable_gc_{false};
}; };
class Executor { class Executor {

@ -0,0 +1,189 @@
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/executor_gc_helper.h"
#include <deque>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "glog/logging.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
struct OpInOutInfo {
public:
void Build(const OperatorBase *op) {
is_built_ = true;
auto &inferer = op->Info().NoNeedBufferVarsInferer();
if (inferer) {
no_need_buffer_ins_ = inferer(op->Inputs(), op->Outputs(), op->Attrs());
if (no_need_buffer_ins_.empty()) return;
for (auto &in_name_pair : op->Inputs()) {
if (no_need_buffer_ins_.count(in_name_pair.first) != 0) {
continue;
}
for (auto &in_arg_name : in_name_pair.second) {
other_args_set_.insert(in_arg_name);
}
}
for (auto &out_name_pair : op->Outputs()) {
for (auto &out_arg_name : out_name_pair.second) {
other_args_set_.insert(out_arg_name);
}
}
}
}
bool IsBuilt() const { return is_built_; }
bool IsInArgBufferNeeded(const std::string &in_arg_name) const {
return no_need_buffer_ins_.empty() ||
other_args_set_.count(in_arg_name) != 0;
}
private:
// A set to record unused buffer input vars of op
std::unordered_set<std::string> no_need_buffer_ins_;
// A set to record other args of op (including in, out)
std::unordered_set<std::string> other_args_set_;
bool is_built_{false};
};
static bool VarCanBeDeleted(const std::string &name, const BlockDesc &block,
const std::unordered_set<std::string> &skip_vars) {
if (skip_vars.count(name) != 0) {
return false;
}
auto *var_desc = block.FindVar(name);
if (var_desc == nullptr || var_desc->Persistable()) {
return false;
}
auto type = var_desc->Proto()->type().type();
return type == proto::VarType::LOD_TENSOR ||
type == proto::VarType::SELECTED_ROWS ||
type == proto::VarType::LOD_TENSOR_ARRAY;
}
std::unordered_map<OperatorBase *, std::vector<std::string>> GetUnusedVars(
const BlockDesc &block,
const std::vector<std::unique_ptr<OperatorBase>> &ops,
const std::vector<std::string> &skip_var_list) {
std::unordered_set<std::string> skip_vars(skip_var_list.begin(),
skip_var_list.end());
std::unordered_map<std::string, size_t> var_op_idx_map;
for (size_t i = 0; i < ops.size(); ++i) {
auto *op = ops[i].get();
OpInOutInfo info;
for (auto &name_pair : op->Inputs()) {
for (auto &name : name_pair.second) {
if (!VarCanBeDeleted(name, block, skip_vars)) {
continue;
}
// var can be gc-ed
if (!info.IsBuilt()) {
info.Build(op);
}
if (info.IsInArgBufferNeeded(name)) {
// Update the last living op of variable to current op
var_op_idx_map[name] = i;
} else {
VLOG(10) << "Skip reference count computing of variable "
<< name_pair.first << "(" << name << ") in Operator "
<< op->Type();
}
}
}
for (auto &name_pair : op->Outputs()) {
for (auto &name : name_pair.second) {
if (VarCanBeDeleted(name, block, skip_vars)) {
// Update the last living op of variable to current op
var_op_idx_map[name] = i;
}
}
}
}
std::unordered_map<OperatorBase *, std::vector<std::string>> result;
for (auto &name_op_idx_pair : var_op_idx_map) {
auto &name = name_op_idx_pair.first;
size_t op_idx = name_op_idx_pair.second;
result[ops[op_idx].get()].emplace_back(name);
}
return result;
}
void DeleteUnusedTensors(
const Scope &scope, OperatorBase *op,
const std::unordered_map<OperatorBase *, std::vector<std::string>>
&delete_vars_map,
GarbageCollector *gc) {
auto iter = delete_vars_map.find(op);
if (iter == delete_vars_map.end()) {
return;
}
auto &delete_vars = iter->second;
std::deque<std::shared_ptr<memory::Allocation>> garbages;
for (auto &var_name : delete_vars) {
auto *var = scope.FindVar(var_name);
if (var == nullptr) {
continue;
}
VLOG(2) << "Erase variable " << var_name;
if (var->IsType<LoDTensor>()) {
garbages.emplace_back(var->GetMutable<LoDTensor>()->MoveMemoryHolder());
} else if (var->IsType<SelectedRows>()) {
garbages.emplace_back(
var->GetMutable<SelectedRows>()->mutable_value()->MoveMemoryHolder());
} else if (var->IsType<LoDTensorArray>()) {
auto *lod_tensor_arr = var->GetMutable<LoDTensorArray>();
for (auto &t : *lod_tensor_arr) {
garbages.emplace_back(t.MoveMemoryHolder());
}
} else {
PADDLE_THROW("Type %s of %s is not supported eager deletion",
framework::ToTypeName(var->Type()), var_name);
}
}
if (!garbages.empty()) {
gc->Add(std::move(garbages));
}
}
} // namespace framework
} // namespace paddle

@ -0,0 +1,42 @@
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/scope.h"
namespace paddle {
namespace framework {
// Result map: op -> variable names that can be deleted after op runs
std::unordered_map<OperatorBase *, std::vector<std::string>> GetUnusedVars(
const BlockDesc &block,
const std::vector<std::unique_ptr<OperatorBase>> &ops,
const std::vector<std::string> &skip_vars);
// Collect unused tensors after op runs
void DeleteUnusedTensors(
const Scope &scope, OperatorBase *op,
const std::unordered_map<OperatorBase *, std::vector<std::string>>
&delete_vars_map,
GarbageCollector *gc);
} // namespace framework
} // namespace paddle

@ -13,14 +13,36 @@
// limitations under the License. // limitations under the License.
#include <algorithm> #include <algorithm>
#include <deque>
#include <functional>
#include <memory>
#include <mutex> // NOLINT
#include <utility>
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cuda_device_guard.h" #include "paddle/fluid/platform/cuda_device_guard.h"
#endif #endif
#include "gflags/gflags.h"
#include "glog/logging.h"
#include "paddle/fluid/framework/garbage_collector.h" #include "paddle/fluid/framework/garbage_collector.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
DEFINE_double(
eager_delete_tensor_gb, -1.0,
"Memory size threshold (GB) when the garbage collector clear tensors."
"Disabled when this value is less than 0");
DEFINE_bool(fast_eager_deletion_mode, true,
"Fast eager deletion mode. If enabled, memory would release "
"immediately without waiting GPU kernel ends.");
DEFINE_double(memory_fraction_of_eager_deletion, 1.0,
"Fraction of eager deletion. If less than 1.0, all variables in "
"the program would be sorted according to its memory size, and "
"only the FLAGS_memory_fraction_of_eager_deletion of the largest "
"variables would be deleted.");
GarbageCollector::GarbageCollector(const platform::Place &place, GarbageCollector::GarbageCollector(const platform::Place &place,
size_t max_memory_size) size_t max_memory_size)
: max_memory_size_((std::max)(max_memory_size, static_cast<size_t>(1))) { : max_memory_size_((std::max)(max_memory_size, static_cast<size_t>(1))) {
@ -85,5 +107,25 @@ void StreamGarbageCollector::ClearCallback(
callback_manager_->AddCallback(callback); callback_manager_->AddCallback(callback);
} }
#endif #endif
int64_t GetEagerDeletionThreshold() {
return FLAGS_eager_delete_tensor_gb < 0
? -1
: static_cast<int64_t>(FLAGS_eager_delete_tensor_gb *
(static_cast<int64_t>(1) << 30));
}
bool IsFastEagerDeletionModeEnabled() { return FLAGS_fast_eager_deletion_mode; }
void SetEagerDeletionMode(double threshold, double fraction, bool fast_mode) {
FLAGS_eager_delete_tensor_gb = threshold;
FLAGS_memory_fraction_of_eager_deletion = fraction;
FLAGS_fast_eager_deletion_mode = fast_mode;
}
double GetEagerDeletionMemoryFraction() {
return FLAGS_memory_fraction_of_eager_deletion;
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle

@ -18,6 +18,8 @@
#include <functional> #include <functional>
#include <memory> #include <memory>
#include <mutex> // NOLINT #include <mutex> // NOLINT
#include <utility>
#include "gflags/gflags.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
namespace paddle { namespace paddle {
@ -126,5 +128,12 @@ void GarbageCollector::Add(Container &&objs, Callback &&callback) {
} }
} }
int64_t GetEagerDeletionThreshold();
bool IsFastEagerDeletionModeEnabled();
void SetEagerDeletionMode(double threshold, double fraction, bool fast_mode);
double GetEagerDeletionMemoryFraction();
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle

@ -0,0 +1,60 @@
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/op_desc.h"
namespace paddle {
namespace framework {
class NoNeedBufferVarsInference {
public:
NoNeedBufferVarsInference(const VariableNameMap &inputs,
const VariableNameMap &outputs,
const AttributeMap &attrs)
: inputs_(inputs), outputs_(outputs), attrs_(attrs) {}
virtual ~NoNeedBufferVarsInference() = default;
const VariableNameMap &Inputs() const { return inputs_; }
const VariableNameMap &Outputs() const { return outputs_; }
const AttributeMap &Attrs() const { return attrs_; }
virtual std::unordered_set<std::string> operator()() const = 0;
private:
const VariableNameMap &inputs_;
const VariableNameMap &outputs_;
const AttributeMap &attrs_;
};
#define DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(class_type, ...) \
class class_type : public ::paddle::framework::NoNeedBufferVarsInference { \
public: \
using ::paddle::framework::NoNeedBufferVarsInference:: \
NoNeedBufferVarsInference; \
\
std::unordered_set<std::string> operator()() const override { \
return {__VA_ARGS__}; \
} \
}
} // namespace framework
} // namespace paddle

@ -19,6 +19,7 @@ limitations under the License. */
#include <unordered_map> #include <unordered_map>
#include "paddle/fluid/framework/attribute.h" #include "paddle/fluid/framework/attribute.h"
#include "paddle/fluid/framework/no_need_buffer_vars_inference.h"
#include "paddle/fluid/framework/type_defs.h" #include "paddle/fluid/framework/type_defs.h"
#include "paddle/fluid/platform/macros.h" #include "paddle/fluid/platform/macros.h"
@ -39,6 +40,7 @@ struct OpInfo {
InferVarTypeFN infer_var_type_; InferVarTypeFN infer_var_type_;
InferShapeFN infer_shape_; InferShapeFN infer_shape_;
InferInplaceOpFN infer_inplace_; InferInplaceOpFN infer_inplace_;
InferNoNeedBufferVarsFN infer_no_need_buffer_vars_;
bool HasOpProtoAndChecker() const { bool HasOpProtoAndChecker() const {
return proto_ != nullptr && checker_ != nullptr; return proto_ != nullptr && checker_ != nullptr;
@ -64,6 +66,10 @@ struct OpInfo {
} }
const OpAttrChecker* Checker() const { return checker_; } const OpAttrChecker* Checker() const { return checker_; }
const InferNoNeedBufferVarsFN& NoNeedBufferVarsInferer() const {
return infer_no_need_buffer_vars_;
}
}; };
class OpInfoMap { class OpInfoMap {

@ -18,6 +18,7 @@ limitations under the License. */
#include <algorithm> #include <algorithm>
#include <sstream> #include <sstream>
#include <string> #include <string>
#include <unordered_set>
#include <vector> #include <vector>
#include "paddle/fluid/framework/data_transform.h" #include "paddle/fluid/framework/data_transform.h"
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor.h"
@ -326,7 +327,12 @@ OperatorBase::OperatorBase(const std::string& type,
const VariableNameMap& inputs, const VariableNameMap& inputs,
const VariableNameMap& outputs, const VariableNameMap& outputs,
const AttributeMap& attrs) const AttributeMap& attrs)
: type_(type), inputs_(inputs), outputs_(outputs), attrs_(attrs) { : type_(type),
inputs_(inputs),
outputs_(outputs),
attrs_(attrs),
// NOTE(zjl): why op_info may be nullptr?
info_(OpInfoMap::Instance().GetNullable(type)) {
GenerateTemporaryNames(); GenerateTemporaryNames();
CheckAllInputOutputSet(); CheckAllInputOutputSet();
} }
@ -350,7 +356,7 @@ std::vector<std::string> OperatorBase::OutputVars(bool has_intermediate) const {
} }
return ret_val; return ret_val;
} }
auto& info = OpInfoMap::Instance().Get(Type()); auto& info = Info();
// get all OpProto::Var for outputs // get all OpProto::Var for outputs
for (auto& o : info.Proto().outputs()) { for (auto& o : info.Proto().outputs()) {
@ -366,18 +372,16 @@ std::vector<std::string> OperatorBase::OutputVars(bool has_intermediate) const {
} }
void OperatorBase::CheckAllInputOutputSet() const { void OperatorBase::CheckAllInputOutputSet() const {
auto& info_map = OpInfoMap::Instance(); if (info_ == nullptr || info_->proto_ == nullptr) return;
auto* op_info = info_map.GetNullable(Type());
if (op_info == nullptr || op_info->proto_ == nullptr) return;
for (auto& in : op_info->Proto().inputs()) { for (auto& in : info_->Proto().inputs()) {
if (!in.dispensable()) { if (!in.dispensable()) {
PADDLE_ENFORCE(inputs_.find(in.name()) != inputs_.end(), PADDLE_ENFORCE(inputs_.find(in.name()) != inputs_.end(),
"Operator %s's input, %s, is not set", Type(), in.name()); "Operator %s's input, %s, is not set", Type(), in.name());
} }
} }
for (auto& out : op_info->Proto().outputs()) { for (auto& out : info_->Proto().outputs()) {
if (!out.dispensable()) { if (!out.dispensable()) {
PADDLE_ENFORCE(outputs_.find(out.name()) != outputs_.end(), PADDLE_ENFORCE(outputs_.find(out.name()) != outputs_.end(),
"Operator %s's output, %s, is not set", Type(), "Operator %s's output, %s, is not set", Type(),
@ -997,7 +1001,27 @@ Scope* OperatorWithKernel::PrepareData(
std::vector<std::string>* transfered_inplace_vars, std::vector<std::string>* transfered_inplace_vars,
RuntimeContext* ctx) const { RuntimeContext* ctx) const {
Scope* new_scope = nullptr; Scope* new_scope = nullptr;
std::unordered_set<std::string> no_buffer_ins;
if (info_) {
auto& no_buffer_inferer = info_->NoNeedBufferVarsInferer();
// Some op may not register NoNeedBufferVarsInferer
if (no_buffer_inferer) {
no_buffer_ins = no_buffer_inferer(Inputs(), Outputs(), Attrs());
}
}
for (auto& var_name_item : Inputs()) { for (auto& var_name_item : Inputs()) {
// NOTE(zjl): STL does not guarantee fast std::unordered_set::count when set
// is empty. At least STL implemented on my mac does calculate hash code
// of search key even though the set is empty.
if (!no_buffer_ins.empty() &&
no_buffer_ins.count(var_name_item.first) > 0) {
VLOG(1) << "Skip scanning input " << var_name_item.first
<< " in Operator " << type_;
continue;
}
std::vector<Variable*>& input_vars = ctx->inputs[var_name_item.first]; std::vector<Variable*>& input_vars = ctx->inputs[var_name_item.first];
for (size_t i = 0; i < var_name_item.second.size(); ++i) { for (size_t i = 0; i < var_name_item.second.size(); ++i) {

@ -160,6 +160,11 @@ class OperatorBase {
const VariableNameMap& Inputs() const { return inputs_; } const VariableNameMap& Inputs() const { return inputs_; }
const VariableNameMap& Outputs() const { return outputs_; } const VariableNameMap& Outputs() const { return outputs_; }
const OpInfo& Info() const {
PADDLE_ENFORCE_NOT_NULL(info_, "OpInfo of %s is not found", type_);
return *info_;
}
bool HasInputs(const std::string& name) const; bool HasInputs(const std::string& name) const;
//! Get a input with argument's name described in `op_proto` //! Get a input with argument's name described in `op_proto`
std::string Input(const std::string& name) const; std::string Input(const std::string& name) const;
@ -194,6 +199,10 @@ class OperatorBase {
// IG (Inputs Gradients) // IG (Inputs Gradients)
VariableNameMap outputs_; VariableNameMap outputs_;
AttributeMap attrs_; AttributeMap attrs_;
// OpInfo
const OpInfo* info_;
// Whether this operator executes in an Executor. // Whether this operator executes in an Executor.
bool run_by_executor_{true}; bool run_by_executor_{true};
@ -444,7 +453,7 @@ class OperatorWithKernel : public OperatorBase {
} }
virtual void InferShape(InferShapeContext* ctx) const { virtual void InferShape(InferShapeContext* ctx) const {
OpInfoMap::Instance().Get(Type()).infer_shape_(ctx); Info().infer_shape_(ctx);
} }
void RuntimeInferShape(const Scope& scope, const platform::Place& place, void RuntimeInferShape(const Scope& scope, const platform::Place& place,

@ -29,15 +29,6 @@ DEFINE_bool(
"Delete local scope eagerly. It will reduce GPU memory usage but " "Delete local scope eagerly. It will reduce GPU memory usage but "
"slow down the destruction of variables.(around 1% performance harm)"); "slow down the destruction of variables.(around 1% performance harm)");
DEFINE_double(
eager_delete_tensor_gb, -1.0,
"Memory size threshold (GB) when the garbage collector clear tensors."
"Disabled when this value is less than 0");
DEFINE_bool(fast_eager_deletion_mode, true,
"Fast eager deletion mode. If enabled, memory would release "
"immediately without waiting GPU kernel ends.");
// When in inference scenario, the scopes will not be written by two threads in // When in inference scenario, the scopes will not be written by two threads in
// a mean time, but a scope may be read by multiple threads concurrently, and // a mean time, but a scope may be read by multiple threads concurrently, and
// the mutex will cause serious performance issue. // the mutex will cause serious performance issue.
@ -57,15 +48,6 @@ DEFINE_bool(fast_eager_deletion_mode, true,
namespace paddle { namespace paddle {
namespace framework { namespace framework {
int64_t GetEagerDeletionThreshold() {
return FLAGS_eager_delete_tensor_gb < 0
? -1
: static_cast<int64_t>(FLAGS_eager_delete_tensor_gb *
(static_cast<int64_t>(1) << 30));
}
bool IsFastEagerDeletionModeEnabled() { return FLAGS_fast_eager_deletion_mode; }
Scope::~Scope() { DropKids(); } Scope::~Scope() { DropKids(); }
Scope& Scope::NewScope() const { Scope& Scope::NewScope() const {

@ -32,9 +32,6 @@ extern "C" {
namespace paddle { namespace paddle {
namespace framework { namespace framework {
int64_t GetEagerDeletionThreshold();
bool IsFastEagerDeletionModeEnabled();
class Scope; class Scope;
/** /**

@ -30,6 +30,7 @@ class InferShapeContext;
class InferVarTypeContext; class InferVarTypeContext;
class BlockDesc; class BlockDesc;
class Variable; class Variable;
class NoNeedBufferVarsInference;
using VariableNameMap = std::map<std::string, std::vector<std::string>>; using VariableNameMap = std::map<std::string, std::vector<std::string>>;
// TODO(panyx0718): Replace vector with something like gtl::Vector. // TODO(panyx0718): Replace vector with something like gtl::Vector.
@ -61,5 +62,9 @@ using InferShapeFN = std::function<void(InferShapeContext*)>;
using InplacePair = std::unordered_map<std::string, std::string>; using InplacePair = std::unordered_map<std::string, std::string>;
using InferInplaceOpFN = std::function<InplacePair(const OpDesc&)>; using InferInplaceOpFN = std::function<InplacePair(const OpDesc&)>;
using InferNoNeedBufferVarsFN = std::function<std::unordered_set<std::string>(
const VariableNameMap& /*inputs*/, const VariableNameMap& /*outputs*/,
const AttributeMap& /*attrs*/)>;
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle

@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/add_position_encoding_op.h" #include "paddle/fluid/operators/add_position_encoding_op.h"
#include <memory>
namespace paddle { namespace paddle {
namespace operators { namespace operators {
@ -39,13 +40,8 @@ class AddPositionEncodingOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "X(Input) must not be null.");
PADDLE_ENFORCE(ctx->HasInput("Out"), "Out must not be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Out@GRAD must not be null.");
auto out_dims = ctx->GetInputDim("Out");
if (ctx->HasOutput(framework::GradVarName("X"))) { if (ctx->HasOutput(framework::GradVarName("X"))) {
auto out_dims = ctx->GetInputDim(framework::GradVarName("Out"));
ctx->SetOutputDim(framework::GradVarName("X"), out_dims); ctx->SetOutputDim(framework::GradVarName("X"), out_dims);
} }
} }
@ -75,6 +71,22 @@ class AddPositionEncodingOpMaker : public framework::OpProtoAndCheckerMaker {
} }
}; };
class AddPositionEncodingGradOpDescMaker
: public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("add_position_encoding_grad");
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetAttrMap(Attrs());
return op;
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
@ -83,7 +95,7 @@ namespace plt = paddle::platform;
REGISTER_OPERATOR(add_position_encoding, ops::AddPositionEncodingOp, REGISTER_OPERATOR(add_position_encoding, ops::AddPositionEncodingOp,
ops::AddPositionEncodingOpMaker, ops::AddPositionEncodingOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>); ops::AddPositionEncodingGradOpDescMaker);
REGISTER_OPERATOR(add_position_encoding_grad, ops::AddPositionEncodingOpGrad); REGISTER_OPERATOR(add_position_encoding_grad, ops::AddPositionEncodingOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(

@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/clip_op.h" #include "paddle/fluid/operators/clip_op.h"
#include <memory>
namespace paddle { namespace paddle {
namespace operators { namespace operators {
@ -76,12 +77,28 @@ class ClipOpGrad : public framework::OperatorWithKernel {
} }
}; };
class ClipGradOpDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("clip_grad");
op->SetInput("X", Input("X"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetAttrMap(Attrs());
return op;
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(clip, ops::ClipOp, ops::ClipOpMaker<float>, REGISTER_OPERATOR(clip, ops::ClipOp, ops::ClipOpMaker<float>,
paddle::framework::DefaultGradOpDescMaker<true>); ops::ClipGradOpDescMaker);
REGISTER_OPERATOR(clip_grad, ops::ClipOpGrad); REGISTER_OPERATOR(clip_grad, ops::ClipOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
clip, ops::ClipKernel<paddle::platform::CPUDeviceContext, float>); clip, ops::ClipKernel<paddle::platform::CPUDeviceContext, float>);

@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/concat_op.h" #include "paddle/fluid/operators/concat_op.h"
#include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
@ -120,11 +121,7 @@ Examples:
class ConcatOpGrad : public framework::OperatorWithKernel { class ConcatOpGrad : public framework::OperatorWithKernel {
public: public:
ConcatOpGrad(const std::string &type, using framework::OperatorWithKernel::OperatorWithKernel;
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
auto in_x = "X"; auto in_x = "X";
@ -142,6 +139,33 @@ class ConcatOpGrad : public framework::OperatorWithKernel {
} }
} }
} }
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
ctx.Input<Tensor>(framework::GradVarName("Out"))->type(),
ctx.GetPlace());
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(ConcatOpGradNoNeedBufferVarInference,
"X");
class ConcatGradOpDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("concat_grad");
op->SetInput("X", Input("X"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X", false));
op->SetAttrMap(Attrs());
return op;
}
}; };
} // namespace operators } // namespace operators
@ -149,9 +173,9 @@ class ConcatOpGrad : public framework::OperatorWithKernel {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(concat, ops::ConcatOp, ops::ConcatOpMaker, REGISTER_OPERATOR(concat, ops::ConcatOp, ops::ConcatOpMaker,
paddle::framework::DefaultGradOpDescMaker< ops::ConcatGradOpDescMaker);
false> /* set false to disable empty grad */); REGISTER_OPERATOR(concat_grad, ops::ConcatOpGrad,
REGISTER_OPERATOR(concat_grad, ops::ConcatOpGrad); ops::ConcatOpGradNoNeedBufferVarInference);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
concat, ops::ConcatKernel<paddle::platform::CPUDeviceContext, double>, concat, ops::ConcatKernel<paddle::platform::CPUDeviceContext, double>,
ops::ConcatKernel<paddle::platform::CPUDeviceContext, float>, ops::ConcatKernel<paddle::platform::CPUDeviceContext, float>,

@ -455,13 +455,13 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType(
return type; return type;
} }
class Conv2dGradMaker : public framework::SingleGradOpDescMaker { class Conv2DGradMaker : public framework::SingleGradOpDescMaker {
public: public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
std::unique_ptr<framework::OpDesc> Apply() const override { std::unique_ptr<framework::OpDesc> Apply() const override {
auto* op = new framework::OpDesc(); auto* op = new framework::OpDesc();
op->SetType(GradOpType()); op->SetType(this->ForwardOpType() + "_grad");
op->SetInput("Input", Input("Input")); op->SetInput("Input", Input("Input"));
op->SetInput("Filter", Input("Filter")); op->SetInput("Filter", Input("Filter"));
op->SetInput("Bias", Input("Bias")); op->SetInput("Bias", Input("Bias"));
@ -470,14 +470,33 @@ class Conv2dGradMaker : public framework::SingleGradOpDescMaker {
op->SetOutput(framework::GradVarName("Input"), InputGrad("Input")); op->SetOutput(framework::GradVarName("Input"), InputGrad("Input"));
op->SetOutput(framework::GradVarName("Filter"), InputGrad("Filter")); op->SetOutput(framework::GradVarName("Filter"), InputGrad("Filter"));
op->SetOutput(framework::GradVarName("Bias"), InputGrad("Bias")); op->SetOutput(framework::GradVarName("Bias"), InputGrad("Bias"));
op->SetAttrMap(Attrs()); op->SetAttrMap(Attrs());
return std::unique_ptr<framework::OpDesc>(op); return std::unique_ptr<framework::OpDesc>(op);
} }
};
class Conv3DGradMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
virtual std::string GradOpType() const { std::unique_ptr<framework::OpDesc> Apply() const override {
return this->ForwardOpType() + "_grad"; auto* op = new framework::OpDesc();
op->SetType(this->ForwardOpType() + "_grad");
op->SetInput("Input", Input("Input"));
op->SetInput("Filter", Input("Filter"));
op->SetInput(framework::GradVarName("Output"), OutputGrad("Output"));
op->SetOutput(framework::GradVarName("Input"), InputGrad("Input"));
op->SetOutput(framework::GradVarName("Filter"), InputGrad("Filter"));
if (ForwardOp().Inputs().count("ResidualData") != 0) {
op->SetInput("ResidualData", Input("ResidualData"));
}
op->SetAttrMap(Attrs());
return std::unique_ptr<framework::OpDesc>(op);
} }
}; };
@ -486,17 +505,16 @@ class Conv2dGradMaker : public framework::SingleGradOpDescMaker {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(conv2d, ops::ConvOp, ops::Conv2DOpMaker, REGISTER_OPERATOR(conv2d, ops::ConvOp, ops::Conv2DOpMaker,
ops::ConvOpInferVarType, ops::Conv2dGradMaker); ops::ConvOpInferVarType, ops::Conv2DGradMaker);
REGISTER_OPERATOR(conv2d_grad, ops::ConvOpGrad); REGISTER_OPERATOR(conv2d_grad, ops::ConvOpGrad);
// depthwise convolution op // depthwise convolution op
REGISTER_OPERATOR(depthwise_conv2d, ops::ConvOp, ops::Conv2DOpMaker, REGISTER_OPERATOR(depthwise_conv2d, ops::ConvOp, ops::Conv2DOpMaker,
ops::ConvOpInferVarType, ops::Conv2dGradMaker); ops::ConvOpInferVarType, ops::Conv2DGradMaker);
REGISTER_OPERATOR(depthwise_conv2d_grad, ops::ConvOpGrad); REGISTER_OPERATOR(depthwise_conv2d_grad, ops::ConvOpGrad);
REGISTER_OPERATOR(conv3d, ops::ConvOp, ops::Conv3DOpMaker, REGISTER_OPERATOR(conv3d, ops::ConvOp, ops::Conv3DOpMaker,
ops::ConvOpInferVarType, ops::ConvOpInferVarType, ops::Conv3DGradMaker);
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(conv3d_grad, ops::ConvOpGrad); REGISTER_OPERATOR(conv3d_grad, ops::ConvOpGrad);
// depthwise conv kernel // depthwise conv kernel

@ -13,7 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/crop_op.h" #include "paddle/fluid/operators/crop_op.h"
#include <boost/lexical_cast.hpp> #include <memory>
#include <string>
#include <vector>
namespace paddle { namespace paddle {
namespace operators { namespace operators {
@ -178,12 +180,31 @@ class CropOpGrad : public framework::OperatorWithKernel {
} }
}; };
class CropGradOpDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("crop_grad");
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetInput("X", Input("X"));
if (ForwardOp().Inputs().count("Offsets") > 0) {
op->SetInput("Offsets", Input("Offsets"));
}
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetAttrMap(Attrs());
return op;
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(crop, ops::CropOp, ops::CropOpMaker, REGISTER_OPERATOR(crop, ops::CropOp, ops::CropOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>); ops::CropGradOpDescMaker);
REGISTER_OPERATOR(crop_grad, ops::CropOpGrad); REGISTER_OPERATOR(crop_grad, ops::CropOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
crop, ops::CropKernel<paddle::platform::CPUDeviceContext, float>); crop, ops::CropKernel<paddle::platform::CPUDeviceContext, float>);

@ -238,6 +238,23 @@ class CrossEntropyGradientOp : public CrossEntropyGradientOpBase {
} }
}; };
class CrossEntropyGradOpDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("cross_entropy_grad");
op->SetInput("X", Input("X"));
op->SetInput("Label", Input("Label"));
op->SetInput(framework::GradVarName("Y"), OutputGrad("Y"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetAttrMap(Attrs());
return op;
}
};
class CrossEntropyOp2 : public CrossEntropyOpBase { class CrossEntropyOp2 : public CrossEntropyOpBase {
public: public:
using CrossEntropyOpBase::CrossEntropyOpBase; using CrossEntropyOpBase::CrossEntropyOpBase;
@ -354,7 +371,7 @@ using CPUCtx = paddle::platform::CPUDeviceContext;
REGISTER_OPERATOR(cross_entropy, ops::CrossEntropyOpBase, REGISTER_OPERATOR(cross_entropy, ops::CrossEntropyOpBase,
ops::CrossEntropyOpMaker, ops::CrossEntropyOpInferVarType, ops::CrossEntropyOpMaker, ops::CrossEntropyOpInferVarType,
paddle::framework::DefaultGradOpDescMaker<true>); ops::CrossEntropyGradOpDescMaker);
REGISTER_OPERATOR(cross_entropy_grad, ops::CrossEntropyGradientOp); REGISTER_OPERATOR(cross_entropy_grad, ops::CrossEntropyGradientOp);
REGISTER_OP_CPU_KERNEL(cross_entropy, ops::CrossEntropyOpKernel<CPUCtx, float>, REGISTER_OP_CPU_KERNEL(cross_entropy, ops::CrossEntropyOpKernel<CPUCtx, float>,
ops::CrossEntropyOpKernel<CPUCtx, double>); ops::CrossEntropyOpKernel<CPUCtx, double>);

@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <memory>
#include <string> #include <string>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
@ -170,11 +171,6 @@ class CudnnLSTMGradOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasInput("Input"), PADDLE_ENFORCE(ctx->HasInput("Input"),
"Input(Input) of LSTM should not be null."); "Input(Input) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("W"), "Input(W) of LSTM should not be null."); PADDLE_ENFORCE(ctx->HasInput("W"), "Input(W) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("last_h"),
"Input(last_h) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("last_c"),
"Input(last_c) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Cache"), PADDLE_ENFORCE(ctx->HasInput("Cache"),
"Input(last_c) of LSTM should not be null."); "Input(last_c) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("InitH"), PADDLE_ENFORCE(ctx->HasInput("InitH"),
@ -197,6 +193,35 @@ class CudnnLSTMGradOp : public framework::OperatorWithKernel {
} }
}; };
class CudnnLSTMGradOpDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("cudnn_lstm_grad");
op->SetInput("Input", Input("Input"));
op->SetInput("InitH", Input("InitH"));
op->SetInput("InitC", Input("InitC"));
op->SetInput("W", Input("W"));
if (ForwardOp().Inputs().count("Cache") > 0) {
op->SetInput("Cache", Input("Cache"));
}
op->SetInput("Out", Output("Out"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetInput(framework::GradVarName("last_c"), OutputGrad("last_c"));
op->SetInput(framework::GradVarName("last_h"), OutputGrad("last_h"));
op->SetOutput(framework::GradVarName("Input"), InputGrad("Input"));
op->SetOutput(framework::GradVarName("W"), InputGrad("W"));
op->SetOutput(framework::GradVarName("InitH"), InputGrad("InitH"));
op->SetOutput(framework::GradVarName("InitC"), InputGrad("InitC"));
op->SetAttrMap(Attrs());
return op;
}
};
template <typename T> template <typename T>
class NotImpleKernel : public framework::OpKernel<T> { class NotImpleKernel : public framework::OpKernel<T> {
public: public:
@ -211,7 +236,7 @@ class NotImpleKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(cudnn_lstm, ops::CudnnLSTMOp, ops::CudnnLSTMOpMaker, REGISTER_OPERATOR(cudnn_lstm, ops::CudnnLSTMOp, ops::CudnnLSTMOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>); ops::CudnnLSTMGradOpDescMaker);
REGISTER_OPERATOR(cudnn_lstm_grad, ops::CudnnLSTMGradOp); REGISTER_OPERATOR(cudnn_lstm_grad, ops::CudnnLSTMGradOp);
REGISTER_OP_CPU_KERNEL(cudnn_lstm, ops::NotImpleKernel<float>); REGISTER_OP_CPU_KERNEL(cudnn_lstm, ops::NotImpleKernel<float>);

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save