Polish code

panyx0718-patch-1
Yu Yang 6 years ago
parent e25240c22a
commit 29f66c2408

@ -167,7 +167,7 @@ class CudnnHolder {
if (required_workspace_len > WorkspaceSize()) {
ReallocateWorkspace(required_workspace_len);
}
cudnn_func(workspace_->ptr());
cudnn_func(WorkspacePtr());
}
~CudnnHolder() { PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_)); }
@ -181,6 +181,14 @@ class CudnnHolder {
}
}
void* WorkspacePtr() const {
if (workspace_ == nullptr) {
return nullptr;
} else {
return workspace_->ptr();
}
}
void ReallocateWorkspace(size_t required_workspace_len) {
if (required_workspace_len <= WorkspaceSize()) {
return;

@ -99,7 +99,7 @@ struct CastToPyBufferImpl<true, I, ARGS...> {
py_buffer->shape = reinterpret_cast<Py_ssize_t *>(
malloc(sizeof(Py_ssize_t) * tensor.dims().size()));
for (size_t i = 0; i < tensor.dims().size(); ++i) {
for (int i = 0; i < tensor.dims().size(); ++i) {
py_buffer->shape[i] = tensor.dims()[i];
}

Loading…
Cancel
Save