pass unit test

enforce_failed
Yibing Liu 8 years ago
parent 12eaa22ad2
commit 899c7d6b35

@ -38,6 +38,7 @@ class ReshapeOp : public framework::OperatorWithKernel {
size_t in_size = framework::product(in->dims());
PADDLE_ENFORCE_EQ(shape_size, in_size,
"The size of Input(X) mismatches with Attr(shape).");
ctx.Output<framework::Tensor>("Out")->Resize(in->dims());
}
};
@ -51,7 +52,7 @@ class ReshapeOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<std::vector<int>>("shape", "Target shape of reshape operator.");
AddComment(R"DOC(Reshape operator
The input tensor will be reshaped with Attr(shape).
Reshape Input(X) into the shape specified by Attr(shape).
)DOC");
}
};

@ -23,13 +23,13 @@ namespace operators {
using Tensor = framework::Tensor;
template <typename Place, typename T, typename AttrType = T>
template <typename Place, typename T>
class ReshapeKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const {
auto* out = ctx.Output<Tensor>("Out");
auto* in = ctx.Input<Tensor>("X");
out->mutable_data<T>(in->place());
out->mutable_data<T>(ctx.GetPlace());
auto shape = ctx.Attr<std::vector<int>>("shape");
std::vector<int64_t> tmp;
@ -42,7 +42,7 @@ class ReshapeKernel : public framework::OpKernel {
}
};
template <typename Place, typename T, typename AttrType = T>
template <typename Place, typename T>
class ReshapeGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const {
@ -51,7 +51,6 @@ class ReshapeGradKernel : public framework::OpKernel {
d_x->mutable_data<T>(ctx.GetPlace());
auto in_dims = d_x->dims();
d_x->CopyFrom<T>(*d_out, ctx.GetPlace());
d_x->Resize(in_dims);
}

@ -1,6 +1,6 @@
import unittest
import numpy as np
from gradient_checker import GradientChecker, create_op
from gradient_checker import GradientChecker, Operator
from op_test_util import OpTestMeta
@ -9,19 +9,16 @@ class TestReshapeOp(unittest.TestCase):
def setUp(self):
self.type = "reshape"
self.inputs = {'X': np.random.random((2, 4)).astype("float32"), }
print self.inputs
self.attrs = {'shape': [4, 2]}
self.inputs = {'X': np.random.random((37, 51)).astype("float32"), }
self.attrs = {'shape': [51, 37]}
self.outputs = {'Out': self.inputs['X'].reshape(self.attrs['shape'])}
print self.outputs
class ReshapeGradOpTest(GradientChecker):
def test_normal(self):
op = create_op("reshape")
inputs = {"X": np.random.random((2, 4)).astype("float32")}
attrs = {'shape': [4, 2]}
self.check_grad(op, inputs, attrs, set("X"), "Out")
op = Operator("reshape", X='X', Out='Out', shape=[5, 40])
inputs = {"X": np.random.random((10, 20)).astype("float32")}
self.check_grad(op, inputs, set("X"), "Out")
if __name__ == '__main__':

Loading…
Cancel
Save