|
|
|
@ -60,40 +60,78 @@ class TreeConvOp : public framework::OperatorWithKernel {
|
|
|
|
|
public:
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Out"));
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("NodesVector"), "Input", "NodesVector",
|
|
|
|
|
"TreeConv");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("Filter"), "Input", "Filter", "TreeConv");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("EdgeSet"), "Input", "EdgeSet", "TreeConv");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "TreeConv");
|
|
|
|
|
|
|
|
|
|
auto edge_dims = ctx->GetInputDim("EdgeSet");
|
|
|
|
|
auto vector_dims = ctx->GetInputDim("NodesVector");
|
|
|
|
|
auto filter_dims = ctx->GetInputDim("Filter");
|
|
|
|
|
|
|
|
|
|
if (ctx->IsRuntime()) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(edge_dims[2], 2, "Input(EdgeSet) dim[2] should be 2");
|
|
|
|
|
PADDLE_ENFORCE_EQ(edge_dims[2], 2,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input(EdgeSet) dim[2] should be 2. "
|
|
|
|
|
"But received Input(EdgeSet) dim[2] is %d.",
|
|
|
|
|
edge_dims[2]));
|
|
|
|
|
} else {
|
|
|
|
|
if (edge_dims[2] != -1) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(edge_dims[2], 2, "Input(EdgeSet) dim[2] should be 2");
|
|
|
|
|
PADDLE_ENFORCE_EQ(edge_dims[2], 2,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input(EdgeSet) dim[2] should be 2. "
|
|
|
|
|
"But received Input(EdgeSet) dim[2] is %d.",
|
|
|
|
|
edge_dims[2]));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
PADDLE_ENFORCE_EQ(edge_dims.size(), 3,
|
|
|
|
|
"The dimension of EdgeSet Tensor should be 3");
|
|
|
|
|
PADDLE_ENFORCE_EQ(vector_dims.size(), 3,
|
|
|
|
|
"The dimension of NodesVector Tensor should be 3");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The dimension of EdgeSet Tensor should be 3. "
|
|
|
|
|
"But received the dimension of EdgeSet Tensor is %d.",
|
|
|
|
|
edge_dims.size()));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
vector_dims.size(), 3,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The dimension of NodesVector Tensor should be 3. "
|
|
|
|
|
"But received the dimension of NodesVector Tensor is %d.",
|
|
|
|
|
vector_dims.size()));
|
|
|
|
|
PADDLE_ENFORCE_EQ(filter_dims.size(), 4,
|
|
|
|
|
"The dimension of Filter Tensor should be 4");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The dimension of Filter Tensor should be 4. "
|
|
|
|
|
"But received the dimension of Filter Tensor is %d.",
|
|
|
|
|
filter_dims.size()));
|
|
|
|
|
|
|
|
|
|
if (ctx->IsRuntime()) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(filter_dims[1], 3, "Input(Filter) dim[1] should be 3");
|
|
|
|
|
PADDLE_ENFORCE_EQ(filter_dims[1], 3,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input(Filter) dim[1] should be 3. "
|
|
|
|
|
"But received Input(Filter) dim[1] is %d.",
|
|
|
|
|
filter_dims[1]));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
filter_dims[0], vector_dims[2],
|
|
|
|
|
"Input(Filter) dim[0] must equal to Input(NodesVector) dim[2]");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input(Filter) dim[0] must equal to Input(NodesVector) dim[2]. "
|
|
|
|
|
"But received Input(Filter) dim[0] = %d, Input(NodesVector) "
|
|
|
|
|
"dim[2] = %d.",
|
|
|
|
|
filter_dims[0], vector_dims[2]));
|
|
|
|
|
} else {
|
|
|
|
|
if (filter_dims[1] != -1) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(filter_dims[1], 3,
|
|
|
|
|
"Input(Filter) dim[1] should be 3");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input(Filter) dim[1] should be 3. "
|
|
|
|
|
"But received Input(Filter) dim[1] is %d.",
|
|
|
|
|
filter_dims[1]));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (filter_dims[0] != -1 && vector_dims[2] != -1) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
filter_dims[0], vector_dims[2],
|
|
|
|
|
"Input(Filter) dim[0] must equal to Input(NodesVector) dim[2]");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input(Filter) dim[0] must equal to Input(NodesVector) dim[2]. "
|
|
|
|
|
"But received Input(Filter) dim[0] = %d, Input(NodesVector) "
|
|
|
|
|
"dim[2] = %d.",
|
|
|
|
|
filter_dims[0], vector_dims[2]));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
auto output_dims = framework::make_ddim(
|
|
|
|
@ -137,10 +175,21 @@ class TreeConvGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("Filter"), "Input", "Filter", "grad_TreeConv");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("EdgeSet"), "Input", "EdgeSet",
|
|
|
|
|
"grad_TreeConv");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("NodesVector"), "Input", "NodesVector",
|
|
|
|
|
"grad_TreeConv");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
|
|
|
|
|
framework::GradVarName("Out"), "grad_TreeConv");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("NodesVector")),
|
|
|
|
|
"Output", framework::GradVarName("NodesVector"),
|
|
|
|
|
"grad_TreeConv");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("Filter")), "Output",
|
|
|
|
|
framework::GradVarName("Filter"), "grad_TreeConv");
|
|
|
|
|
|
|
|
|
|
auto vectors_dims = ctx->GetInputDim("NodesVector");
|
|
|
|
|
auto filter_dims = ctx->GetInputDim("Filter");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
|
|
|
|
|
"the gradient of output(Out) must not be null");
|
|
|
|
|
if (ctx->HasOutput(framework::GradVarName("Filter"))) {
|
|
|
|
|
ctx->SetOutputDim(framework::GradVarName("Filter"), filter_dims);
|
|
|
|
|
}
|
|
|
|
|