commit
966a6ce6db
@ -0,0 +1,91 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/operators/math/unpooling.h"
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
namespace math {
|
||||||
|
template <typename T>
|
||||||
|
class Unpool2dMaxFunctor<platform::CPUPlace, T> {
|
||||||
|
public:
|
||||||
|
void operator()(const platform::DeviceContext& context,
|
||||||
|
const framework::Tensor& input,
|
||||||
|
const framework::Tensor& indices, framework::Tensor* output) {
|
||||||
|
const int batch_size = input.dims()[0];
|
||||||
|
const int input_height = input.dims()[2];
|
||||||
|
const int input_width = input.dims()[3];
|
||||||
|
const int output_channels = output->dims()[1];
|
||||||
|
const int output_height = output->dims()[2];
|
||||||
|
const int output_width = output->dims()[3];
|
||||||
|
int input_feasize = input_height * input_width;
|
||||||
|
int output_feasize = output_height * output_width;
|
||||||
|
const T* input_data = input.data<T>();
|
||||||
|
const int* indices_data = indices.data<int>();
|
||||||
|
T* output_data = output->mutable_data<T>(context.GetPlace());
|
||||||
|
for (int b = 0; b < batch_size; ++b) {
|
||||||
|
for (int c = 0; c < output_channels; ++c) {
|
||||||
|
for (int i = 0; i < input_feasize; ++i) {
|
||||||
|
int index = indices_data[i];
|
||||||
|
PADDLE_ENFORCE(index < output_feasize, "err index in unpooling!");
|
||||||
|
output_data[index] = input_data[i];
|
||||||
|
}
|
||||||
|
input_data += input_feasize;
|
||||||
|
indices_data += input_feasize;
|
||||||
|
output_data += output_feasize;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
template <class T>
|
||||||
|
class Unpool2dMaxGradFunctor<platform::CPUPlace, T> {
|
||||||
|
public:
|
||||||
|
void operator()(const platform::DeviceContext& context,
|
||||||
|
const framework::Tensor& input,
|
||||||
|
const framework::Tensor& indices,
|
||||||
|
const framework::Tensor& output,
|
||||||
|
const framework::Tensor& output_grad,
|
||||||
|
framework::Tensor* input_grad) {
|
||||||
|
const int batch_size = input.dims()[0];
|
||||||
|
const int input_height = input.dims()[2];
|
||||||
|
const int input_width = input.dims()[3];
|
||||||
|
const int output_channels = output.dims()[1];
|
||||||
|
const int output_height = output.dims()[2];
|
||||||
|
const int output_width = output.dims()[3];
|
||||||
|
int input_feasize = input_height * input_width;
|
||||||
|
int output_feasize = output_height * output_width;
|
||||||
|
const int* indices_data = indices.data<int>();
|
||||||
|
const T* output_grad_data = output_grad.data<T>();
|
||||||
|
T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
|
||||||
|
|
||||||
|
for (int b = 0; b < batch_size; ++b) {
|
||||||
|
for (int c = 0; c < output_channels; ++c) {
|
||||||
|
for (int i = 0; i < input_feasize; ++i) {
|
||||||
|
int index = indices_data[i];
|
||||||
|
PADDLE_ENFORCE(index < output_feasize, "err index in unpooling!");
|
||||||
|
input_grad_data[i] = output_grad_data[index];
|
||||||
|
}
|
||||||
|
input_grad_data += input_feasize;
|
||||||
|
indices_data += input_feasize;
|
||||||
|
output_grad_data += output_feasize;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
template class Unpool2dMaxGradFunctor<platform::CPUPlace, float>;
|
||||||
|
template class Unpool2dMaxGradFunctor<platform::CPUPlace, double>;
|
||||||
|
template class Unpool2dMaxFunctor<platform::CPUPlace, float>;
|
||||||
|
template class Unpool2dMaxFunctor<platform::CPUPlace, double>;
|
||||||
|
} // namespace math
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,134 @@
|
|||||||
|
/* Copyright (c) 2016 paddlepaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/operators/math/unpooling.h"
|
||||||
|
#include "paddle/platform/cuda_helper.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
namespace math {
|
||||||
|
template <typename T>
|
||||||
|
__global__ void KernelUnpool2dMax(const int nthreads, const T* input_data,
|
||||||
|
const int* indices_data,
|
||||||
|
const int input_height, const int input_width,
|
||||||
|
const int channels, T* output_data,
|
||||||
|
const int output_height,
|
||||||
|
const int output_width) {
|
||||||
|
int in_n_stride = input_height * input_width * channels;
|
||||||
|
int in_c_stride = input_height * input_width;
|
||||||
|
int out_n_stride = output_height * output_width * channels;
|
||||||
|
int out_c_stride = output_height * output_width;
|
||||||
|
int index = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
int offset = blockDim.x * gridDim.x;
|
||||||
|
for (int i = index; i < nthreads; i += offset) {
|
||||||
|
int bidx = i / in_n_stride;
|
||||||
|
int boffset = i % in_n_stride;
|
||||||
|
int cidx = boffset / in_c_stride;
|
||||||
|
int out_offset = bidx * out_n_stride + cidx * out_c_stride;
|
||||||
|
int out_index = indices_data[i];
|
||||||
|
PADDLE_ASSERT(out_index < out_c_stride);
|
||||||
|
output_data[out_offset + out_index] = input_data[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
template <typename T>
|
||||||
|
__global__ void KernelUnpool2dMaxGrad(
|
||||||
|
const int nthreads, const T* input_data, const int* indices_data,
|
||||||
|
const int input_height, const int input_width, const int channels,
|
||||||
|
const T* output_data, const T* output_grad, const int output_height,
|
||||||
|
const int output_width, T* input_grad) {
|
||||||
|
int in_n_stride = input_height * input_width * channels;
|
||||||
|
int in_c_stride = input_height * input_width;
|
||||||
|
int out_n_stride = output_height * output_width * channels;
|
||||||
|
int out_c_stride = output_height * output_width;
|
||||||
|
int index = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
int offset = blockDim.x * gridDim.x;
|
||||||
|
for (int i = index; i < nthreads; i += offset) {
|
||||||
|
int bidx = i / in_n_stride;
|
||||||
|
int boffset = i % in_n_stride;
|
||||||
|
int cidx = boffset / in_c_stride;
|
||||||
|
int out_offset = bidx * out_n_stride + cidx * out_c_stride;
|
||||||
|
int out_index = indices_data[i];
|
||||||
|
PADDLE_ASSERT(out_index < out_c_stride);
|
||||||
|
input_grad[i] = output_grad[out_offset + out_index];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
/*
|
||||||
|
* All tensors are in NCHW format.
|
||||||
|
*/
|
||||||
|
template <typename T>
|
||||||
|
class Unpool2dMaxFunctor<platform::GPUPlace, T> {
|
||||||
|
public:
|
||||||
|
void operator()(const platform::DeviceContext& context,
|
||||||
|
const framework::Tensor& input,
|
||||||
|
const framework::Tensor& indices, framework::Tensor* output) {
|
||||||
|
const int batch_size = input.dims()[0];
|
||||||
|
const int input_height = input.dims()[2];
|
||||||
|
const int input_width = input.dims()[3];
|
||||||
|
const int output_channels = output->dims()[1];
|
||||||
|
const int output_height = output->dims()[2];
|
||||||
|
const int output_width = output->dims()[3];
|
||||||
|
const T* input_data = input.data<T>();
|
||||||
|
const int* indices_data = indices.data<int>();
|
||||||
|
T* output_data = output->mutable_data<T>(context.GetPlace());
|
||||||
|
int threads = 1024;
|
||||||
|
int grid = (input.numel() + threads - 1) / threads;
|
||||||
|
KernelUnpool2dMax<
|
||||||
|
T><<<grid, threads, 0,
|
||||||
|
reinterpret_cast<const platform::CUDADeviceContext&>(context)
|
||||||
|
.stream()>>>(input.numel(), input_data, indices_data,
|
||||||
|
input_height, input_width, output_channels,
|
||||||
|
output_data, output_height, output_width);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
/*
|
||||||
|
* All tensors are in NCHW format.
|
||||||
|
*/
|
||||||
|
template <typename T>
|
||||||
|
class Unpool2dMaxGradFunctor<platform::GPUPlace, T> {
|
||||||
|
public:
|
||||||
|
void operator()(const platform::DeviceContext& context,
|
||||||
|
const framework::Tensor& input,
|
||||||
|
const framework::Tensor& indices,
|
||||||
|
const framework::Tensor& output,
|
||||||
|
const framework::Tensor& output_grad,
|
||||||
|
framework::Tensor* input_grad) {
|
||||||
|
const int batch_size = input.dims()[0];
|
||||||
|
const int input_height = input.dims()[2];
|
||||||
|
const int input_width = input.dims()[3];
|
||||||
|
const int output_channels = output.dims()[1];
|
||||||
|
const int output_height = output.dims()[2];
|
||||||
|
const int output_width = output.dims()[3];
|
||||||
|
const T* input_data = input.data<T>();
|
||||||
|
const int* indices_data = indices.data<int>();
|
||||||
|
const T* output_data = output.data<T>();
|
||||||
|
const T* output_grad_data = output_grad.data<T>();
|
||||||
|
T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
|
||||||
|
int threads = 1024;
|
||||||
|
int grid = (input.numel() + threads - 1) / threads;
|
||||||
|
KernelUnpool2dMaxGrad<
|
||||||
|
T><<<grid, threads, 0,
|
||||||
|
reinterpret_cast<const platform::CUDADeviceContext&>(context)
|
||||||
|
.stream()>>>(input.numel(), input_data, indices_data,
|
||||||
|
input_height, input_width, output_channels,
|
||||||
|
output_data, output_grad_data, output_height,
|
||||||
|
output_width, input_grad_data);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
template class Unpool2dMaxGradFunctor<platform::GPUPlace, float>;
|
||||||
|
template class Unpool2dMaxGradFunctor<platform::GPUPlace, double>;
|
||||||
|
template class Unpool2dMaxFunctor<platform::GPUPlace, float>;
|
||||||
|
template class Unpool2dMaxFunctor<platform::GPUPlace, double>;
|
||||||
|
} // namespace math
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,40 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
#include "paddle/framework/tensor.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
namespace math {
|
||||||
|
template <typename Place, typename T>
|
||||||
|
class Unpool2dMaxFunctor {
|
||||||
|
public:
|
||||||
|
void operator()(const platform::DeviceContext& context,
|
||||||
|
const framework::Tensor& input,
|
||||||
|
const framework::Tensor& indices, framework::Tensor* output);
|
||||||
|
};
|
||||||
|
template <typename Place, class T>
|
||||||
|
class Unpool2dMaxGradFunctor {
|
||||||
|
public:
|
||||||
|
void operator()(const platform::DeviceContext& context,
|
||||||
|
const framework::Tensor& input,
|
||||||
|
const framework::Tensor& indices,
|
||||||
|
const framework::Tensor& output,
|
||||||
|
const framework::Tensor& output_grad,
|
||||||
|
framework::Tensor* input_grad);
|
||||||
|
};
|
||||||
|
} // namespace math
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,23 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
Indicesou may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/operators/unpool_op.h"
|
||||||
|
|
||||||
|
namespace ops = paddle::operators;
|
||||||
|
REGISTER_OP_GPU_KERNEL(unpool,
|
||||||
|
ops::UnpoolKernel<paddle::platform::GPUPlace, float>,
|
||||||
|
ops::UnpoolKernel<paddle::platform::GPUPlace, double>);
|
||||||
|
REGISTER_OP_GPU_KERNEL(
|
||||||
|
unpool_grad, ops::UnpoolGradKernel<paddle::platform::GPUPlace, float>,
|
||||||
|
ops::UnpoolGradKernel<paddle::platform::GPUPlace, double>);
|
@ -0,0 +1,71 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
Indicesou may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "paddle/framework/op_registry.h"
|
||||||
|
#include "paddle/operators/math/math_function.h"
|
||||||
|
#include "paddle/operators/math/unpooling.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
template <typename Place, typename T>
|
||||||
|
class UnpoolKernel : public framework::OpKernel<T> {
|
||||||
|
public:
|
||||||
|
void Compute(const framework::ExecutionContext& context) const override {
|
||||||
|
const framework::Tensor* in_x = context.Input<framework::Tensor>("X");
|
||||||
|
const framework::Tensor* in_y = context.Input<framework::Tensor>("Indices");
|
||||||
|
auto* out = context.Output<framework::Tensor>("Out");
|
||||||
|
std::string unpooling_type = context.Attr<std::string>("unpooling_type");
|
||||||
|
std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
|
||||||
|
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
|
||||||
|
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
|
||||||
|
T* output_data = out->mutable_data<T>(context.GetPlace());
|
||||||
|
if (output_data) {
|
||||||
|
math::SetConstant<Place, T> set_zero;
|
||||||
|
set_zero(context.device_context(), out, static_cast<T>(0));
|
||||||
|
}
|
||||||
|
math::Unpool2dMaxFunctor<Place, T> unpool2d_max_forward;
|
||||||
|
unpool2d_max_forward(context.device_context(), *in_x, *in_y, out);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
template <typename Place, typename T>
|
||||||
|
class UnpoolGradKernel : public framework::OpKernel<T> {
|
||||||
|
public:
|
||||||
|
void Compute(const framework::ExecutionContext& context) const override {
|
||||||
|
const framework::Tensor* in_x = context.Input<framework::Tensor>("X");
|
||||||
|
const framework::Tensor* in_y = context.Input<framework::Tensor>("Indices");
|
||||||
|
const framework::Tensor* out = context.Input<framework::Tensor>("Out");
|
||||||
|
const framework::Tensor* out_grad =
|
||||||
|
context.Input<framework::Tensor>(framework::GradVarName("Out"));
|
||||||
|
framework::Tensor* in_x_grad =
|
||||||
|
context.Output<framework::Tensor>(framework::GradVarName("X"));
|
||||||
|
std::string unpooling_type = context.Attr<std::string>("unpooling_type");
|
||||||
|
std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
|
||||||
|
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
|
||||||
|
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
|
||||||
|
|
||||||
|
auto& device_ctx = context.device_context();
|
||||||
|
math::SetConstant<Place, T> zero;
|
||||||
|
if (in_x_grad) {
|
||||||
|
in_x_grad->mutable_data<T>(context.GetPlace());
|
||||||
|
zero(device_ctx, in_x_grad, static_cast<T>(0));
|
||||||
|
}
|
||||||
|
math::Unpool2dMaxGradFunctor<Place, T> unpool2d_max_backward;
|
||||||
|
unpool2d_max_backward(context.device_context(), *in_x, *in_y, *out,
|
||||||
|
*out_grad, in_x_grad);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,83 @@
|
|||||||
|
import unittest
|
||||||
|
import numpy as np
|
||||||
|
from op_test import OpTest
|
||||||
|
|
||||||
|
|
||||||
|
def unpool2dmax_forward_naive(input, indices, ksize, strides, paddings):
|
||||||
|
s0, s1, s2, s3 = input.shape
|
||||||
|
out_hsize = (s2 - 1) * strides[0] - 2 * paddings[0] + ksize[0]
|
||||||
|
out_wsize = (s2 - 1) * strides[1] - 2 * paddings[1] + ksize[1]
|
||||||
|
out = np.zeros((s0, s1, out_hsize, out_wsize))
|
||||||
|
for nidx in xrange(s0):
|
||||||
|
for cidx in xrange(s1):
|
||||||
|
for h in xrange(s2):
|
||||||
|
for w in xrange(s3):
|
||||||
|
index = indices[nidx, cidx, h, w]
|
||||||
|
hidx = (index - index % out_wsize) / out_wsize
|
||||||
|
widx = index % out_wsize
|
||||||
|
out[nidx, cidx, int(hidx), int(widx)] = \
|
||||||
|
input[nidx, cidx, h, w]
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class TestUnpoolOp(OpTest):
|
||||||
|
def setUp(self):
|
||||||
|
self.op_type = "unpool"
|
||||||
|
self.init_test_case()
|
||||||
|
pre_input = np.random.random(self.shape).astype("float32")
|
||||||
|
nsize, csize, hsize, wsize = pre_input.shape
|
||||||
|
hsize_out = (hsize - self.ksize[0] + 2 * self.paddings[0]) / \
|
||||||
|
self.strides[0] + 1
|
||||||
|
wsize_out = (wsize - self.ksize[1] + 2 * self.paddings[1]) / \
|
||||||
|
self.strides[1] + 1
|
||||||
|
input = np.zeros((nsize, csize, hsize_out, wsize_out))
|
||||||
|
indices = np.zeros((nsize, csize, hsize_out, wsize_out))
|
||||||
|
for i in xrange(hsize_out):
|
||||||
|
for j in xrange(wsize_out):
|
||||||
|
r_start = np.max((i * self.strides[0] - self.paddings[0], 0))
|
||||||
|
r_end = np.min((i * self.strides[0] + self.ksize[0] - \
|
||||||
|
self.paddings[0], hsize))
|
||||||
|
c_start = np.max((j * self.strides[1] - self.paddings[1], 0))
|
||||||
|
c_end = np.min((j * self.strides[1] + self.ksize[1] - \
|
||||||
|
self.paddings[1], wsize))
|
||||||
|
for nidx in xrange(nsize):
|
||||||
|
for cidx in xrange(csize):
|
||||||
|
x_masked = pre_input[nidx, cidx, r_start:r_end, \
|
||||||
|
c_start:c_end]
|
||||||
|
input[nidx, cidx, i, j] = x_masked.max()
|
||||||
|
arg = x_masked.argmax()
|
||||||
|
indices[nidx, cidx, i, j] = \
|
||||||
|
(r_start + arg / self.ksize[1]) * wsize + \
|
||||||
|
c_start + arg % self.ksize[1]
|
||||||
|
output = self.unpool2d_forward_naive(input, indices, self.ksize, \
|
||||||
|
self.strides, self.paddings).astype("float32")
|
||||||
|
self.inputs = {
|
||||||
|
'X': input.astype('float32'),
|
||||||
|
'Indices': indices.astype('int32')
|
||||||
|
}
|
||||||
|
self.attrs = {
|
||||||
|
'strides': self.strides,
|
||||||
|
'paddings': self.paddings,
|
||||||
|
'ksize': self.ksize,
|
||||||
|
'unpooling_type': self.unpooling_type,
|
||||||
|
}
|
||||||
|
self.outputs = {'Out': output.astype('float32')}
|
||||||
|
|
||||||
|
def test_check_output(self):
|
||||||
|
self.check_output()
|
||||||
|
|
||||||
|
def test_check_grad(self):
|
||||||
|
self.check_grad(['X'], 'Out')
|
||||||
|
|
||||||
|
def init_test_case(self):
|
||||||
|
self.unpool2d_forward_naive = unpool2dmax_forward_naive
|
||||||
|
self.unpooling_type = "max"
|
||||||
|
self.shape = [6, 4, 5, 5]
|
||||||
|
self.ksize = [3, 3]
|
||||||
|
self.strides = [2, 2]
|
||||||
|
self.paddings = [0, 0]
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
Loading…
Reference in new issue