@ -14,16 +14,11 @@ limitations under the License. */
#pragma once
#include "paddle/framework/attribute.h"
#include "paddle/framework/ddim.h"
namespace paddle {
namespace framework {
class InferShapeContextBase;
typedef std::function<void(InferShapeContextBase *)> InferShapeFn;
class InferShapeContextBase {
public:
virtual ~InferShapeContextBase() {}
@ -26,8 +26,6 @@ class TestInferShape(unittest.TestCase):
sum_op_desc.set_input("X", ["x1", "x2"])
sum_op_desc.set_output("Out", ["out"])
print(type(sum_op_desc))
print(type(block))
core.Operator.infer_shape(sum_op_desc, block)
self.assertEqual(out.shape(), shape)