26 lines
904 B
26 lines
904 B
#include "paddle/framework/grad_op_builder.h"
|
|
#include <gtest/gtest.h>
|
|
#include "paddle/framework/op_registry.h"
|
|
#include "paddle/framework/operator.h"
|
|
|
|
USE_OP(add_two);
|
|
|
|
namespace paddle {
|
|
namespace framework {
|
|
|
|
TEST(GradOpBuilder, AddTwo) {
|
|
std::shared_ptr<OperatorBase> add_op(
|
|
OpRegistry::CreateOp("add_two", {"x", "y"}, {"out"}, {}));
|
|
std::shared_ptr<OperatorBase> grad_add_op = OpRegistry::CreateGradOp(add_op);
|
|
EXPECT_EQ(static_cast<int>(grad_add_op->inputs_.size()), 4);
|
|
EXPECT_EQ(static_cast<int>(grad_add_op->outputs_.size()), 2);
|
|
EXPECT_EQ(grad_add_op->Input("X"), "x");
|
|
EXPECT_EQ(grad_add_op->Input("Y"), "y");
|
|
EXPECT_EQ(grad_add_op->Input("Out"), "out");
|
|
EXPECT_EQ(grad_add_op->Input("Out@GRAD"), "out@GRAD");
|
|
EXPECT_EQ(grad_add_op->Output("X@GRAD"), "x@GRAD");
|
|
EXPECT_EQ(grad_add_op->Output("Y@GRAD"), "y@GRAD");
|
|
}
|
|
|
|
} // namespace framework
|
|
} // namespace paddle
|