Replace TemporaryAllocator by CUDADeviceContextAllocator (#18989)
TemporaryAllocator is a singleton used for allocating memory for Cudnn. Since it is a singleton, we can delete it for better performance in memory. We replace TemporaryAllocator by CUDADeviceContextAllocator and CUDADeviceContextAllocation, which uses stream callback to delete the memory allocated for the stream to avoid singleton. Also added data_feed_proto to operator to fix CI in CPU compilationexpand_as_op_1
parent
0daa5c9772
commit
12542320c5
@ -1,12 +1,29 @@
|
||||
add_subdirectory(detail)
|
||||
add_subdirectory(allocation)
|
||||
cc_library(malloc SRCS malloc.cc DEPS place enforce allocator_facade profiler)
|
||||
|
||||
if (WITH_MKLDNN)
|
||||
set(MKLDNN_CTX_DEPS mkldnn)
|
||||
else ()
|
||||
set(MKLDNN_CTX_DEPS)
|
||||
endif()
|
||||
|
||||
cc_library(malloc SRCS malloc.cc DEPS
|
||||
place enforce allocator_facade profiler ${MKLDNN_CTX_DEPS})
|
||||
cc_library(memcpy SRCS memcpy.cc DEPS place)
|
||||
|
||||
cc_library(memory
|
||||
DEPS
|
||||
malloc
|
||||
memcpy)
|
||||
|
||||
if (WITH_GPU)
|
||||
add_dependencies(malloc cuda_device_context_allocator_pool)
|
||||
target_link_libraries(malloc cuda_device_context_allocator_pool)
|
||||
nv_test(malloc_test
|
||||
SRCS malloc_test.cu
|
||||
DEPS device_context malloc)
|
||||
endif()
|
||||
|
||||
#if (WITH_GPU)
|
||||
# nv_test(pinned_memory_test SRCS pinned_memory_test.cu DEPS place memory)
|
||||
#endif()
|
||||
|
@ -0,0 +1,47 @@
|
||||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/memory/allocation/cuda_device_context_allocation.h"
|
||||
#include <utility>
|
||||
#include "paddle/fluid/platform/enforce.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace memory {
|
||||
namespace allocation {
|
||||
|
||||
CUDADeviceContextAllocation::CUDADeviceContextAllocation(
|
||||
AllocationPtr allocation)
|
||||
: Allocation(allocation->ptr(), allocation->size(), allocation->place()),
|
||||
underlying_allocation_(std::move(allocation)) {}
|
||||
|
||||
CUDADeviceContextAllocation::~CUDADeviceContextAllocation() {
|
||||
PADDLE_ENFORCE_NOT_NULL(
|
||||
dev_ctx_, "Didn't set device context for CUDADeviceContextAllocation");
|
||||
auto *p_allocation = underlying_allocation_.release();
|
||||
VLOG(4) << "Adding callback to delete CUDADeviceContextAllocation at "
|
||||
<< p_allocation;
|
||||
dev_ctx_->AddStreamCallback([p_allocation] {
|
||||
VLOG(4) << "Delete CUDADeviceContextAllocation at " << p_allocation;
|
||||
AllocationDeleter()(p_allocation);
|
||||
});
|
||||
}
|
||||
|
||||
void CUDADeviceContextAllocation::SetCUDADeviceContext(
|
||||
const platform::CUDADeviceContext *dev_ctx) {
|
||||
dev_ctx_ = dev_ctx;
|
||||
}
|
||||
|
||||
} // namespace allocation
|
||||
} // namespace memory
|
||||
} // namespace paddle
|
@ -0,0 +1,36 @@
|
||||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
#include "paddle/fluid/memory/allocation/allocator.h"
|
||||
#include "paddle/fluid/platform/device_context.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace memory {
|
||||
namespace allocation {
|
||||
|
||||
class CUDADeviceContextAllocation : public Allocation {
|
||||
public:
|
||||
explicit CUDADeviceContextAllocation(AllocationPtr allocation);
|
||||
~CUDADeviceContextAllocation();
|
||||
void SetCUDADeviceContext(const platform::CUDADeviceContext *dev_ctx);
|
||||
|
||||
private:
|
||||
AllocationPtr underlying_allocation_;
|
||||
const platform::CUDADeviceContext *dev_ctx_{nullptr};
|
||||
};
|
||||
|
||||
} // namespace allocation
|
||||
} // namespace memory
|
||||
} // namespace paddle
|
@ -0,0 +1,66 @@
|
||||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/memory/allocation/cuda_device_context_allocator.h"
|
||||
|
||||
#include "paddle/fluid/memory/allocation/cuda_device_context_allocation.h"
|
||||
#include "paddle/fluid/platform/cuda_device_guard.h"
|
||||
#include "paddle/fluid/platform/enforce.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace memory {
|
||||
namespace allocation {
|
||||
|
||||
CUDADeviceContextAllocator::CUDADeviceContextAllocator(
|
||||
const platform::CUDAPlace place, cudaStream_t default_stream)
|
||||
: place_(place), default_stream_(default_stream) {
|
||||
platform::CUDADeviceGuard guard(place_.device);
|
||||
PADDLE_ENFORCE_CUDA_SUCCESS(
|
||||
cudaEventCreate(&event_, cudaEventDisableTiming),
|
||||
"Create event failed in CUDADeviceContextAllocator");
|
||||
}
|
||||
|
||||
CUDADeviceContextAllocator::~CUDADeviceContextAllocator() {
|
||||
if (event_) {
|
||||
platform::CUDADeviceGuard guard(place_.device);
|
||||
PADDLE_ENFORCE_CUDA_SUCCESS(
|
||||
cudaEventDestroy(event_),
|
||||
"Destory event failed in CUDADeviceContextAllocator destroctor");
|
||||
}
|
||||
}
|
||||
|
||||
Allocation *CUDADeviceContextAllocator::AllocateImpl(size_t size) {
|
||||
PADDLE_ENFORCE_NOT_NULL(
|
||||
default_stream_,
|
||||
"Didn't set default stream for CUDADeviceContextAllocator");
|
||||
platform::CUDADeviceGuard guard(place_.device);
|
||||
auto allocation =
|
||||
new CUDADeviceContextAllocation(memory::Alloc(place_, size));
|
||||
// Wait for the event on stream
|
||||
PADDLE_ENFORCE_CUDA_SUCCESS(
|
||||
cudaEventRecord(event_, default_stream_),
|
||||
"Failed to record event in CUDADeviceContextAllocator");
|
||||
PADDLE_ENFORCE_CUDA_SUCCESS(
|
||||
cudaStreamWaitEvent(default_stream_, event_, 0),
|
||||
"Failed to wait event in CUDADeviceContextAllocator");
|
||||
return allocation;
|
||||
}
|
||||
|
||||
void CUDADeviceContextAllocator::FreeImpl(Allocation *allocation) {
|
||||
delete allocation;
|
||||
}
|
||||
|
||||
} // namespace allocation
|
||||
} // namespace memory
|
||||
} // namespace paddle
|
@ -0,0 +1,45 @@
|
||||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include "paddle/fluid/memory/allocation/allocator.h"
|
||||
#include "paddle/fluid/platform/device_context.h"
|
||||
#include "paddle/fluid/platform/place.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace memory {
|
||||
namespace allocation {
|
||||
|
||||
class CUDADeviceContextAllocator : public Allocator {
|
||||
public:
|
||||
explicit CUDADeviceContextAllocator(platform::CUDAPlace place,
|
||||
cudaStream_t default_stream);
|
||||
~CUDADeviceContextAllocator();
|
||||
|
||||
protected:
|
||||
Allocation *AllocateImpl(size_t size) override;
|
||||
void FreeImpl(Allocation *allocation) override;
|
||||
|
||||
private:
|
||||
platform::CUDAPlace place_;
|
||||
cudaEvent_t event_{nullptr};
|
||||
cudaStream_t default_stream_{nullptr};
|
||||
};
|
||||
|
||||
} // namespace allocation
|
||||
} // namespace memory
|
||||
} // namespace paddle
|
@ -0,0 +1,59 @@
|
||||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/memory/allocation/cuda_device_context_allocator_pool.h"
|
||||
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/memory/allocation/cuda_device_context_allocation.h"
|
||||
#include "paddle/fluid/memory/allocation/cuda_device_context_allocator.h"
|
||||
#include "paddle/fluid/platform/enforce.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace memory {
|
||||
namespace allocation {
|
||||
|
||||
CUDADeviceContextAllocatorPool &CUDADeviceContextAllocatorPool::Instance() {
|
||||
static CUDADeviceContextAllocatorPool pool;
|
||||
return pool;
|
||||
}
|
||||
|
||||
AllocationPtr CUDADeviceContextAllocatorPool::Alloc(
|
||||
const platform::CUDADeviceContext &dev_ctx, size_t size) {
|
||||
auto iter =
|
||||
allocators_.find(boost::get<platform::CUDAPlace>(dev_ctx.GetPlace()));
|
||||
PADDLE_ENFORCE_EQ(iter != allocators_.end(), true,
|
||||
"CUDADeviceContextAllocatorPool initialization error");
|
||||
auto &allocator = iter->second;
|
||||
AllocationPtr allocation = allocator->Allocate(size);
|
||||
static_cast<CUDADeviceContextAllocation *>(allocation.get())
|
||||
->SetCUDADeviceContext(&dev_ctx);
|
||||
return allocation;
|
||||
}
|
||||
|
||||
CUDADeviceContextAllocatorPool::CUDADeviceContextAllocatorPool() {
|
||||
std::vector<int> devices = platform::GetSelectedDevices();
|
||||
for (int i : devices) {
|
||||
auto place = platform::CUDAPlace(i);
|
||||
auto compute_stream =
|
||||
platform::DeviceContextPool::Instance().GetByPlace(place)->stream();
|
||||
auto allocator = std::shared_ptr<CUDADeviceContextAllocator>(
|
||||
new CUDADeviceContextAllocator(place, compute_stream));
|
||||
allocators_.insert(make_pair(place, allocator));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace allocation
|
||||
} // namespace memory
|
||||
} // namespace paddle
|
@ -0,0 +1,41 @@
|
||||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include "paddle/fluid/memory/allocation/allocator.h"
|
||||
#include "paddle/fluid/memory/allocation/cuda_device_context_allocator.h"
|
||||
#include "paddle/fluid/platform/device_context.h"
|
||||
#include "paddle/fluid/platform/place.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace memory {
|
||||
namespace allocation {
|
||||
|
||||
class CUDADeviceContextAllocatorPool {
|
||||
public:
|
||||
static CUDADeviceContextAllocatorPool &Instance();
|
||||
|
||||
AllocationPtr Alloc(const platform::CUDADeviceContext &dev_ctx, size_t size);
|
||||
|
||||
private:
|
||||
CUDADeviceContextAllocatorPool();
|
||||
std::map<platform::CUDAPlace, std::shared_ptr<CUDADeviceContextAllocator>>
|
||||
allocators_;
|
||||
};
|
||||
|
||||
} // namespace allocation
|
||||
} // namespace memory
|
||||
} // namespace paddle
|
@ -0,0 +1,137 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <thread> // NOLINT
|
||||
#include <vector>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "paddle/fluid/memory/malloc.h"
|
||||
#include "paddle/fluid/platform/device_context.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace memory {
|
||||
|
||||
const int NUM_STREAMS = 8;
|
||||
const int N = 2;
|
||||
const float DELTA = 1e-1;
|
||||
|
||||
using CudaDevCtxVec = std::vector<std::unique_ptr<platform::CUDADeviceContext>>;
|
||||
|
||||
__global__ void kernel(float *x, int n) {
|
||||
int tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
for (int i = tid; i < n; i += blockDim.x * gridDim.x) {
|
||||
x[i] = 3.14159 * i;
|
||||
}
|
||||
}
|
||||
|
||||
void CheckKernelOutput(float *x, int n) {
|
||||
auto host_x = std::unique_ptr<float[]>(new float[n]);
|
||||
for (int i = 0; i < n; ++i) {
|
||||
EXPECT_TRUE(cudaSuccess == cudaMemcpy(host_x.get(), x, n * sizeof(float),
|
||||
cudaMemcpyDeviceToHost));
|
||||
EXPECT_GE(host_x[i] + DELTA, 3.14159f * i);
|
||||
EXPECT_LE(host_x[i] - DELTA, 3.14159f * i);
|
||||
}
|
||||
}
|
||||
|
||||
void MultiStreamCompute(float **data, float **second_data,
|
||||
const platform::CUDADeviceContext &ctx) {
|
||||
// multi-streams
|
||||
AllocationPtr allocation_ptr = Alloc(ctx, N * sizeof(float));
|
||||
EXPECT_GE(allocation_ptr->size(), N * sizeof(float));
|
||||
*data = reinterpret_cast<float *>(allocation_ptr->ptr());
|
||||
kernel<<<1, 64, 0, ctx.stream()>>>(*data, N);
|
||||
|
||||
// allocate and compute on same stream again
|
||||
allocation_ptr = Alloc(ctx, N * sizeof(float));
|
||||
EXPECT_GE(allocation_ptr->size(), N * sizeof(float));
|
||||
*second_data = reinterpret_cast<float *>(allocation_ptr->ptr());
|
||||
kernel<<<1, 64, 0, ctx.stream()>>>(*second_data, N);
|
||||
}
|
||||
|
||||
TEST(Malloc, CUDADeviceContextMultiStream) {
|
||||
auto place = platform::CUDAPlace(0);
|
||||
EXPECT_TRUE(cudaSuccess == cudaSetDevice(0));
|
||||
|
||||
AllocationPtr main_stream_alloc_ptr = Alloc(place, N * sizeof(float));
|
||||
EXPECT_GE(main_stream_alloc_ptr->size(), N * sizeof(float));
|
||||
float *main_stream_data =
|
||||
reinterpret_cast<float *>(main_stream_alloc_ptr->ptr());
|
||||
|
||||
float *data[NUM_STREAMS];
|
||||
float *second_data[NUM_STREAMS];
|
||||
CudaDevCtxVec dev_ctx;
|
||||
|
||||
// default stream
|
||||
kernel<<<1, 64>>>(main_stream_data, N);
|
||||
main_stream_alloc_ptr.reset();
|
||||
|
||||
for (int i = 0; i < NUM_STREAMS; ++i) {
|
||||
dev_ctx.push_back(std::unique_ptr<platform::CUDADeviceContext>(
|
||||
new platform::CUDADeviceContext(place)));
|
||||
MultiStreamCompute(&data[i], &second_data[i], *dev_ctx[i]);
|
||||
}
|
||||
|
||||
EXPECT_TRUE(cudaSuccess == cudaDeviceSynchronize());
|
||||
for (int i = 0; i < NUM_STREAMS; ++i) {
|
||||
CheckKernelOutput(data[i], N);
|
||||
CheckKernelOutput(second_data[i], N);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(Malloc, CUDADeviceContextMultiThreadMultiStream) {
|
||||
auto place = platform::CUDAPlace(0);
|
||||
EXPECT_TRUE(cudaSuccess == cudaSetDevice(0));
|
||||
|
||||
AllocationPtr main_stream_alloc_ptr = Alloc(place, N * sizeof(float));
|
||||
EXPECT_GE(main_stream_alloc_ptr->size(), N * sizeof(float));
|
||||
float *main_stream_data =
|
||||
reinterpret_cast<float *>(main_stream_alloc_ptr->ptr());
|
||||
|
||||
float *data[NUM_STREAMS];
|
||||
float *second_data[NUM_STREAMS];
|
||||
CudaDevCtxVec dev_ctx;
|
||||
std::vector<std::thread> threads;
|
||||
|
||||
// default stream
|
||||
kernel<<<1, 64>>>(main_stream_data, N);
|
||||
main_stream_alloc_ptr.reset();
|
||||
|
||||
for (int i = 0; i < NUM_STREAMS; ++i) {
|
||||
dev_ctx.push_back(std::unique_ptr<platform::CUDADeviceContext>(
|
||||
new platform::CUDADeviceContext(place)));
|
||||
threads.push_back(std::thread(MultiStreamCompute, &data[i], &second_data[i],
|
||||
std::cref(*dev_ctx[i])));
|
||||
}
|
||||
|
||||
for (int i = 0; i < NUM_STREAMS; ++i) {
|
||||
threads[i].join();
|
||||
}
|
||||
|
||||
EXPECT_TRUE(cudaSuccess == cudaDeviceSynchronize());
|
||||
for (int i = 0; i < NUM_STREAMS; ++i) {
|
||||
CheckKernelOutput(data[i], N);
|
||||
CheckKernelOutput(second_data[i], N);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(Malloc, AllocZero) {
|
||||
auto place = platform::CUDAPlace(0);
|
||||
AllocationPtr allocation_ptr = Alloc(place, 0);
|
||||
EXPECT_GE(allocation_ptr->size(), 0);
|
||||
}
|
||||
} // namespace memory
|
||||
} // namespace paddle
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue