fix test error

shanyi15-patch-2
Kexin Zhao 7 years ago
parent e4de5dc347
commit a13ec3432a

@ -134,8 +134,8 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
platform::CUDAPlace gpu = boost::get<platform::CUDAPlace>(ctx.GetPlace());
cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes);
// ------------------- cudnn conv forward ---------------------
T alpha = static_cast<T>(1.0f);
T beta = static_cast<T>(0.0f);
typename platform::CudnnDataType<T>::ScalingParamType alpha = 1.0f,
beta = 0.0f;
for (int i = 0; i < groups; i++) {
PADDLE_ENFORCE(platform::dynload::cudnnConvolutionForward(
handle, &alpha, cudnn_input_desc, input_data + i * group_offset_in,
@ -321,7 +321,7 @@ namespace plat = paddle::platform;
REGISTER_OP_KERNEL(conv2d, CUDNN, plat::CUDAPlace,
paddle::operators::CUDNNConvOpKernel<float>,
paddle::operators::CUDNNConvOpKernel<double>,
paddle::operators::CUDNNConvOpKernel < plat::float16);
paddle::operators::CUDNNConvOpKernel<plat::float16>);
REGISTER_OP_KERNEL(conv2d_grad, CUDNN, plat::CUDAPlace,
paddle::operators::CUDNNConvGradOpKernel<float>,
paddle::operators::CUDNNConvGradOpKernel<double>);

@ -85,13 +85,14 @@ template <>
class CudnnDataType<float16> {
public:
static const cudnnDataType_t type = CUDNN_DATA_HALF;
typedef const float16 ScalingParamType;
// The scaling param type is float for HALF and FLOAT tensors
typedef const float ScalingParamType;
static ScalingParamType* kOne() {
static ScalingParamType v = static_cast<float16>(1.0);
static ScalingParamType v = 1.0;
return &v;
}
static ScalingParamType* kZero() {
static ScalingParamType v = static_cast<float16>(0.0);
static ScalingParamType v = 0.0;
return &v;
}
};

@ -79,7 +79,7 @@ class TestConv2dOp(OpTest):
input = np.random.random(self.input_size).astype(self.dtype)
filter = np.random.random(self.filter_size).astype(self.dtype)
output = conv2d_forward_naive(self.input, self.filter, self.groups,
output = conv2d_forward_naive(input, filter, self.groups,
conv2d_param).astype(self.dtype)
# numpy float16 is binded to paddle::platform::float16
@ -88,9 +88,12 @@ class TestConv2dOp(OpTest):
# uint16_t in paddle or np.uint16 in numpy, which are
# themselves binded together.
self.inputs = {
'Input': input.view(np.uint16)
if self.dtype == np.float16 else input,
'Filter': create_view(filter)
#'Input': (input.view(np.uint16)
# if self.dtype == np.float16 else input),
#'Filter': (filter.view(np.uint16)
# if self.dtype == np.float16 else filter)
'Input': OpTest.create_view(input),
'Filter': OpTest.create_view(filter)
}
self.attrs = {
'strides': self.stride,
@ -254,7 +257,7 @@ class TestFP16CUDNN(TestCUDNN):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=1e-1)
self.check_output_with_place(place, atol=2e-2)
def test_check_grad(self):
pass

Loading…
Cancel
Save