You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Paddle/paddle/fluid/framework/garbage_collector.cc

132 lines
4.6 KiB

// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <deque>
#include <functional>
#include <memory>
#include <mutex> // NOLINT
#include <utility>
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cuda_device_guard.h"
#endif
#include "gflags/gflags.h"
#include "glog/logging.h"
#include "paddle/fluid/framework/garbage_collector.h"
namespace paddle {
namespace framework {
DEFINE_double(
eager_delete_tensor_gb, -1.0,
"Memory size threshold (GB) when the garbage collector clear tensors."
"Disabled when this value is less than 0");
DEFINE_bool(fast_eager_deletion_mode, true,
"Fast eager deletion mode. If enabled, memory would release "
"immediately without waiting GPU kernel ends.");
DEFINE_double(memory_fraction_of_eager_deletion, 1.0,
"Fraction of eager deletion. If less than 1.0, all variables in "
"the program would be sorted according to its memory size, and "
"only the FLAGS_memory_fraction_of_eager_deletion of the largest "
"variables would be deleted.");
GarbageCollector::GarbageCollector(const platform::Place &place,
size_t max_memory_size)
: max_memory_size_((std::max)(max_memory_size, static_cast<size_t>(1))) {
garbages_.reset(new GarbageQueue());
dev_ctx_ = platform::DeviceContextPool::Instance().Get(place);
}
CPUGarbageCollector::CPUGarbageCollector(const platform::CPUPlace &place,
size_t max_memory_size)
: GarbageCollector(place, max_memory_size) {}
void CPUGarbageCollector::ClearCallback(const std::function<void()> &callback) {
callback();
}
#ifdef PADDLE_WITH_CUDA
UnsafeFastGPUGarbageCollector::UnsafeFastGPUGarbageCollector(
const platform::CUDAPlace &place, size_t max_memory_size)
: GarbageCollector(place, max_memory_size) {}
void UnsafeFastGPUGarbageCollector::ClearCallback(
const std::function<void()> &callback) {
callback();
}
DefaultStreamGarbageCollector::DefaultStreamGarbageCollector(
const platform::CUDAPlace &place, size_t max_memory_size)
: GarbageCollector(place, max_memory_size) {}
void DefaultStreamGarbageCollector::Wait() const {
static_cast<platform::CUDADeviceContext *>(this->dev_ctx_)
->WaitStreamCallback();
}
void DefaultStreamGarbageCollector::ClearCallback(
const std::function<void()> &callback) {
static_cast<platform::CUDADeviceContext *>(this->dev_ctx_)
->AddStreamCallback(callback);
}
StreamGarbageCollector::StreamGarbageCollector(const platform::CUDAPlace &place,
size_t max_memory_size)
: GarbageCollector(place, max_memory_size) {
platform::CUDADeviceGuard guard(place.device);
PADDLE_ENFORCE(cudaStreamCreate(&stream_));
callback_manager_.reset(new platform::StreamCallbackManager(stream_));
}
StreamGarbageCollector::~StreamGarbageCollector() {
auto place = boost::get<platform::CUDAPlace>(this->dev_ctx_->GetPlace());
platform::CUDADeviceGuard guard(place.device);
PADDLE_ENFORCE(cudaStreamSynchronize(stream_));
PADDLE_ENFORCE(cudaStreamDestroy(stream_));
}
cudaStream_t StreamGarbageCollector::stream() const { return stream_; }
void StreamGarbageCollector::Wait() const { callback_manager_->Wait(); }
void StreamGarbageCollector::ClearCallback(
const std::function<void()> &callback) {
callback_manager_->AddCallback(callback);
}
#endif
int64_t GetEagerDeletionThreshold() {
return FLAGS_eager_delete_tensor_gb < 0
? -1
: static_cast<int64_t>(FLAGS_eager_delete_tensor_gb *
(static_cast<int64_t>(1) << 30));
}
bool IsFastEagerDeletionModeEnabled() { return FLAGS_fast_eager_deletion_mode; }
void SetEagerDeletionMode(double threshold, double fraction, bool fast_mode) {
FLAGS_eager_delete_tensor_gb = threshold;
FLAGS_memory_fraction_of_eager_deletion = fraction;
FLAGS_fast_eager_deletion_mode = fast_mode;
}
double GetEagerDeletionMemoryFraction() {
return FLAGS_memory_fraction_of_eager_deletion;
}
} // namespace framework
} // namespace paddle