|
|
|
@ -17,7 +17,7 @@ limitations under the License. */
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
|
|
|
|
|
ThreadLocalD<std::vector<MemoryHandle *>> ConvBaseProjection::convMem_;
|
|
|
|
|
ThreadLocalD<std::vector<MemoryHandlePtr>> ConvBaseProjection::convMem_;
|
|
|
|
|
|
|
|
|
|
ConvBaseProjection::ConvBaseProjection(const ProjectionConfig &config,
|
|
|
|
|
ParameterPtr parameter,
|
|
|
|
@ -175,18 +175,18 @@ void ConvBaseProjection::reshape(int batchSize) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void *ConvBaseProjection::getSpaceBytes(size_t size) {
|
|
|
|
|
std::vector<MemoryHandle *> &convMem = *convMem_;
|
|
|
|
|
std::vector<MemoryHandlePtr> &convMem = *convMem_;
|
|
|
|
|
if (convMem.empty()) {
|
|
|
|
|
int numDevices = hl_get_device_count();
|
|
|
|
|
convMem.resize(numDevices);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int devId = hl_get_device();
|
|
|
|
|
MemoryHandle **localMem = &(convMem[devId]);
|
|
|
|
|
if (NULL == *localMem || size > (*localMem)->getAllocSize()) {
|
|
|
|
|
*localMem = new GpuMemoryHandle(size);
|
|
|
|
|
MemoryHandlePtr localMem = convMem[devId];
|
|
|
|
|
if (NULL == localMem || size > localMem->getAllocSize()) {
|
|
|
|
|
localMem = std::make_shared<GpuMemoryHandle>(size);
|
|
|
|
|
}
|
|
|
|
|
return (*localMem)->getBuf();
|
|
|
|
|
return localMem->getBuf();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ConvBaseProjection::~ConvBaseProjection() {
|
|
|
|
|