Topk share lod (#7373)

* add lod tensor ToAbsOffset test

* add share lod to topk op and softmax op
detection_output_fixbug
Qiao Longfei 7 years ago committed by GitHub
parent cedd9805f5
commit 91f80f792d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -115,5 +115,21 @@ TEST(LoD, AppendLoD) {
EXPECT_EQ(origin, expected);
}
TEST(LoD, ToAbsOffset) {
LoD relative_lod;
relative_lod.push_back(std::vector<size_t>({0, 2}));
relative_lod.push_back(std::vector<size_t>({0, 1, 3}));
relative_lod.push_back(std::vector<size_t>({0, 2, 4, 5}));
LoD abs_lod = paddle::framework::ToAbsOffset(relative_lod);
LoD expected;
expected.push_back(std::vector<size_t>({0, 5}));
expected.push_back(std::vector<size_t>({0, 2, 5}));
expected.push_back(std::vector<size_t>({0, 2, 4, 5}));
EXPECT_EQ(abs_lod, expected);
}
} // namespace framework
} // namespace paddle

@ -31,6 +31,7 @@ class SoftmaxOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(x_dims.size() == 2UL,
"The input of softmax op must be a matrix.");
ctx->SetOutputDim("Out", x_dims);
ctx->ShareLoD("X", /*->*/ "Out");
}
};

@ -41,6 +41,8 @@ class TopkOp : public framework::OperatorWithKernel {
dims[dims.size() - 1] = k;
ctx->SetOutputDim("Out", dims);
ctx->SetOutputDim("Indices", dims);
ctx->ShareLoD("X", "Out");
ctx->ShareLoD("X", "Indices");
}
};

Loading…
Cancel
Save