|
|
@ -22,17 +22,16 @@ class LoDResetOp : public framework::OperatorWithKernel {
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
|
|
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
|
|
|
// input check
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"),
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"),
|
|
|
|
"Input(X) of LoDResetOp should not be null.");
|
|
|
|
"Input(X) of LoDResetOp should not be null.");
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
|
|
|
"Output(Out) of LoDResetOp should not be null.");
|
|
|
|
"Output(Out) of LoDResetOp should not be null.");
|
|
|
|
// If target LoD is not set form Input(), then it must be set from Attr().
|
|
|
|
|
|
|
|
if (!ctx->HasInput("TargetLoD")) {
|
|
|
|
if (!ctx->HasInput("Y")) {
|
|
|
|
auto level0 = ctx->Attrs().Get<std::vector<int>>("target_lod");
|
|
|
|
auto level0 = ctx->Attrs().Get<std::vector<int>>("target_lod");
|
|
|
|
PADDLE_ENFORCE(level0.size() > 1,
|
|
|
|
PADDLE_ENFORCE_GT(level0.size(), 1,
|
|
|
|
"Target LoD is not found, should be set to be a valid one "
|
|
|
|
"If Input(Y) not provided, the target lod should be "
|
|
|
|
"through Input() or Attr().");
|
|
|
|
"specified by attribute `target_lod`.");
|
|
|
|
}
|
|
|
|
}
|
|
|
|
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
|
|
|
|
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -50,36 +49,77 @@ class LoDResetOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
LoDResetOpMaker(OpProto *proto, OpAttrChecker *op_checker)
|
|
|
|
LoDResetOpMaker(OpProto *proto, OpAttrChecker *op_checker)
|
|
|
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
|
|
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
|
|
|
AddInput("X", "(LoDTensor) The input tensor of lod_reset operator.");
|
|
|
|
AddInput("X",
|
|
|
|
AddInput("TargetLoD",
|
|
|
|
"(Tensor, LoDTensor) Input variable of LoDResetOp which "
|
|
|
|
"(Tensor, optional) The target level 0 LoD from Input().")
|
|
|
|
"could be a Tensor or LoDTensor, where the data of output "
|
|
|
|
|
|
|
|
"variable inherits from.");
|
|
|
|
|
|
|
|
AddInput("Y",
|
|
|
|
|
|
|
|
"(Tensor, LoDTensor, optional) If provided and Y is LoDTensor, "
|
|
|
|
|
|
|
|
"lod of Input(Y) would be considered as the target lod first, "
|
|
|
|
|
|
|
|
"otherwise data of Input(Y) would be considered as the "
|
|
|
|
|
|
|
|
"target lod.")
|
|
|
|
.AsDispensable();
|
|
|
|
.AsDispensable();
|
|
|
|
AddOutput("Out", "(LoDTensor) The output tensor of lod_reset operator.");
|
|
|
|
AddOutput("Out",
|
|
|
|
|
|
|
|
"(LoDTensor) Output variable of LoDResetOp which should be a "
|
|
|
|
|
|
|
|
"LoDTensor.");
|
|
|
|
AddAttr<std::vector<int>>("target_lod",
|
|
|
|
AddAttr<std::vector<int>>("target_lod",
|
|
|
|
"The target level 0 LoD from Attr().")
|
|
|
|
"The target level 0 LoD from Attr().")
|
|
|
|
.SetDefault(std::vector<int>{});
|
|
|
|
.SetDefault(std::vector<int>{});
|
|
|
|
AddComment(R"DOC(LoDReset operator
|
|
|
|
AddComment(R"DOC(LoDReset operator
|
|
|
|
|
|
|
|
|
|
|
|
Reset LoD of Input(X) into a new one specified by Input(TargetLoD) or
|
|
|
|
Set LoD of `X` to a new one specified by `Y` or attribute `target_lod`. When `Y`
|
|
|
|
Attr(target_lod), or set LoD for Input(X) if it doesn't have one.
|
|
|
|
provided and `Y` is a LoDTensor, `Y.lod` would be considered as target LoD
|
|
|
|
Currently the lod_reset operator only supports the reset of level 0 LoD.
|
|
|
|
first, otherwise `Y.data` would be considered as target LoD. If `Y` is not
|
|
|
|
At least one of Input(TargetLoD) and Attr(target_lod) must be set,
|
|
|
|
provided, target LoD should be specified by attribute `target_lod`.
|
|
|
|
and if both of them are set, Input(TargetLoD) will be chosen as the
|
|
|
|
If target LoD is specified by `Y.data` or `target_lod`, only one level LoD
|
|
|
|
target LoD.
|
|
|
|
is supported.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Example 1:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Given a 1-level LoDTensor input(X):
|
|
|
|
|
|
|
|
X.lod = [[ 0, 2, 5 6 ]]
|
|
|
|
|
|
|
|
X.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]]
|
|
|
|
|
|
|
|
X.dims = [6, 1]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
attr(target_lod): [0, 4, 6]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
then we get a 1-level LoDTensor:
|
|
|
|
|
|
|
|
Out.lod = [[ 0, 4, 6 ]]
|
|
|
|
|
|
|
|
Out.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]]
|
|
|
|
|
|
|
|
Out.dims = [6, 1]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Example 2:
|
|
|
|
|
|
|
|
|
|
|
|
An example:
|
|
|
|
Given a 1-level LoDTensor input(X):
|
|
|
|
Given a float LoDTensor X with shape (6, 1), its transpose form represents
|
|
|
|
X.lod = [[ 0, 2, 5 6 ]]
|
|
|
|
|
|
|
|
X.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]]
|
|
|
|
|
|
|
|
X.dims = [6, 1]
|
|
|
|
|
|
|
|
|
|
|
|
[1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
|
|
|
|
input(Y) is a Tensor:
|
|
|
|
|
|
|
|
Y.data = [[0, 2, 6]]
|
|
|
|
|
|
|
|
Y.dims = [1, 3]
|
|
|
|
|
|
|
|
|
|
|
|
with LoD = [[0, 2, 5, 6]] and the three (transposed) sequences look like
|
|
|
|
then we get a 1-level LoDTensor:
|
|
|
|
|
|
|
|
Out.lod = [[ 0, 2, 6 ]]
|
|
|
|
|
|
|
|
Out.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]]
|
|
|
|
|
|
|
|
Out.dims = [6, 1]
|
|
|
|
|
|
|
|
|
|
|
|
[1.0, 2.0], [3.0, 4.0, 5.0], [6.0].
|
|
|
|
Example 3:
|
|
|
|
|
|
|
|
|
|
|
|
If target LoD = [0, 4, 6], the lod_reset operator will reset the LoD and
|
|
|
|
Given a 1-level LoDTensor input(X):
|
|
|
|
the sequences that the LoDTensor Output(Out) contains becomes:
|
|
|
|
X.lod = [[ 0, 2, 5 6 ]]
|
|
|
|
|
|
|
|
X.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]]
|
|
|
|
|
|
|
|
X.dims = [6, 1]
|
|
|
|
|
|
|
|
|
|
|
|
[1.0, 2.0, 3.0, 4.0], [5.0, 6.0].
|
|
|
|
input(Y) is a 2-level LoDTensor:
|
|
|
|
|
|
|
|
Y.lod = [[0, 2, 4], [0, 2, 5, 6]]
|
|
|
|
|
|
|
|
Y.data = [[1.1], [2.1], [3.1], [4.1], [5.1], [6.1]]
|
|
|
|
|
|
|
|
Y.dims = [6, 1]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
then we get a 2-level LoDTensor:
|
|
|
|
|
|
|
|
Out.lod = [[0, 2, 4], [0, 2, 5, 6]]
|
|
|
|
|
|
|
|
Out.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]]
|
|
|
|
|
|
|
|
Out.dims = [6, 1]
|
|
|
|
|
|
|
|
|
|
|
|
)DOC");
|
|
|
|
)DOC");
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -90,10 +130,16 @@ class LoDResetGradOp : public framework::OperatorWithKernel {
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
|
|
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) shouldn't be null.");
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"),
|
|
|
|
|
|
|
|
"Input(X) of LoDResetGradOp should not be null.");
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
|
|
|
|
"Input(Out@GRAD) shouldn't be null.");
|
|
|
|
"Input(Out@Grad) of LoDResetGradOp should not be null.");
|
|
|
|
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
|
|
|
|
|
|
|
|
|
|
|
|
auto x_grad_name = framework::GradVarName("X");
|
|
|
|
|
|
|
|
if (ctx->HasOutput(x_grad_name)) {
|
|
|
|
|
|
|
|
ctx->SetOutputDim(x_grad_name, ctx->GetInputDim("X"));
|
|
|
|
|
|
|
|
ctx->ShareLoD("X", /*->*/ x_grad_name);
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
protected:
|
|
|
@ -111,9 +157,13 @@ class LoDResetGradOp : public framework::OperatorWithKernel {
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
REGISTER_OP(lod_reset, ops::LoDResetOp, ops::LoDResetOpMaker, lod_reset_grad,
|
|
|
|
REGISTER_OP(lod_reset, ops::LoDResetOp, ops::LoDResetOpMaker, lod_reset_grad,
|
|
|
|
ops::LoDResetGradOp);
|
|
|
|
ops::LoDResetGradOp);
|
|
|
|
REGISTER_OP_CPU_KERNEL(lod_reset,
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
ops::LoDResetKernel<paddle::platform::CPUPlace, float>,
|
|
|
|
lod_reset, ops::LoDResetKernel<paddle::platform::CPUPlace, float>,
|
|
|
|
ops::LoDResetKernel<paddle::platform::CPUPlace, double>);
|
|
|
|
ops::LoDResetKernel<paddle::platform::CPUPlace, double>,
|
|
|
|
|
|
|
|
ops::LoDResetKernel<paddle::platform::CPUPlace, int>,
|
|
|
|
|
|
|
|
ops::LoDResetKernel<paddle::platform::CPUPlace, int64_t>);
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
lod_reset_grad, ops::LoDResetGradKernel<paddle::platform::CPUPlace, float>,
|
|
|
|
lod_reset_grad, ops::LoDResetGradKernel<paddle::platform::CPUPlace, float>,
|
|
|
|
ops::LoDResetGradKernel<paddle::platform::CPUPlace, double>);
|
|
|
|
ops::LoDResetGradKernel<paddle::platform::CPUPlace, double>,
|
|
|
|
|
|
|
|
ops::LoDResetGradKernel<paddle::platform::CPUPlace, int>,
|
|
|
|
|
|
|
|
ops::LoDResetGradKernel<paddle::platform::CPUPlace, int64_t>);
|
|
|
|