fix reduce bug test=develop (#19971)

fix-python-transpose
wangchaochaohu 5 years ago committed by GitHub
parent 3ea2b661c0
commit 3409db950c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -197,6 +197,9 @@ class ReduceOp : public framework::OperatorWithKernel {
remove(dims_vector.begin(), dims_vector.end(), kDelFlag),
dims_vector.end());
}
if (!keep_dim && dims_vector.size() == 0) {
dims_vector.push_back(1);
}
auto out_dims = framework::make_ddim(dims_vector);
ctx->SetOutputDim("Out", out_dims);
if (dims[0] != 0) {

@ -397,5 +397,19 @@ class TestReduceAll(OpTest):
self.check_grad(['X'], 'Out')
class Test1DReduceWithAxes1(OpTest):
def setUp(self):
self.op_type = "reduce_sum"
self.inputs = {'X': np.random.random(1).astype("float64")}
self.attrs = {'dim': [0], 'keep_dim': False}
self.outputs = {'Out': self.inputs['X'].sum(axis=0)}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save