Replace TemporaryAllocator by CUDADeviceContextAllocator (#18989)

TemporaryAllocator is a singleton used for allocating memory for Cudnn. Since it is a singleton, we can delete it for better performance in memory.

We replace TemporaryAllocator by CUDADeviceContextAllocator and CUDADeviceContextAllocation, which uses stream callback to delete the memory allocated for the stream to avoid singleton.

Also added data_feed_proto to operator to fix CI in CPU compilation
expand_as_op_1
Huihuang Zheng 6 years ago committed by GitHub
parent 0daa5c9772
commit 12542320c5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -389,7 +389,6 @@ function(cc_test_run TARGET_NAME)
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cpu_deterministic=true)
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_init_allocated_mem=true)
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_limit_of_tmp_allocation=4294967296) # 4G
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cudnn_deterministic=true)
# No unit test should exceed 10 minutes.
set_tests_properties(${TARGET_NAME} PROPERTIES TIMEOUT 600)
@ -472,7 +471,6 @@ function(nv_test TARGET_NAME)
add_test(${TARGET_NAME} ${TARGET_NAME})
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cpu_deterministic=true)
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_init_allocated_mem=true)
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_limit_of_tmp_allocation=4294967296) # 4G
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cudnn_deterministic=true)
endif()
endfunction(nv_test)
@ -725,7 +723,7 @@ function(py_test TARGET_NAME)
if(WITH_COVERAGE)
add_test(NAME ${TARGET_NAME}
COMMAND ${CMAKE_COMMAND} -E env FLAGS_init_allocated_mem=true FLAGS_cudnn_deterministic=true
FLAGS_cpu_deterministic=true FLAGS_limit_of_tmp_allocation=4294967296 # 4G
FLAGS_cpu_deterministic=true
PYTHONPATH=${PADDLE_BINARY_DIR}/python ${py_test_ENVS}
COVERAGE_FILE=${PADDLE_BINARY_DIR}/python-coverage.data
${PYTHON_EXECUTABLE} -m coverage run --branch -p ${py_test_SRCS} ${py_test_ARGS}
@ -733,7 +731,7 @@ function(py_test TARGET_NAME)
else()
add_test(NAME ${TARGET_NAME}
COMMAND ${CMAKE_COMMAND} -E env FLAGS_init_allocated_mem=true FLAGS_cudnn_deterministic=true
FLAGS_cpu_deterministic=true FLAGS_limit_of_tmp_allocation=4294967296 # 4G
FLAGS_cpu_deterministic=true
PYTHONPATH=${PADDLE_BINARY_DIR}/python ${py_test_ENVS}
${PYTHON_EXECUTABLE} -u ${py_test_SRCS} ${py_test_ARGS}
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})

@ -123,8 +123,8 @@ cc_library(shape_inference SRCS shape_inference.cc DEPS ddim attribute device_co
cc_library(transfer_scope_cache SRCS transfer_scope_cache.cc DEPS scope framework_proto device_context)
cc_library(op_kernel_type SRCS op_kernel_type.cc DEPS device_context place)
cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope
glog shape_inference data_transform lod_tensor profiler transfer_scope_cache op_kernel_type op_call_stack)
cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog data_feed_proto
shape_inference data_transform lod_tensor profiler transfer_scope_cache op_kernel_type op_call_stack)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry device_context)

@ -18,6 +18,7 @@
#include "paddle/fluid/framework/details/reduce_and_gather.h"
#include "paddle/fluid/framework/details/variable_visitor.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/platform/gpu_info.h"
#include "paddle/fluid/platform/profiler.h"
@ -103,16 +104,15 @@ void SparseAllReduceOpHandle::RunImplEncoded() {
int dev_id = boost::get<platform::CUDAPlace>(place).device;
auto *nccl_ctxs = nccl_ctxs_->GetRunEnvNCCLCtx(run_order_, false);
auto &nccl_ctx = nccl_ctxs->at(dev_id);
auto *dev_ctx = nccl_ctxs->DevCtx(dev_id);
auto stream = nccl_ctx.stream();
auto comm = nccl_ctx.comm_;
auto &allocator =
platform::DeviceTemporaryAllocator::Instance().Get(place, stream);
int encode_size = 2 * k * sizeof(int);
// dgc use ncclAllGather to get all the encoded data
// so the buffer need nranks.
int buf_size = nranks_ * encode_size;
auto tmp_ious_data = allocator.Allocate(buf_size);
auto tmp_ious_data = memory::Alloc(*dev_ctx, buf_size);
void *gather_buff = reinterpret_cast<void *>(tmp_ious_data->ptr());
VLOG(10) << "in_numel:" << in_numel << ", out_numel:" << out_numel

@ -35,6 +35,7 @@ limitations under the License. */
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/variant.h"
@ -360,9 +361,7 @@ class ExecutionContext {
template <typename T, typename DevContext>
Tensor AllocateTmpTensor(const framework::DDim& dim,
const DevContext& dev_ctx) const {
auto tmp_allocation_ptr = platform::DeviceTemporaryAllocator::Instance()
.Get<DevContext>(dev_ctx)
.Allocate(product(dim) * sizeof(T));
auto tmp_allocation_ptr = memory::Alloc(dev_ctx, product(dim) * sizeof(T));
auto& deleter = tmp_allocation_ptr.get_deleter();
auto* allocation_ptr = tmp_allocation_ptr.release();
auto shared_allocation = std::shared_ptr<memory::allocation::Allocation>(

@ -19,7 +19,6 @@ limitations under the License. */
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/temporary_allocator.h"
namespace paddle {
namespace framework {

@ -1,12 +1,29 @@
add_subdirectory(detail)
add_subdirectory(allocation)
cc_library(malloc SRCS malloc.cc DEPS place enforce allocator_facade profiler)
if (WITH_MKLDNN)
set(MKLDNN_CTX_DEPS mkldnn)
else ()
set(MKLDNN_CTX_DEPS)
endif()
cc_library(malloc SRCS malloc.cc DEPS
place enforce allocator_facade profiler ${MKLDNN_CTX_DEPS})
cc_library(memcpy SRCS memcpy.cc DEPS place)
cc_library(memory
DEPS
malloc
memcpy)
if (WITH_GPU)
add_dependencies(malloc cuda_device_context_allocator_pool)
target_link_libraries(malloc cuda_device_context_allocator_pool)
nv_test(malloc_test
SRCS malloc_test.cu
DEPS device_context malloc)
endif()
#if (WITH_GPU)
# nv_test(pinned_memory_test SRCS pinned_memory_test.cu DEPS place memory)
#endif()

@ -6,8 +6,20 @@ cc_library(best_fit_allocator SRCS best_fit_allocator.cc DEPS allocator)
cc_library(naive_best_fit_allocator SRCS naive_best_fit_allocator.cc DEPS allocator buddy_allocator profiler)
cc_test(buffered_allocator_test SRCS buffered_allocator_test.cc DEPS locked_allocator buffered_allocator cpu_allocator best_fit_allocator)
if (WITH_MKLDNN)
set(MKLDNN_CTX_DEPS mkldnn)
else ()
set(MKLDNN_CTX_DEPS)
endif()
if (WITH_GPU)
nv_library(cuda_allocator SRCS cuda_allocator.cc DEPS allocator cuda_device_guard)
nv_library(cuda_device_context_allocation SRCS cuda_device_context_allocation.cc
DEPS allocator enforce place ${MKLDNN_CTX_DEPS})
nv_library(cuda_device_context_allocator SRCS cuda_device_context_allocator.cc
DEPS allocator enforce place cuda_device_context_allocation ${MKLDNN_CTX_DEPS})
nv_library(cuda_device_context_allocator_pool SRCS cuda_device_context_allocator_pool.cc
DEPS allocator enforce place cuda_device_context_allocation cuda_device_context_allocator ${MKLDNN_CTX_DEPS})
endif()
cc_library(retry_allocator SRCS retry_allocator.cc DEPS allocator)

@ -0,0 +1,47 @@
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/memory/allocation/cuda_device_context_allocation.h"
#include <utility>
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace memory {
namespace allocation {
CUDADeviceContextAllocation::CUDADeviceContextAllocation(
AllocationPtr allocation)
: Allocation(allocation->ptr(), allocation->size(), allocation->place()),
underlying_allocation_(std::move(allocation)) {}
CUDADeviceContextAllocation::~CUDADeviceContextAllocation() {
PADDLE_ENFORCE_NOT_NULL(
dev_ctx_, "Didn't set device context for CUDADeviceContextAllocation");
auto *p_allocation = underlying_allocation_.release();
VLOG(4) << "Adding callback to delete CUDADeviceContextAllocation at "
<< p_allocation;
dev_ctx_->AddStreamCallback([p_allocation] {
VLOG(4) << "Delete CUDADeviceContextAllocation at " << p_allocation;
AllocationDeleter()(p_allocation);
});
}
void CUDADeviceContextAllocation::SetCUDADeviceContext(
const platform::CUDADeviceContext *dev_ctx) {
dev_ctx_ = dev_ctx;
}
} // namespace allocation
} // namespace memory
} // namespace paddle

@ -0,0 +1,36 @@
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/memory/allocation/allocator.h"
#include "paddle/fluid/platform/device_context.h"
namespace paddle {
namespace memory {
namespace allocation {
class CUDADeviceContextAllocation : public Allocation {
public:
explicit CUDADeviceContextAllocation(AllocationPtr allocation);
~CUDADeviceContextAllocation();
void SetCUDADeviceContext(const platform::CUDADeviceContext *dev_ctx);
private:
AllocationPtr underlying_allocation_;
const platform::CUDADeviceContext *dev_ctx_{nullptr};
};
} // namespace allocation
} // namespace memory
} // namespace paddle

@ -0,0 +1,66 @@
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/memory/allocation/cuda_device_context_allocator.h"
#include "paddle/fluid/memory/allocation/cuda_device_context_allocation.h"
#include "paddle/fluid/platform/cuda_device_guard.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace memory {
namespace allocation {
CUDADeviceContextAllocator::CUDADeviceContextAllocator(
const platform::CUDAPlace place, cudaStream_t default_stream)
: place_(place), default_stream_(default_stream) {
platform::CUDADeviceGuard guard(place_.device);
PADDLE_ENFORCE_CUDA_SUCCESS(
cudaEventCreate(&event_, cudaEventDisableTiming),
"Create event failed in CUDADeviceContextAllocator");
}
CUDADeviceContextAllocator::~CUDADeviceContextAllocator() {
if (event_) {
platform::CUDADeviceGuard guard(place_.device);
PADDLE_ENFORCE_CUDA_SUCCESS(
cudaEventDestroy(event_),
"Destory event failed in CUDADeviceContextAllocator destroctor");
}
}
Allocation *CUDADeviceContextAllocator::AllocateImpl(size_t size) {
PADDLE_ENFORCE_NOT_NULL(
default_stream_,
"Didn't set default stream for CUDADeviceContextAllocator");
platform::CUDADeviceGuard guard(place_.device);
auto allocation =
new CUDADeviceContextAllocation(memory::Alloc(place_, size));
// Wait for the event on stream
PADDLE_ENFORCE_CUDA_SUCCESS(
cudaEventRecord(event_, default_stream_),
"Failed to record event in CUDADeviceContextAllocator");
PADDLE_ENFORCE_CUDA_SUCCESS(
cudaStreamWaitEvent(default_stream_, event_, 0),
"Failed to wait event in CUDADeviceContextAllocator");
return allocation;
}
void CUDADeviceContextAllocator::FreeImpl(Allocation *allocation) {
delete allocation;
}
} // namespace allocation
} // namespace memory
} // namespace paddle

@ -0,0 +1,45 @@
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cuda_runtime.h>
#include "paddle/fluid/memory/allocation/allocator.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace memory {
namespace allocation {
class CUDADeviceContextAllocator : public Allocator {
public:
explicit CUDADeviceContextAllocator(platform::CUDAPlace place,
cudaStream_t default_stream);
~CUDADeviceContextAllocator();
protected:
Allocation *AllocateImpl(size_t size) override;
void FreeImpl(Allocation *allocation) override;
private:
platform::CUDAPlace place_;
cudaEvent_t event_{nullptr};
cudaStream_t default_stream_{nullptr};
};
} // namespace allocation
} // namespace memory
} // namespace paddle

@ -0,0 +1,59 @@
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/memory/allocation/cuda_device_context_allocator_pool.h"
#include <utility>
#include <vector>
#include "paddle/fluid/memory/allocation/cuda_device_context_allocation.h"
#include "paddle/fluid/memory/allocation/cuda_device_context_allocator.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace memory {
namespace allocation {
CUDADeviceContextAllocatorPool &CUDADeviceContextAllocatorPool::Instance() {
static CUDADeviceContextAllocatorPool pool;
return pool;
}
AllocationPtr CUDADeviceContextAllocatorPool::Alloc(
const platform::CUDADeviceContext &dev_ctx, size_t size) {
auto iter =
allocators_.find(boost::get<platform::CUDAPlace>(dev_ctx.GetPlace()));
PADDLE_ENFORCE_EQ(iter != allocators_.end(), true,
"CUDADeviceContextAllocatorPool initialization error");
auto &allocator = iter->second;
AllocationPtr allocation = allocator->Allocate(size);
static_cast<CUDADeviceContextAllocation *>(allocation.get())
->SetCUDADeviceContext(&dev_ctx);
return allocation;
}
CUDADeviceContextAllocatorPool::CUDADeviceContextAllocatorPool() {
std::vector<int> devices = platform::GetSelectedDevices();
for (int i : devices) {
auto place = platform::CUDAPlace(i);
auto compute_stream =
platform::DeviceContextPool::Instance().GetByPlace(place)->stream();
auto allocator = std::shared_ptr<CUDADeviceContextAllocator>(
new CUDADeviceContextAllocator(place, compute_stream));
allocators_.insert(make_pair(place, allocator));
}
}
} // namespace allocation
} // namespace memory
} // namespace paddle

@ -0,0 +1,41 @@
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <map>
#include <memory>
#include "paddle/fluid/memory/allocation/allocator.h"
#include "paddle/fluid/memory/allocation/cuda_device_context_allocator.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace memory {
namespace allocation {
class CUDADeviceContextAllocatorPool {
public:
static CUDADeviceContextAllocatorPool &Instance();
AllocationPtr Alloc(const platform::CUDADeviceContext &dev_ctx, size_t size);
private:
CUDADeviceContextAllocatorPool();
std::map<platform::CUDAPlace, std::shared_ptr<CUDADeviceContextAllocator>>
allocators_;
};
} // namespace allocation
} // namespace memory
} // namespace paddle

@ -17,17 +17,44 @@ limitations under the License. */
#include <vector>
#include "paddle/fluid/memory/allocation/allocator_facade.h"
#include "paddle/fluid/memory/allocation/allocator_strategy.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/memory/allocation/cuda_device_context_allocator_pool.h"
#endif
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace memory {
std::shared_ptr<Allocation> AllocShared(const platform::Place& place,
std::shared_ptr<Allocation> AllocShared(const platform::Place &place,
size_t size) {
return allocation::AllocatorFacade::Instance().AllocShared(place, size);
}
AllocationPtr Alloc(const platform::Place& place, size_t size) {
AllocationPtr Alloc(const platform::Place &place, size_t size) {
return allocation::AllocatorFacade::Instance().Alloc(place, size);
}
AllocationPtr Alloc(const platform::DeviceContext &dev_ctx, size_t size) {
auto place = dev_ctx.GetPlace();
#ifdef PADDLE_WITH_CUDA
if (size == 0 || !platform::is_gpu_place(place)) {
return Alloc(place, size);
}
auto *default_dev_ctx = static_cast<platform::CUDADeviceContext *>(
platform::DeviceContextPool::Instance().Get(place));
auto &desired_dev_ctx =
static_cast<const platform::CUDADeviceContext &>(dev_ctx);
if (default_dev_ctx->stream() == desired_dev_ctx.stream()) {
return Alloc(place, size);
} else {
return allocation::CUDADeviceContextAllocatorPool::Instance().Alloc(
desired_dev_ctx, size);
}
#else
return Alloc(place, size);
#endif
}
} // namespace memory
} // namespace paddle

@ -18,7 +18,13 @@ limitations under the License. */
#include "paddle/fluid/memory/allocation/allocator.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace platform {
class DeviceContext;
} // platform
namespace memory {
using allocation::Allocation;
using allocation::Allocator;
using allocation::AllocationPtr;
@ -28,5 +34,7 @@ extern std::shared_ptr<Allocation> AllocShared(const platform::Place& place,
extern AllocationPtr Alloc(const platform::Place& place, size_t size);
extern AllocationPtr Alloc(const platform::DeviceContext& dev_ctx, size_t size);
} // namespace memory
} // namespace paddle

@ -0,0 +1,137 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cuda.h>
#include <cuda_runtime.h>
#include <thread> // NOLINT
#include <vector>
#include "gtest/gtest.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/platform/device_context.h"
namespace paddle {
namespace memory {
const int NUM_STREAMS = 8;
const int N = 2;
const float DELTA = 1e-1;
using CudaDevCtxVec = std::vector<std::unique_ptr<platform::CUDADeviceContext>>;
__global__ void kernel(float *x, int n) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
for (int i = tid; i < n; i += blockDim.x * gridDim.x) {
x[i] = 3.14159 * i;
}
}
void CheckKernelOutput(float *x, int n) {
auto host_x = std::unique_ptr<float[]>(new float[n]);
for (int i = 0; i < n; ++i) {
EXPECT_TRUE(cudaSuccess == cudaMemcpy(host_x.get(), x, n * sizeof(float),
cudaMemcpyDeviceToHost));
EXPECT_GE(host_x[i] + DELTA, 3.14159f * i);
EXPECT_LE(host_x[i] - DELTA, 3.14159f * i);
}
}
void MultiStreamCompute(float **data, float **second_data,
const platform::CUDADeviceContext &ctx) {
// multi-streams
AllocationPtr allocation_ptr = Alloc(ctx, N * sizeof(float));
EXPECT_GE(allocation_ptr->size(), N * sizeof(float));
*data = reinterpret_cast<float *>(allocation_ptr->ptr());
kernel<<<1, 64, 0, ctx.stream()>>>(*data, N);
// allocate and compute on same stream again
allocation_ptr = Alloc(ctx, N * sizeof(float));
EXPECT_GE(allocation_ptr->size(), N * sizeof(float));
*second_data = reinterpret_cast<float *>(allocation_ptr->ptr());
kernel<<<1, 64, 0, ctx.stream()>>>(*second_data, N);
}
TEST(Malloc, CUDADeviceContextMultiStream) {
auto place = platform::CUDAPlace(0);
EXPECT_TRUE(cudaSuccess == cudaSetDevice(0));
AllocationPtr main_stream_alloc_ptr = Alloc(place, N * sizeof(float));
EXPECT_GE(main_stream_alloc_ptr->size(), N * sizeof(float));
float *main_stream_data =
reinterpret_cast<float *>(main_stream_alloc_ptr->ptr());
float *data[NUM_STREAMS];
float *second_data[NUM_STREAMS];
CudaDevCtxVec dev_ctx;
// default stream
kernel<<<1, 64>>>(main_stream_data, N);
main_stream_alloc_ptr.reset();
for (int i = 0; i < NUM_STREAMS; ++i) {
dev_ctx.push_back(std::unique_ptr<platform::CUDADeviceContext>(
new platform::CUDADeviceContext(place)));
MultiStreamCompute(&data[i], &second_data[i], *dev_ctx[i]);
}
EXPECT_TRUE(cudaSuccess == cudaDeviceSynchronize());
for (int i = 0; i < NUM_STREAMS; ++i) {
CheckKernelOutput(data[i], N);
CheckKernelOutput(second_data[i], N);
}
}
TEST(Malloc, CUDADeviceContextMultiThreadMultiStream) {
auto place = platform::CUDAPlace(0);
EXPECT_TRUE(cudaSuccess == cudaSetDevice(0));
AllocationPtr main_stream_alloc_ptr = Alloc(place, N * sizeof(float));
EXPECT_GE(main_stream_alloc_ptr->size(), N * sizeof(float));
float *main_stream_data =
reinterpret_cast<float *>(main_stream_alloc_ptr->ptr());
float *data[NUM_STREAMS];
float *second_data[NUM_STREAMS];
CudaDevCtxVec dev_ctx;
std::vector<std::thread> threads;
// default stream
kernel<<<1, 64>>>(main_stream_data, N);
main_stream_alloc_ptr.reset();
for (int i = 0; i < NUM_STREAMS; ++i) {
dev_ctx.push_back(std::unique_ptr<platform::CUDADeviceContext>(
new platform::CUDADeviceContext(place)));
threads.push_back(std::thread(MultiStreamCompute, &data[i], &second_data[i],
std::cref(*dev_ctx[i])));
}
for (int i = 0; i < NUM_STREAMS; ++i) {
threads[i].join();
}
EXPECT_TRUE(cudaSuccess == cudaDeviceSynchronize());
for (int i = 0; i < NUM_STREAMS; ++i) {
CheckKernelOutput(data[i], N);
CheckKernelOutput(second_data[i], N);
}
}
TEST(Malloc, AllocZero) {
auto place = platform::CUDAPlace(0);
AllocationPtr allocation_ptr = Alloc(place, 0);
EXPECT_GE(allocation_ptr->size(), 0);
}
} // namespace memory
} // namespace paddle

@ -28,6 +28,7 @@
#include <limits>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/operators/deformable_psroi_pooling_op.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/math_function.h"
@ -231,10 +232,8 @@ class DeformablePSROIPoolCUDAKernel : public framework::OpKernel<T> {
}
auto& dev_ctx = ctx.cuda_device_context();
auto& allocator =
platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx);
int bytes = roi_batch_id_list.numel() * sizeof(int);
auto roi_ptr = allocator.Allocate(bytes);
auto roi_ptr = memory::Alloc(dev_ctx, bytes);
int* roi_id_data = reinterpret_cast<int*>(roi_ptr->ptr());
const auto gplace = boost::get<platform::CUDAPlace>(ctx.GetPlace());
memory::Copy(gplace, roi_id_data, cplace, roi_batch_id_data, bytes,
@ -499,10 +498,8 @@ class DeformablePSROIPoolGradCUDAKernel : public framework::OpKernel<T> {
}
}
auto& allocator =
platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx);
int bytes = roi_batch_id_list.numel() * sizeof(int);
auto roi_ptr = allocator.Allocate(bytes);
auto roi_ptr = memory::Alloc(dev_ctx, bytes);
int* roi_id_data = reinterpret_cast<int*>(roi_ptr->ptr());
const auto gplace = boost::get<platform::CUDAPlace>(ctx.GetPlace());
memory::Copy(gplace, roi_id_data, cplace, roi_batch_id_data, bytes,

@ -11,7 +11,7 @@ limitations under the License. */
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/operators/detection/box_coder_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
@ -174,10 +174,8 @@ class BoxCoderCUDAKernel : public framework::OpKernel<T> {
int grid = (row * col + block - 1) / block;
auto& device_ctx = context.cuda_device_context();
auto& allocator =
platform::DeviceTemporaryAllocator::Instance().Get(device_ctx);
int bytes = var_size * sizeof(float);
auto dev_var = allocator.Allocate(bytes);
auto dev_var = memory::Alloc(device_ctx, bytes);
float* dev_var_data = reinterpret_cast<float*>(dev_var->ptr());
auto cplace = platform::CPUPlace();
const auto gplace = boost::get<platform::CUDAPlace>(context.GetPlace());

@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/operators/detection/yolo_box_op.h"
#include "paddle/fluid/operators/math/math_function.h"
@ -84,10 +85,8 @@ class YoloBoxOpCUDAKernel : public framework::OpKernel<T> {
int input_size = downsample_ratio * h;
auto& dev_ctx = ctx.cuda_device_context();
auto& allocator =
platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx);
int bytes = sizeof(int) * anchors.size();
auto anchors_ptr = allocator.Allocate(sizeof(int) * anchors.size());
auto anchors_ptr = memory::Alloc(dev_ctx, sizeof(int) * anchors.size());
int* anchors_data = reinterpret_cast<int*>(anchors_ptr->ptr());
const auto gplace = boost::get<platform::CUDAPlace>(ctx.GetPlace());
const auto cplace = platform::CPUPlace();

@ -16,6 +16,7 @@ limitations under the License. */
#include <vector>
#include "dgc/dgc.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/operators/elementwise/elementwise_add_op.h"
namespace paddle {
@ -112,9 +113,7 @@ class DGCOpKernel : public framework::OpKernel<T> {
framework::DDim{2 * k}, ctx.GetPlace());
int buf_size = paddle::communication::dgc::get_buffer_size(k);
auto& allocator = platform::DeviceTemporaryAllocator::Instance().Get(
ctx.GetPlace(), dev_ctx.stream());
auto tmp_ious_data = allocator.Allocate(buf_size);
auto tmp_ious_data = memory::Alloc(dev_ctx, buf_size);
void* buf = reinterpret_cast<void*>(tmp_ious_data->ptr());
if (!paddle::communication::dgc::k_select(

@ -18,6 +18,7 @@ limitations under the License. */
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/operators/math/blas.h"
namespace paddle {
@ -184,9 +185,7 @@ class FakeMovingAverageAbsMaxKernelBase : public framework::OpKernel<T> {
// training
auto* in_accum = context.Input<framework::Tensor>("InAccum");
auto* in_state = context.Input<framework::Tensor>("InState");
auto& allocator =
platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx);
auto cur_scale = allocator.Allocate(1 * sizeof(T));
auto cur_scale = memory::Alloc(dev_ctx, sizeof(T));
T* cur_scale_data = static_cast<T*>(cur_scale->ptr());
FindAbsMaxFunctor<DeviceContext, T>()(dev_ctx, in->data<T>(), in->numel(),
@ -251,9 +250,7 @@ class MovingAverageAbsMaxScaleKernel : public framework::OpKernel<T> {
// training
auto* in_accum = context.Input<framework::Tensor>("InAccum");
auto* in_state = context.Input<framework::Tensor>("InState");
auto& allocator =
platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx);
auto cur_scale = allocator.Allocate(1 * sizeof(T));
auto cur_scale = memory::Alloc(dev_ctx, sizeof(T));
T* cur_scale_data = static_cast<T*>(cur_scale->ptr());
FindAbsMaxFunctor<DeviceContext, T>()(dev_ctx, in->data<T>(), in->numel(),

@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/fluid/framework/dim.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/place.h"
@ -142,9 +143,8 @@ void GPUGatherNd(const framework::ExecutionContext& context,
}
auto& dev_ctx = context.cuda_device_context();
auto& allocator = platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx);
int bytes = input_dims_size * sizeof(int);
auto p_input_dims = allocator.Allocate(bytes);
auto p_input_dims = memory::Alloc(dev_ctx, bytes);
int* g_input_dims = reinterpret_cast<int*>(p_input_dims->ptr());
memory::Copy(gplace, g_input_dims, cplace, v_input_dims.data(), bytes,
ctx.stream());

@ -15,6 +15,7 @@ limitations under the License. */
#include <algorithm>
#include <vector>
#include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/operators/math/concat_and_split.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/float16.h"
@ -264,8 +265,7 @@ class ConcatFunctor<platform::CUDADeviceContext, T> {
const T** dev_ins_data = nullptr;
if (!has_same_shape || in_num < 2 || in_num > 4) {
tmp_dev_ins_data =
platform::DeviceTemporaryAllocator::Instance().Get(context).Allocate(
inputs_data.size() * sizeof(T*));
memory::Alloc(context, inputs_data.size() * sizeof(T*));
memory::Copy(boost::get<platform::CUDAPlace>(context.GetPlace()),
tmp_dev_ins_data->ptr(), platform::CPUPlace(),
static_cast<void*>(inputs_data.data()),
@ -292,8 +292,7 @@ class ConcatFunctor<platform::CUDADeviceContext, T> {
}
} else {
auto tmp_dev_ins_col_data =
platform::DeviceTemporaryAllocator::Instance().Get(context).Allocate(
inputs_col.size() * sizeof(int));
memory::Alloc(context, inputs_col.size() * sizeof(int));
memory::Copy(boost::get<platform::CUDAPlace>(context.GetPlace()),
tmp_dev_ins_col_data->ptr(), platform::CPUPlace(),
static_cast<void*>(inputs_col.data()),
@ -356,8 +355,7 @@ class SplitFunctor<platform::CUDADeviceContext, T> {
T** dev_out_gpu_data = nullptr;
if (!has_same_shape || o_num < 2 || o_num > 4) {
tmp_dev_outs_data =
platform::DeviceTemporaryAllocator::Instance().Get(context).Allocate(
outputs_data.size() * sizeof(T*));
memory::Alloc(context, outputs_data.size() * sizeof(T*));
memory::Copy(boost::get<platform::CUDAPlace>(context.GetPlace()),
tmp_dev_outs_data->ptr(), platform::CPUPlace(),
reinterpret_cast<void*>(outputs_data.data()),
@ -384,8 +382,9 @@ class SplitFunctor<platform::CUDADeviceContext, T> {
}
} else {
auto tmp_dev_ins_col_data =
platform::DeviceTemporaryAllocator::Instance().Get(context).Allocate(
outputs_cols.size() * sizeof(int));
memory::Alloc(context,
outputs_cols.size() * sizeof(int));
memory::Copy(boost::get<platform::CUDAPlace>(context.GetPlace()),
tmp_dev_ins_col_data->ptr(), platform::CPUPlace(),
reinterpret_cast<void*>(outputs_cols.data()),

@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/mean_iou_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
@ -116,9 +117,7 @@ class MeanIoUCUDAOpKernel : public framework::OpKernel<T> {
auto out_correct_t = EigenTensor<int, 1>::From(*out_correct);
// Temporary memory
auto& allocator =
platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx);
auto tmp_ious_data = allocator.Allocate(num_classes * sizeof(float));
auto tmp_ious_data = memory::Alloc(dev_ctx, num_classes * sizeof(float));
float* ious_data = static_cast<float*>(tmp_ious_data->ptr());
// Init out_wrong, out_correct and out_mean_iou

@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/operators/roi_align_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
@ -272,10 +272,8 @@ class GPUROIAlignOpKernel : public framework::OpKernel<T> {
}
}
auto& dev_ctx = ctx.cuda_device_context();
auto& allocator =
platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx);
int bytes = roi_batch_id_list.numel() * sizeof(int);
auto roi_ptr = allocator.Allocate(bytes);
auto roi_ptr = memory::Alloc(dev_ctx, bytes);
int* roi_id_data = reinterpret_cast<int*>(roi_ptr->ptr());
const auto gplace = boost::get<platform::CUDAPlace>(ctx.GetPlace());
memory::Copy(gplace, roi_id_data, cplace, roi_batch_id_data, bytes,
@ -322,9 +320,8 @@ class GPUROIAlignGradOpKernel : public framework::OpKernel<T> {
}
}
auto& dev_ctx = ctx.cuda_device_context();
auto& allocator =
platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx);
auto roi_ptr = allocator.Allocate(roi_batch_id_list.numel() * sizeof(int));
auto roi_ptr =
memory::Alloc(dev_ctx, roi_batch_id_list.numel() * sizeof(int));
int* roi_id_data = reinterpret_cast<int*>(roi_ptr->ptr());
int bytes = roi_batch_id_list.numel() * sizeof(int);
const auto gplace = boost::get<platform::CUDAPlace>(ctx.GetPlace());

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save