fix code format and some bug

revert-4814-Add_sequence_project_op
chengduoZH 8 years ago
parent 6326c40d27
commit bee95fc891

@ -26,7 +26,6 @@ class MaxPool2dWithIndexFunctor<platform::CPUPlace, T> {
framework::Tensor& mask, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings) {
const int batch_size = input.dims()[0];
const int input_height = input.dims()[2];
const int input_width = input.dims()[3];
const int output_channels = output.dims()[1];
@ -112,13 +111,13 @@ class MaxPool2dWithIndexGradFunctor<platform::CPUPlace, T> {
input_grad_data[input_idx] += output_grad_data[output_idx];
}
}
}
// offset
input_grad_data += input_stride;
output_grad_data += output_stride;
mask_data += output_stride;
}
}
}
};
template class MaxPool2dWithIndexFunctor<platform::CPUPlace, float>;
@ -152,6 +151,7 @@ class MaxPool3dWithIndexFunctor<platform::CPUPlace, T> {
const int padding_width = paddings[2];
const int input_stride = input_depth * input_height * input_width;
const int output_stride = output_depth * output_height * output_width;
const T* input_data = input.data<T>();
T* output_data = output.mutable_data<T>(context.GetPlace());
T* mask_data = mask.mutable_data<T>(context.GetPlace());
@ -170,17 +170,17 @@ class MaxPool3dWithIndexFunctor<platform::CPUPlace, T> {
int wstart = pw * stride_width - padding_width;
int wend = std::min(wstart + ksize_width, input_width);
wstart = std::max(wstart, 0);
int output_idx = (pd * output_height + ph) * output_width + pw;
T ele = static_cast<T>(-FLT_MAX);
int index = -1;
for (int d = dstart; d < dend; ++d) {
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
if (ele <
input_data[(d * input_height + h) * input_width + w]) {
index = (d * input_height + h) * input_width + w;
ele =
input_data[(d * input_height + h) * input_width + w];
int input_idx = (d * input_height + h) * input_width + w;
if (ele < input_data[input_idx]) {
index = input_idx;
ele = input_data[input_idx];
}
}
}

File diff suppressed because it is too large Load Diff

@ -23,7 +23,6 @@ namespace operators {
namespace math {
//////////////////////
#define FLT_MAX __FLT_MAX__
/////////////////////
template <typename Place, typename T>
class MaxPool2dWithIndexFunctor {

@ -76,8 +76,8 @@ class MaxPoolWithIndexOpGrad : public framework::OperatorWithKernel {
protected:
void InferShape(framework::InferShapeContextBase *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("X")),
"X(Input) of MaxPoolWithIndexOpGrad should not be null.");
PADDLE_ENFORCE(ctx->HasInput("X"),
"X(Input) of Pooling should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput(framework::GradVarName("X")),
"X@GRAD(Input@GRAD) of MaxPoolWithIndexOpGrad should not be null.");
@ -97,28 +97,37 @@ class MaxPool2dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker {
"number of channels, H and W is the height and width of image.");
AddOutput("Out",
"The output tensor of pooling operator."
"The format of output tensor is also NCHW.");
"The format of output tensor is also NCHW."
"Where N is batch size, C is "
"the number of channels, H and W is the height and "
"width of image.");
AddOutput("Mask",
"The Mask tensor of pooling operator."
"The format of output tensor is also NCHW.");
"The format of output tensor is also NCHW."
"Where N is batch size, C is the number of channels, H and W "
"is the height and width of image."
"The value in it is the index in current feature map");
AddAttr<std::vector<int>>(
"ksize", "pooling size(height, width) of pooling operator.");
"ksize",
"Pooling size(height, width) of pooling operator."
"If globalPooling = true, ksize is ignored and need not be "
"specified."); // TODO(Add checker)
AddAttr<bool>(
"globalPooling",
"whether to use the globalPooling."
"int constant equal to false or true"
"default false"
"Whether to use the globalPooling."
"Bool constant equal to false or true."
"Default false."
"If globalPooling = true, ksize is ignored and need not be specified.")
.SetDefault(false);
AddAttr<std::vector<int>>("strides",
"strides(height, width) of pooling operator."
"default {1,1}")
.SetDefault({1, 1});
"Strides(height, width) of pooling operator."
"Default {1,1}.")
.SetDefault({1, 1}); // TODO(Add checker)
AddAttr<std::vector<int>>("paddings",
"paddings(height, width) of pooling operator."
"default {0,0}")
.SetDefault({0, 0});
"Paddings(height, width) of pooling operator."
"Default {0,0}.")
.SetDefault({0, 0}); // TODO(Add checker)
AddComment(R"DOC(
The maxPooling2d with index operation calculates the output and the mask based on
@ -140,30 +149,40 @@ class MaxPool3dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker {
"image.");
AddOutput("Out",
"The output tensor of pooling operator."
"The format of output tensor is also NCDHW.");
"The format of output tensor is also NCDHW."
"Where N is batch size, C is "
"the number of channels, D, H and W is the depth, height and "
"width of image.");
AddOutput("Mask",
"The Mask tensor of pooling operator."
"The format of output tensor is also NCDHW.");
"The format of output tensor is also NCDHW."
"Where N is batch size, C is the number of channels, D, H and W "
"is the depth, height and width of image."
"The value in it is the index in current feature map");
AddAttr<std::vector<int>>(
"ksize", "pooling size(depth, height, width) of pooling operator.");
"ksize",
"Pooling size(depth, height, width) of pooling operator."
"If globalPooling = true, ksize is ignored and need not be "
"specified."); // TODO(Add checker)
AddAttr<bool>(
"globalPooling",
"whether to use the globalPooling."
"int constant equal to false or true"
"default false"
"Whether to use the globalPooling."
"Bool constant equal to false or true."
"Default false."
"If globalPooling = true, ksize is ignored and need not be specified.")
.SetDefault(false);
AddAttr<std::vector<int>>(
"strides",
"strides(depth, height, width) of pooling operator."
"default {1,1,1}")
.SetDefault({1, 1, 1});
"Strides(depth, height, width) of pooling operator."
"Default {1,1,1}.")
.SetDefault({1, 1, 1}); // TODO(Add checker)
AddAttr<std::vector<int>>(
"paddings",
"paddings(depth, height, width) of pooling operator."
"default {0,0,0}")
.SetDefault({0, 0, 0});
"Paddings(depth, height, width) of pooling operator."
"Default {0,0,0}.")
.SetDefault({0, 0, 0}); // TODO(Add checker)
AddComment(R"DOC(
The maxpooling3d with index operation calculates the output and the mask based on
the input and ksize, strides, paddings parameters.

@ -32,11 +32,10 @@ class MaxPoolWithIndexKernel : public framework::OpKernel {
Tensor* out = context.Output<Tensor>("Out");
Tensor* mask = context.Output<Tensor>("Mask");
bool global_pooling = context.Attr<bool>("globalPooling");
std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
if (global_pooling) {
if (context.Attr<bool>("globalPooling")) {
for (size_t i = 0; i < ksize.size(); ++i) {
ksize[i] = static_cast<int>(in_x->dims()[i + 2]);
}
@ -63,7 +62,7 @@ template <typename Place, typename T>
class MaxPoolWithIndexGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* mask = context.Input<Tensor>("Maks");
const Tensor* mask = context.Input<Tensor>("Mask");
const Tensor* out_grad =
context.Input<Tensor>(framework::GradVarName("Out"));
Tensor* in_x_grad = context.Output<Tensor>(framework::GradVarName("X"));
@ -71,6 +70,11 @@ class MaxPoolWithIndexGradKernel : public framework::OpKernel {
std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
if (context.Attr<bool>("globalPooling")) {
for (size_t i = 0; i < ksize.size(); ++i) {
ksize[i] = static_cast<int>(in_x_grad->dims()[i + 2]);
}
}
if (in_x_grad) {
in_x_grad->mutable_data<T>(context.GetPlace());

@ -3,7 +3,11 @@ import numpy as np
from op_test import OpTest
def max_pool3D_forward_naive(x, ksize, strides, paddings=[0, 0], global_pool=0):
def max_pool3D_forward_naive(x,
ksize,
strides,
paddings=[0, 0, 0],
global_pool=0):
N, C, D, H, W = x.shape
if global_pool == 1:
@ -25,8 +29,19 @@ def max_pool3D_forward_naive(x, ksize, strides, paddings=[0, 0], global_pool=0):
x_masked = x[:, :, d_start:d_end, h_start:h_end, w_start:w_end]
out[:, :, k, i, j] = np.max(x_masked, axis=(2, 3, 4))
# mask[:,:, k, i, j] = np.argmax(x_masked, axis=(2, 3, 4))
return out
for n in xrange(N):
for c in xrange(C):
arr = x_masked[n, c, :, :, :]
index = np.where(arr == np.max(arr))
sub_deep = index[0][0]
sub_row = index[1][0]
sub_col = index[2][0]
index = ((d_start + sub_deep) * H +
(h_start + sub_row)) * W + w_start + sub_col
mask[n, c, k, i, j] = index
return out, mask
def max_pool2D_forward_naive(x, ksize, strides, paddings=[0, 0], global_pool=0):
@ -47,19 +62,25 @@ def max_pool2D_forward_naive(x, ksize, strides, paddings=[0, 0], global_pool=0):
x_masked = x[:, :, r_start:r_end, c_start:c_end]
out[:, :, i, j] = np.max(x_masked, axis=(2, 3))
# mask[:,:, i, j] = np.argmax(x_masked, axis=(2, 3))
return out
for n in xrange(N):
for c in xrange(C):
arr = x_masked[n, c, :, :]
index = np.where(arr == np.max(arr))
sub_row = index[0][0]
sub_col = index[1][0]
index = (r_start + sub_row) * W + c_start + sub_col
mask[n, c, i, j] = index
return out, mask
class TestMaxPoolWithIndex_Op(OpTest):
def setUp(self):
self.initTestCase()
self.op_type = "maxPool3dWithIndex"
input = np.random.random(self.shape).astype("float32")
output = self.pool_forward_naive(input, self.ksize, self.strides,
output, mask = self.pool_forward_naive(input, self.ksize, self.strides,
self.paddings, self.global_pool)
# mask = np.zeros(output.shape)
self.attrs = {
'strides': self.strides,
@ -69,7 +90,7 @@ class TestMaxPoolWithIndex_Op(OpTest):
}
self.inputs = {'X': input}
self.outputs = {'Out': output}
self.outputs = {'Out': output, "Mask": mask}
def test_check_output(self):
self.check_output()
@ -78,7 +99,8 @@ class TestMaxPoolWithIndex_Op(OpTest):
# self.check_grad(set(['X']), ['Out'], max_relative_error=0.07)
def initTestCase(self):
self.global_pool = 0
self.global_pool = False
self.op_type = "maxPool3dWithIndex"
self.pool_forward_naive = max_pool3D_forward_naive
self.shape = [2, 3, 7, 7, 7]
self.ksize = [3, 3, 3]
@ -86,10 +108,9 @@ class TestMaxPoolWithIndex_Op(OpTest):
self.paddings = [1, 1, 1]
""""
class TestCase1(TestMaxPoolWithIndex_Op):
def initTestCase(self):
self.global_pool = 1
self.global_pool = True
self.op_type = "maxPool3dWithIndex"
self.pool_forward_naive = max_pool3D_forward_naive
self.shape = [2, 3, 5, 5, 5]
@ -100,7 +121,7 @@ class TestCase1(TestMaxPoolWithIndex_Op):
class TestCase2(TestMaxPoolWithIndex_Op):
def initTestCase(self):
self.global_pool = 0
self.global_pool = False
self.op_type = "maxPool2dWithIndex"
self.pool_forward_naive = max_pool2D_forward_naive
self.shape = [2, 3, 7, 7]
@ -111,7 +132,7 @@ class TestCase2(TestMaxPoolWithIndex_Op):
class TestCase3(TestMaxPoolWithIndex_Op):
def initTestCase(self):
self.global_pool = 1
self.global_pool = True
self.op_type = "maxPool2dWithIndex"
self.pool_forward_naive = max_pool2D_forward_naive
self.shape = [2, 3, 5, 5]
@ -122,4 +143,3 @@ class TestCase3(TestMaxPoolWithIndex_Op):
if __name__ == '__main__':
unittest.main()
"""

Loading…
Cancel
Save