fix code format and some bug

revert-4814-Add_sequence_project_op
chengduoZH 8 years ago
parent 6326c40d27
commit bee95fc891

@ -26,7 +26,6 @@ class MaxPool2dWithIndexFunctor<platform::CPUPlace, T> {
framework::Tensor& mask, std::vector<int>& ksize, framework::Tensor& mask, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings) { std::vector<int>& strides, std::vector<int>& paddings) {
const int batch_size = input.dims()[0]; const int batch_size = input.dims()[0];
const int input_height = input.dims()[2]; const int input_height = input.dims()[2];
const int input_width = input.dims()[3]; const int input_width = input.dims()[3];
const int output_channels = output.dims()[1]; const int output_channels = output.dims()[1];
@ -112,11 +111,11 @@ class MaxPool2dWithIndexGradFunctor<platform::CPUPlace, T> {
input_grad_data[input_idx] += output_grad_data[output_idx]; input_grad_data[input_idx] += output_grad_data[output_idx];
} }
} }
// offset
input_grad_data += input_stride;
output_grad_data += output_stride;
mask_data += output_stride;
} }
// offset
input_grad_data += input_stride;
output_grad_data += output_stride;
mask_data += output_stride;
} }
} }
}; };
@ -152,6 +151,7 @@ class MaxPool3dWithIndexFunctor<platform::CPUPlace, T> {
const int padding_width = paddings[2]; const int padding_width = paddings[2];
const int input_stride = input_depth * input_height * input_width; const int input_stride = input_depth * input_height * input_width;
const int output_stride = output_depth * output_height * output_width; const int output_stride = output_depth * output_height * output_width;
const T* input_data = input.data<T>(); const T* input_data = input.data<T>();
T* output_data = output.mutable_data<T>(context.GetPlace()); T* output_data = output.mutable_data<T>(context.GetPlace());
T* mask_data = mask.mutable_data<T>(context.GetPlace()); T* mask_data = mask.mutable_data<T>(context.GetPlace());
@ -170,17 +170,17 @@ class MaxPool3dWithIndexFunctor<platform::CPUPlace, T> {
int wstart = pw * stride_width - padding_width; int wstart = pw * stride_width - padding_width;
int wend = std::min(wstart + ksize_width, input_width); int wend = std::min(wstart + ksize_width, input_width);
wstart = std::max(wstart, 0); wstart = std::max(wstart, 0);
int output_idx = (pd * output_height + ph) * output_width + pw; int output_idx = (pd * output_height + ph) * output_width + pw;
T ele = static_cast<T>(-FLT_MAX); T ele = static_cast<T>(-FLT_MAX);
int index = -1; int index = -1;
for (int d = dstart; d < dend; ++d) { for (int d = dstart; d < dend; ++d) {
for (int h = hstart; h < hend; ++h) { for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) { for (int w = wstart; w < wend; ++w) {
if (ele < int input_idx = (d * input_height + h) * input_width + w;
input_data[(d * input_height + h) * input_width + w]) { if (ele < input_data[input_idx]) {
index = (d * input_height + h) * input_width + w; index = input_idx;
ele = ele = input_data[input_idx];
input_data[(d * input_height + h) * input_width + w];
} }
} }
} }

File diff suppressed because it is too large Load Diff

@ -23,7 +23,6 @@ namespace operators {
namespace math { namespace math {
////////////////////// //////////////////////
#define FLT_MAX __FLT_MAX__ #define FLT_MAX __FLT_MAX__
/////////////////////
template <typename Place, typename T> template <typename Place, typename T>
class MaxPool2dWithIndexFunctor { class MaxPool2dWithIndexFunctor {

@ -76,8 +76,8 @@ class MaxPoolWithIndexOpGrad : public framework::OperatorWithKernel {
protected: protected:
void InferShape(framework::InferShapeContextBase *ctx) const override { void InferShape(framework::InferShapeContextBase *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("X")), PADDLE_ENFORCE(ctx->HasInput("X"),
"X(Input) of MaxPoolWithIndexOpGrad should not be null."); "X(Input) of Pooling should not be null.");
PADDLE_ENFORCE( PADDLE_ENFORCE(
ctx->HasOutput(framework::GradVarName("X")), ctx->HasOutput(framework::GradVarName("X")),
"X@GRAD(Input@GRAD) of MaxPoolWithIndexOpGrad should not be null."); "X@GRAD(Input@GRAD) of MaxPoolWithIndexOpGrad should not be null.");
@ -97,28 +97,37 @@ class MaxPool2dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker {
"number of channels, H and W is the height and width of image."); "number of channels, H and W is the height and width of image.");
AddOutput("Out", AddOutput("Out",
"The output tensor of pooling operator." "The output tensor of pooling operator."
"The format of output tensor is also NCHW."); "The format of output tensor is also NCHW."
"Where N is batch size, C is "
"the number of channels, H and W is the height and "
"width of image.");
AddOutput("Mask", AddOutput("Mask",
"The Mask tensor of pooling operator." "The Mask tensor of pooling operator."
"The format of output tensor is also NCHW."); "The format of output tensor is also NCHW."
"Where N is batch size, C is the number of channels, H and W "
"is the height and width of image."
"The value in it is the index in current feature map");
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>(
"ksize", "pooling size(height, width) of pooling operator."); "ksize",
"Pooling size(height, width) of pooling operator."
"If globalPooling = true, ksize is ignored and need not be "
"specified."); // TODO(Add checker)
AddAttr<bool>( AddAttr<bool>(
"globalPooling", "globalPooling",
"whether to use the globalPooling." "Whether to use the globalPooling."
"int constant equal to false or true" "Bool constant equal to false or true."
"default false" "Default false."
"If globalPooling = true, ksize is ignored and need not be specified.") "If globalPooling = true, ksize is ignored and need not be specified.")
.SetDefault(false); .SetDefault(false);
AddAttr<std::vector<int>>("strides", AddAttr<std::vector<int>>("strides",
"strides(height, width) of pooling operator." "Strides(height, width) of pooling operator."
"default {1,1}") "Default {1,1}.")
.SetDefault({1, 1}); .SetDefault({1, 1}); // TODO(Add checker)
AddAttr<std::vector<int>>("paddings", AddAttr<std::vector<int>>("paddings",
"paddings(height, width) of pooling operator." "Paddings(height, width) of pooling operator."
"default {0,0}") "Default {0,0}.")
.SetDefault({0, 0}); .SetDefault({0, 0}); // TODO(Add checker)
AddComment(R"DOC( AddComment(R"DOC(
The maxPooling2d with index operation calculates the output and the mask based on The maxPooling2d with index operation calculates the output and the mask based on
@ -140,30 +149,40 @@ class MaxPool3dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker {
"image."); "image.");
AddOutput("Out", AddOutput("Out",
"The output tensor of pooling operator." "The output tensor of pooling operator."
"The format of output tensor is also NCDHW."); "The format of output tensor is also NCDHW."
"Where N is batch size, C is "
"the number of channels, D, H and W is the depth, height and "
"width of image.");
AddOutput("Mask", AddOutput("Mask",
"The Mask tensor of pooling operator." "The Mask tensor of pooling operator."
"The format of output tensor is also NCDHW."); "The format of output tensor is also NCDHW."
"Where N is batch size, C is the number of channels, D, H and W "
"is the depth, height and width of image."
"The value in it is the index in current feature map");
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>(
"ksize", "pooling size(depth, height, width) of pooling operator."); "ksize",
"Pooling size(depth, height, width) of pooling operator."
"If globalPooling = true, ksize is ignored and need not be "
"specified."); // TODO(Add checker)
AddAttr<bool>( AddAttr<bool>(
"globalPooling", "globalPooling",
"whether to use the globalPooling." "Whether to use the globalPooling."
"int constant equal to false or true" "Bool constant equal to false or true."
"default false" "Default false."
"If globalPooling = true, ksize is ignored and need not be specified.") "If globalPooling = true, ksize is ignored and need not be specified.")
.SetDefault(false); .SetDefault(false);
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>(
"strides", "strides",
"strides(depth, height, width) of pooling operator." "Strides(depth, height, width) of pooling operator."
"default {1,1,1}") "Default {1,1,1}.")
.SetDefault({1, 1, 1}); .SetDefault({1, 1, 1}); // TODO(Add checker)
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>(
"paddings", "paddings",
"paddings(depth, height, width) of pooling operator." "Paddings(depth, height, width) of pooling operator."
"default {0,0,0}") "Default {0,0,0}.")
.SetDefault({0, 0, 0}); .SetDefault({0, 0, 0}); // TODO(Add checker)
AddComment(R"DOC( AddComment(R"DOC(
The maxpooling3d with index operation calculates the output and the mask based on The maxpooling3d with index operation calculates the output and the mask based on
the input and ksize, strides, paddings parameters. the input and ksize, strides, paddings parameters.

@ -32,11 +32,10 @@ class MaxPoolWithIndexKernel : public framework::OpKernel {
Tensor* out = context.Output<Tensor>("Out"); Tensor* out = context.Output<Tensor>("Out");
Tensor* mask = context.Output<Tensor>("Mask"); Tensor* mask = context.Output<Tensor>("Mask");
bool global_pooling = context.Attr<bool>("globalPooling");
std::vector<int> ksize = context.Attr<std::vector<int>>("ksize"); std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
std::vector<int> strides = context.Attr<std::vector<int>>("strides"); std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings"); std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
if (global_pooling) { if (context.Attr<bool>("globalPooling")) {
for (size_t i = 0; i < ksize.size(); ++i) { for (size_t i = 0; i < ksize.size(); ++i) {
ksize[i] = static_cast<int>(in_x->dims()[i + 2]); ksize[i] = static_cast<int>(in_x->dims()[i + 2]);
} }
@ -63,7 +62,7 @@ template <typename Place, typename T>
class MaxPoolWithIndexGradKernel : public framework::OpKernel { class MaxPoolWithIndexGradKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
const Tensor* mask = context.Input<Tensor>("Maks"); const Tensor* mask = context.Input<Tensor>("Mask");
const Tensor* out_grad = const Tensor* out_grad =
context.Input<Tensor>(framework::GradVarName("Out")); context.Input<Tensor>(framework::GradVarName("Out"));
Tensor* in_x_grad = context.Output<Tensor>(framework::GradVarName("X")); Tensor* in_x_grad = context.Output<Tensor>(framework::GradVarName("X"));
@ -71,6 +70,11 @@ class MaxPoolWithIndexGradKernel : public framework::OpKernel {
std::vector<int> ksize = context.Attr<std::vector<int>>("ksize"); std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
std::vector<int> strides = context.Attr<std::vector<int>>("strides"); std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings"); std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
if (context.Attr<bool>("globalPooling")) {
for (size_t i = 0; i < ksize.size(); ++i) {
ksize[i] = static_cast<int>(in_x_grad->dims()[i + 2]);
}
}
if (in_x_grad) { if (in_x_grad) {
in_x_grad->mutable_data<T>(context.GetPlace()); in_x_grad->mutable_data<T>(context.GetPlace());

@ -3,7 +3,11 @@ import numpy as np
from op_test import OpTest from op_test import OpTest
def max_pool3D_forward_naive(x, ksize, strides, paddings=[0, 0], global_pool=0): def max_pool3D_forward_naive(x,
ksize,
strides,
paddings=[0, 0, 0],
global_pool=0):
N, C, D, H, W = x.shape N, C, D, H, W = x.shape
if global_pool == 1: if global_pool == 1:
@ -25,8 +29,19 @@ def max_pool3D_forward_naive(x, ksize, strides, paddings=[0, 0], global_pool=0):
x_masked = x[:, :, d_start:d_end, h_start:h_end, w_start:w_end] x_masked = x[:, :, d_start:d_end, h_start:h_end, w_start:w_end]
out[:, :, k, i, j] = np.max(x_masked, axis=(2, 3, 4)) out[:, :, k, i, j] = np.max(x_masked, axis=(2, 3, 4))
# mask[:,:, k, i, j] = np.argmax(x_masked, axis=(2, 3, 4))
return out for n in xrange(N):
for c in xrange(C):
arr = x_masked[n, c, :, :, :]
index = np.where(arr == np.max(arr))
sub_deep = index[0][0]
sub_row = index[1][0]
sub_col = index[2][0]
index = ((d_start + sub_deep) * H +
(h_start + sub_row)) * W + w_start + sub_col
mask[n, c, k, i, j] = index
return out, mask
def max_pool2D_forward_naive(x, ksize, strides, paddings=[0, 0], global_pool=0): def max_pool2D_forward_naive(x, ksize, strides, paddings=[0, 0], global_pool=0):
@ -47,19 +62,25 @@ def max_pool2D_forward_naive(x, ksize, strides, paddings=[0, 0], global_pool=0):
x_masked = x[:, :, r_start:r_end, c_start:c_end] x_masked = x[:, :, r_start:r_end, c_start:c_end]
out[:, :, i, j] = np.max(x_masked, axis=(2, 3)) out[:, :, i, j] = np.max(x_masked, axis=(2, 3))
# mask[:,:, i, j] = np.argmax(x_masked, axis=(2, 3))
return out for n in xrange(N):
for c in xrange(C):
arr = x_masked[n, c, :, :]
index = np.where(arr == np.max(arr))
sub_row = index[0][0]
sub_col = index[1][0]
index = (r_start + sub_row) * W + c_start + sub_col
mask[n, c, i, j] = index
return out, mask
class TestMaxPoolWithIndex_Op(OpTest): class TestMaxPoolWithIndex_Op(OpTest):
def setUp(self): def setUp(self):
self.initTestCase() self.initTestCase()
self.op_type = "maxPool3dWithIndex"
input = np.random.random(self.shape).astype("float32") input = np.random.random(self.shape).astype("float32")
output = self.pool_forward_naive(input, self.ksize, self.strides, output, mask = self.pool_forward_naive(input, self.ksize, self.strides,
self.paddings, self.global_pool) self.paddings, self.global_pool)
# mask = np.zeros(output.shape)
self.attrs = { self.attrs = {
'strides': self.strides, 'strides': self.strides,
@ -69,7 +90,7 @@ class TestMaxPoolWithIndex_Op(OpTest):
} }
self.inputs = {'X': input} self.inputs = {'X': input}
self.outputs = {'Out': output} self.outputs = {'Out': output, "Mask": mask}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
@ -78,7 +99,8 @@ class TestMaxPoolWithIndex_Op(OpTest):
# self.check_grad(set(['X']), ['Out'], max_relative_error=0.07) # self.check_grad(set(['X']), ['Out'], max_relative_error=0.07)
def initTestCase(self): def initTestCase(self):
self.global_pool = 0 self.global_pool = False
self.op_type = "maxPool3dWithIndex"
self.pool_forward_naive = max_pool3D_forward_naive self.pool_forward_naive = max_pool3D_forward_naive
self.shape = [2, 3, 7, 7, 7] self.shape = [2, 3, 7, 7, 7]
self.ksize = [3, 3, 3] self.ksize = [3, 3, 3]
@ -86,10 +108,9 @@ class TestMaxPoolWithIndex_Op(OpTest):
self.paddings = [1, 1, 1] self.paddings = [1, 1, 1]
""""
class TestCase1(TestMaxPoolWithIndex_Op): class TestCase1(TestMaxPoolWithIndex_Op):
def initTestCase(self): def initTestCase(self):
self.global_pool = 1 self.global_pool = True
self.op_type = "maxPool3dWithIndex" self.op_type = "maxPool3dWithIndex"
self.pool_forward_naive = max_pool3D_forward_naive self.pool_forward_naive = max_pool3D_forward_naive
self.shape = [2, 3, 5, 5, 5] self.shape = [2, 3, 5, 5, 5]
@ -100,7 +121,7 @@ class TestCase1(TestMaxPoolWithIndex_Op):
class TestCase2(TestMaxPoolWithIndex_Op): class TestCase2(TestMaxPoolWithIndex_Op):
def initTestCase(self): def initTestCase(self):
self.global_pool = 0 self.global_pool = False
self.op_type = "maxPool2dWithIndex" self.op_type = "maxPool2dWithIndex"
self.pool_forward_naive = max_pool2D_forward_naive self.pool_forward_naive = max_pool2D_forward_naive
self.shape = [2, 3, 7, 7] self.shape = [2, 3, 7, 7]
@ -111,7 +132,7 @@ class TestCase2(TestMaxPoolWithIndex_Op):
class TestCase3(TestMaxPoolWithIndex_Op): class TestCase3(TestMaxPoolWithIndex_Op):
def initTestCase(self): def initTestCase(self):
self.global_pool = 1 self.global_pool = True
self.op_type = "maxPool2dWithIndex" self.op_type = "maxPool2dWithIndex"
self.pool_forward_naive = max_pool2D_forward_naive self.pool_forward_naive = max_pool2D_forward_naive
self.shape = [2, 3, 5, 5] self.shape = [2, 3, 5, 5]
@ -122,4 +143,3 @@ class TestCase3(TestMaxPoolWithIndex_Op):
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
"""

Loading…
Cancel
Save