Merge pull request #13443 from sneaxiy/feature/eager_delete_tensor

Fix eager deletion bug in executor
upload-readme
Zeng Jinle 7 years ago committed by GitHub
commit 1e44201cb6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -337,6 +337,7 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
std::unique_ptr<GarbageCollector<Tensor>> gc;
if (max_memory_size >= 0) {
ctx->ResetReferenceCount();
#ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(place_)) {
gc.reset(new DefaultStreamGarbageCollector<Tensor>(
@ -357,11 +358,11 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
std::vector<std::string> erase_vars;
for (auto& input : op->Inputs()) {
for (auto& input_name : input.second) {
auto it = ctx->ref_cnts_.find(input_name);
if (it == ctx->ref_cnts_.end()) continue;
auto it = ctx->cur_ref_cnts_.find(input_name);
if (it == ctx->cur_ref_cnts_.end()) continue;
if (it->second == 1) { // should delete it
erase_vars.emplace_back(input_name);
ctx->ref_cnts_.erase(input_name);
ctx->cur_ref_cnts_.erase(input_name);
} else {
--(it->second);
}
@ -370,11 +371,11 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
for (auto& output : op->Outputs()) {
for (auto& output_name : output.second) {
auto it = ctx->ref_cnts_.find(output_name);
if (it == ctx->ref_cnts_.end()) continue;
auto it = ctx->cur_ref_cnts_.find(output_name);
if (it == ctx->cur_ref_cnts_.end()) continue;
if (it->second == 1) {
erase_vars.emplace_back(output_name);
ctx->ref_cnts_.erase(output_name);
ctx->cur_ref_cnts_.erase(output_name);
} else {
--(it->second);
}

@ -72,11 +72,14 @@ struct ExecutorPrepareContext {
ExecutorPrepareContext(const framework::ProgramDesc& prog, size_t block_id);
~ExecutorPrepareContext();
void ResetReferenceCount() { cur_ref_cnts_ = ref_cnts_; }
const framework::ProgramDesc& prog_;
size_t block_id_;
std::vector<std::unique_ptr<OperatorBase>> ops_;
std::unordered_map<std::string, int> ref_cnts_;
std::unordered_map<std::string, int> cur_ref_cnts_;
};
class Executor {

Loading…
Cancel
Save