|
|
|
@ -1,15 +1,11 @@
|
|
|
|
|
import unittest
|
|
|
|
|
import numpy as np
|
|
|
|
|
from gradient_checker import GradientChecker, create_op
|
|
|
|
|
from op_test_util import OpTestMeta
|
|
|
|
|
from paddle.v2.framework.op import Operator
|
|
|
|
|
from op_test import OpTest
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestConv2dOp(unittest.TestCase):
|
|
|
|
|
__metaclass__ = OpTestMeta
|
|
|
|
|
|
|
|
|
|
class TestConv2dOp(OpTest):
|
|
|
|
|
def setUp(self):
|
|
|
|
|
self.type = "conv2d"
|
|
|
|
|
self.op_type = "conv2d"
|
|
|
|
|
batch_size = 2
|
|
|
|
|
input_channels = 3
|
|
|
|
|
input_height = 5
|
|
|
|
@ -58,8 +54,11 @@ class TestConv2dOp(unittest.TestCase):
|
|
|
|
|
self.outputs = {'Output': output}
|
|
|
|
|
self.attrs = {'strides': [1, 1], 'paddings': [0, 0]}
|
|
|
|
|
|
|
|
|
|
def test_check_output(self):
|
|
|
|
|
self.check_output()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestConv2dGradOp(GradientChecker):
|
|
|
|
|
class TestConv2dGradOp(OpTest):
|
|
|
|
|
def setUp(self):
|
|
|
|
|
batch_size = 2
|
|
|
|
|
input_channels = 3
|
|
|
|
@ -79,21 +78,18 @@ class TestConv2dGradOp(GradientChecker):
|
|
|
|
|
(output_channels, input_channels, filter_height,
|
|
|
|
|
filter_width)).astype("float32")
|
|
|
|
|
|
|
|
|
|
self.op_type = 'conv2d'
|
|
|
|
|
self.inputs = {'Input': input, 'Filter': filter}
|
|
|
|
|
self.op = Operator(
|
|
|
|
|
"conv2d",
|
|
|
|
|
Input='Input',
|
|
|
|
|
Filter='Filter',
|
|
|
|
|
Output='Output',
|
|
|
|
|
strides=[1, 1],
|
|
|
|
|
paddings=[0, 0])
|
|
|
|
|
output = np.ndarray(
|
|
|
|
|
(batch_size, output_channels, output_height, output_width))
|
|
|
|
|
self.outputs = {'Output': output}
|
|
|
|
|
self.attrs = {'strides': [1, 1], 'paddings': [0, 0]}
|
|
|
|
|
|
|
|
|
|
def test_compare_grad(self):
|
|
|
|
|
self.compare_grad(self.op, self.inputs)
|
|
|
|
|
#def test_compare_grad(self):
|
|
|
|
|
# self.compare_grad(self.op, self.inputs)
|
|
|
|
|
|
|
|
|
|
def test_check_grad(self):
|
|
|
|
|
self.check_grad(self.op, self.inputs,
|
|
|
|
|
set(['Input', 'Filter']), 'Output')
|
|
|
|
|
self.check_grad(set(['Input', 'Filter']), 'Output')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|