|
|
@ -18,12 +18,12 @@
|
|
|
|
namespace paddle {
|
|
|
|
namespace paddle {
|
|
|
|
namespace operators {
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
|
|
|
class ShrinkStateOp : public ArrayOp {
|
|
|
|
class ShrinkRNNMemoryOp : public ArrayOp {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
ShrinkStateOp(const std::string &type,
|
|
|
|
ShrinkRNNMemoryOp(const std::string &type,
|
|
|
|
const framework::VariableNameMap &inputs,
|
|
|
|
const framework::VariableNameMap &inputs,
|
|
|
|
const framework::VariableNameMap &outputs,
|
|
|
|
const framework::VariableNameMap &outputs,
|
|
|
|
const framework::AttributeMap &attrs)
|
|
|
|
const framework::AttributeMap &attrs)
|
|
|
|
: ArrayOp(type, inputs, outputs, attrs) {}
|
|
|
|
: ArrayOp(type, inputs, outputs, attrs) {}
|
|
|
|
|
|
|
|
|
|
|
|
void Run(const framework::Scope &scope,
|
|
|
|
void Run(const framework::Scope &scope,
|
|
|
@ -36,18 +36,12 @@ class ShrinkStateOp : public ArrayOp {
|
|
|
|
PADDLE_ENFORCE(rank_table_var != nullptr, "RankTable must be set");
|
|
|
|
PADDLE_ENFORCE(rank_table_var != nullptr, "RankTable must be set");
|
|
|
|
auto &rank_table = rank_table_var->Get<framework::LoDRankTable>();
|
|
|
|
auto &rank_table = rank_table_var->Get<framework::LoDRankTable>();
|
|
|
|
|
|
|
|
|
|
|
|
int dst_num_rows = 0;
|
|
|
|
auto &rank_items = rank_table.items();
|
|
|
|
|
|
|
|
int dst_num_rows =
|
|
|
|
{
|
|
|
|
std::lower_bound(rank_items.begin(), rank_items.end(), offset,
|
|
|
|
auto &rank_items = rank_table.items();
|
|
|
|
[](const framework::LoDRankTable::TableItem &a,
|
|
|
|
for (auto &rank_item : rank_items) {
|
|
|
|
size_t b) { return a.length > b; }) -
|
|
|
|
if (offset < rank_item.length) {
|
|
|
|
rank_items.begin();
|
|
|
|
++dst_num_rows;
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
break;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
auto *out_var = scope.FindVar(Output("Out"));
|
|
|
|
auto *out_var = scope.FindVar(Output("Out"));
|
|
|
|
PADDLE_ENFORCE(out_var != nullptr, "Output Out must be set");
|
|
|
|
PADDLE_ENFORCE(out_var != nullptr, "Output Out must be set");
|
|
|
@ -58,10 +52,10 @@ class ShrinkStateOp : public ArrayOp {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
class ShrinkStateOpProtoMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
class ShrinkRNNMemoryOpProtoMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
ShrinkStateOpProtoMaker(framework::OpProto *proto,
|
|
|
|
ShrinkRNNMemoryOpProtoMaker(framework::OpProto *proto,
|
|
|
|
framework::OpAttrChecker *op_checker)
|
|
|
|
framework::OpAttrChecker *op_checker)
|
|
|
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
|
|
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
|
|
|
AddInput("X", "");
|
|
|
|
AddInput("X", "");
|
|
|
|
AddInput("RankTable", "");
|
|
|
|
AddInput("RankTable", "");
|
|
|
@ -71,7 +65,7 @@ class ShrinkStateOpProtoMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
class ShrinkStateOpInferShape : public framework::InferShapeBase {
|
|
|
|
class ShrinkRNNMemoryInferShape : public framework::InferShapeBase {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
void operator()(framework::InferShapeContext *context) const override {
|
|
|
|
void operator()(framework::InferShapeContext *context) const override {
|
|
|
|
PADDLE_ENFORCE(context->HasInput("X"));
|
|
|
|
PADDLE_ENFORCE(context->HasInput("X"));
|
|
|
@ -81,19 +75,18 @@ class ShrinkStateOpInferShape : public framework::InferShapeBase {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
class ShrinkStateGradOp : public ArrayOp {
|
|
|
|
class ShrinkRNNMemoryGradOp : public ArrayOp {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
ShrinkStateGradOp(const std::string &type,
|
|
|
|
ShrinkRNNMemoryGradOp(const std::string &type,
|
|
|
|
const framework::VariableNameMap &inputs,
|
|
|
|
const framework::VariableNameMap &inputs,
|
|
|
|
const framework::VariableNameMap &outputs,
|
|
|
|
const framework::VariableNameMap &outputs,
|
|
|
|
const framework::AttributeMap &attrs)
|
|
|
|
const framework::AttributeMap &attrs)
|
|
|
|
: ArrayOp(type, inputs, outputs, attrs) {}
|
|
|
|
: ArrayOp(type, inputs, outputs, attrs) {}
|
|
|
|
|
|
|
|
|
|
|
|
void Run(const framework::Scope &scope,
|
|
|
|
void Run(const framework::Scope &scope,
|
|
|
|
const platform::DeviceContext &dev_ctx) const override {
|
|
|
|
const platform::DeviceContext &dev_ctx) const override {
|
|
|
|
auto *dout_var = scope.FindVar(Input(framework::GradVarName("Out")));
|
|
|
|
auto *dout_var = scope.FindVar(Input(framework::GradVarName("Out")));
|
|
|
|
auto dx_name = Output(framework::GradVarName("X"));
|
|
|
|
auto *dx_var = scope.FindVar(Output(framework::GradVarName("X")));
|
|
|
|
auto *dx_var = scope.FindVar(dx_name);
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(dx_var != nullptr, "Input Gradient should not be nullptr");
|
|
|
|
PADDLE_ENFORCE(dx_var != nullptr, "Input Gradient should not be nullptr");
|
|
|
|
auto *x_var = scope.FindVar(Input("X"));
|
|
|
|
auto *x_var = scope.FindVar(Input("X"));
|
|
|
|
PADDLE_ENFORCE(x_var != nullptr);
|
|
|
|
PADDLE_ENFORCE(x_var != nullptr);
|
|
|
@ -110,7 +103,7 @@ class ShrinkStateGradOp : public ArrayOp {
|
|
|
|
auto height = dout_tensor.dims()[0];
|
|
|
|
auto height = dout_tensor.dims()[0];
|
|
|
|
dx_tensor.Slice(0, static_cast<int>(height))
|
|
|
|
dx_tensor.Slice(0, static_cast<int>(height))
|
|
|
|
.CopyFrom(dout_tensor, dout_tensor.place(), dev_ctx);
|
|
|
|
.CopyFrom(dout_tensor, dout_tensor.place(), dev_ctx);
|
|
|
|
if (height < dout_tensor.dims()[0]) {
|
|
|
|
if (dx_tensor.dims()[0] < height) {
|
|
|
|
auto rest_tensor = dx_tensor.Slice(
|
|
|
|
auto rest_tensor = dx_tensor.Slice(
|
|
|
|
static_cast<int>(height), static_cast<int>(dout_tensor.dims()[0]));
|
|
|
|
static_cast<int>(height), static_cast<int>(dout_tensor.dims()[0]));
|
|
|
|
math::set_constant(dev_ctx, &rest_tensor, 0.0f);
|
|
|
|
math::set_constant(dev_ctx, &rest_tensor, 0.0f);
|
|
|
@ -119,7 +112,7 @@ class ShrinkStateGradOp : public ArrayOp {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
class ShrikStateGradInferShape : public framework::InferShapeBase {
|
|
|
|
class ShrinkRNNMemoryGradInferShape : public framework::InferShapeBase {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
void operator()(framework::InferShapeContext *context) const override {
|
|
|
|
void operator()(framework::InferShapeContext *context) const override {
|
|
|
|
PADDLE_ENFORCE(context->HasInput("X"));
|
|
|
|
PADDLE_ENFORCE(context->HasInput("X"));
|
|
|
@ -129,14 +122,14 @@ class ShrikStateGradInferShape : public framework::InferShapeBase {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
class ShrinkStateGradOpMaker : public framework::SingleGradOpDescMaker {
|
|
|
|
class ShrinkRNNGradOpMaker : public framework::SingleGradOpDescMaker {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
|
|
|
|
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
|
|
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
protected:
|
|
|
|
std::unique_ptr<framework::OpDescBind> Apply() const override {
|
|
|
|
std::unique_ptr<framework::OpDescBind> Apply() const override {
|
|
|
|
auto *op = new framework::OpDescBind();
|
|
|
|
auto *op = new framework::OpDescBind();
|
|
|
|
op->SetType("shrink_state_grad");
|
|
|
|
op->SetType("shrink_rnn_memory_grad");
|
|
|
|
op->SetInput("X", Input("X"));
|
|
|
|
op->SetInput("X", Input("X"));
|
|
|
|
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
|
|
|
|
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
|
|
|
|
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
|
|
|
|
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
|
|
|
@ -149,8 +142,8 @@ class ShrinkStateGradOpMaker : public framework::SingleGradOpDescMaker {
|
|
|
|
} // namespace paddle
|
|
|
|
} // namespace paddle
|
|
|
|
|
|
|
|
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
REGISTER_OPERATOR(shrink_state, ops::ShrinkStateOp,
|
|
|
|
REGISTER_OPERATOR(shrink_rnn_memory, ops::ShrinkRNNMemoryOp,
|
|
|
|
ops::ShrinkStateOpInferShape, ops::ShrinkStateOpProtoMaker,
|
|
|
|
ops::ShrinkRNNMemoryInferShape,
|
|
|
|
ops::ShrinkStateGradOpMaker);
|
|
|
|
ops::ShrinkRNNMemoryOpProtoMaker, ops::ShrinkRNNGradOpMaker);
|
|
|
|
REGISTER_OPERATOR(shrink_state_grad, ops::ShrinkStateGradOp,
|
|
|
|
REGISTER_OPERATOR(shrink_rnn_memory_grad, ops::ShrinkRNNMemoryGradOp,
|
|
|
|
ops::ShrikStateGradInferShape);
|
|
|
|
ops::ShrinkRNNMemoryGradInferShape);
|