|
|
@ -25,12 +25,12 @@ namespace operators {
|
|
|
|
using StepScopeVar = std::vector<framework::Scope *>;
|
|
|
|
using StepScopeVar = std::vector<framework::Scope *>;
|
|
|
|
using LoDTensor = framework::LoDTensor;
|
|
|
|
using LoDTensor = framework::LoDTensor;
|
|
|
|
|
|
|
|
|
|
|
|
constexpr char kStepBlock[] = "sub_block";
|
|
|
|
static constexpr char kStepBlock[] = "sub_block";
|
|
|
|
constexpr char kCondition[] = "Condition";
|
|
|
|
static constexpr char kCondition[] = "Condition";
|
|
|
|
constexpr char kStepScopes[] = "StepScopes";
|
|
|
|
static constexpr char kStepScopes[] = "StepScopes";
|
|
|
|
constexpr char kParameters[] = "X";
|
|
|
|
static constexpr char kX[] = "X";
|
|
|
|
constexpr char kParamGrads[] = "X@GRAD";
|
|
|
|
static constexpr char kXGRAD[] = "X@GRAD";
|
|
|
|
constexpr char kOutputs[] = "Out";
|
|
|
|
static constexpr char kOutputs[] = "Out";
|
|
|
|
|
|
|
|
|
|
|
|
class WhileOp : public framework::OperatorBase {
|
|
|
|
class WhileOp : public framework::OperatorBase {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
@ -67,7 +67,7 @@ class WhileOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
WhileOpMaker(OpProto *proto, OpAttrChecker *op_checker)
|
|
|
|
WhileOpMaker(OpProto *proto, OpAttrChecker *op_checker)
|
|
|
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
|
|
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
|
|
|
AddInput(kParameters,
|
|
|
|
AddInput(kX,
|
|
|
|
"A set of variables, which are required by operators inside the "
|
|
|
|
"A set of variables, which are required by operators inside the "
|
|
|
|
"block of While Op.")
|
|
|
|
"block of While Op.")
|
|
|
|
.AsDuplicable();
|
|
|
|
.AsDuplicable();
|
|
|
@ -158,8 +158,8 @@ class WhileGradOp : public framework::OperatorBase {
|
|
|
|
|
|
|
|
|
|
|
|
executor.Run(*program, *cur_scope_iter, block->ID(), false);
|
|
|
|
executor.Run(*program, *cur_scope_iter, block->ID(), false);
|
|
|
|
|
|
|
|
|
|
|
|
auto &pg_names = Outputs(kParamGrads);
|
|
|
|
auto &pg_names = Outputs(kXGRAD);
|
|
|
|
auto &p_names = Inputs(kParameters);
|
|
|
|
auto &p_names = Inputs(kX);
|
|
|
|
PADDLE_ENFORCE_EQ(pg_names.size(), p_names.size());
|
|
|
|
PADDLE_ENFORCE_EQ(pg_names.size(), p_names.size());
|
|
|
|
for (size_t param_id = 0; param_id < pg_names.size(); ++param_id) {
|
|
|
|
for (size_t param_id = 0; param_id < pg_names.size(); ++param_id) {
|
|
|
|
if (pg_names[param_id] == framework::kEmptyVarName) {
|
|
|
|
if (pg_names[param_id] == framework::kEmptyVarName) {
|
|
|
@ -213,11 +213,11 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
|
|
|
|
std::unique_ptr<framework::OpDesc> Apply() const override {
|
|
|
|
std::unique_ptr<framework::OpDesc> Apply() const override {
|
|
|
|
auto *grad = new framework::OpDesc();
|
|
|
|
auto *grad = new framework::OpDesc();
|
|
|
|
grad->SetType("while_grad");
|
|
|
|
grad->SetType("while_grad");
|
|
|
|
grad->SetInput(kParameters, Input(kParameters));
|
|
|
|
grad->SetInput(kX, Input(kX));
|
|
|
|
|
|
|
|
|
|
|
|
// Not all of IGs will be generated by inner gradient operators of while op.
|
|
|
|
// Not all of IGs will be generated by inner gradient operators of while op.
|
|
|
|
// Ignore IGs that is not generated by the inside block.
|
|
|
|
// Ignore IGs that is not generated by the inside block.
|
|
|
|
auto igs = InputGrad(kParameters, /*do not drop empty gradient*/ false);
|
|
|
|
auto igs = InputGrad(kX, /*do not drop empty gradient*/ false);
|
|
|
|
std::unordered_set<std::string> all_outs;
|
|
|
|
std::unordered_set<std::string> all_outs;
|
|
|
|
for (size_t i = 0; i < grad_block_[0]->OpSize(); ++i) {
|
|
|
|
for (size_t i = 0; i < grad_block_[0]->OpSize(); ++i) {
|
|
|
|
for (auto &oname : grad_block_[0]->Op(i)->OutputArgumentNames()) {
|
|
|
|
for (auto &oname : grad_block_[0]->Op(i)->OutputArgumentNames()) {
|
|
|
@ -231,7 +231,7 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
grad->SetOutput(framework::GradVarName(kParameters), igs);
|
|
|
|
grad->SetOutput(framework::GradVarName(kX), igs);
|
|
|
|
|
|
|
|
|
|
|
|
grad->SetInput(kOutputs, Output(kOutputs));
|
|
|
|
grad->SetInput(kOutputs, Output(kOutputs));
|
|
|
|
|
|
|
|
|
|
|
@ -240,7 +240,7 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
|
|
|
|
std::unordered_set<std::string> block_ins;
|
|
|
|
std::unordered_set<std::string> block_ins;
|
|
|
|
auto *fwd_block = this->grad_block_[0]->ParentBlock();
|
|
|
|
auto *fwd_block = this->grad_block_[0]->ParentBlock();
|
|
|
|
{
|
|
|
|
{
|
|
|
|
for (auto &p : Input(kParameters)) {
|
|
|
|
for (auto &p : Input(kX)) {
|
|
|
|
block_ins.insert(p);
|
|
|
|
block_ins.insert(p);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
for (auto &o : Output(kOutputs)) {
|
|
|
|
for (auto &o : Output(kOutputs)) {
|
|
|
@ -288,8 +288,8 @@ class WhileGradOpVarTypeInference : public framework::VarTypeInference {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
void operator()(const framework::OpDesc &op_desc,
|
|
|
|
void operator()(const framework::OpDesc &op_desc,
|
|
|
|
framework::BlockDesc *block) const override {
|
|
|
|
framework::BlockDesc *block) const override {
|
|
|
|
auto p_names = op_desc.Input(kParameters);
|
|
|
|
auto p_names = op_desc.Input(kX);
|
|
|
|
auto pg_names = op_desc.Output(framework::GradVarName(kParameters));
|
|
|
|
auto pg_names = op_desc.Output(framework::GradVarName(kX));
|
|
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < p_names.size(); ++i) {
|
|
|
|
for (size_t i = 0; i < p_names.size(); ++i) {
|
|
|
|
auto &p_var = detail::Ref(block->FindVarRecursive(p_names[i]));
|
|
|
|
auto &p_var = detail::Ref(block->FindVarRecursive(p_names[i]));
|
|
|
@ -307,21 +307,21 @@ class WhileGradOpVarTypeInference : public framework::VarTypeInference {
|
|
|
|
class WhileGradOpShapeInference : public framework::InferShapeBase {
|
|
|
|
class WhileGradOpShapeInference : public framework::InferShapeBase {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
void operator()(framework::InferShapeContext *ctx) const override {
|
|
|
|
void operator()(framework::InferShapeContext *ctx) const override {
|
|
|
|
ctx->HasInputs(kParameters);
|
|
|
|
ctx->HasInputs(kX);
|
|
|
|
ctx->HasOutputs(framework::GradVarName(kParameters));
|
|
|
|
ctx->HasOutputs(framework::GradVarName(kX));
|
|
|
|
ctx->HasInputs(kOutputs);
|
|
|
|
ctx->HasInputs(kOutputs);
|
|
|
|
ctx->HasInputs(framework::GradVarName(kOutputs));
|
|
|
|
ctx->HasInputs(framework::GradVarName(kOutputs));
|
|
|
|
|
|
|
|
|
|
|
|
auto p_names = ctx->Inputs(kParameters);
|
|
|
|
auto p_names = ctx->Inputs(kX);
|
|
|
|
auto pg_names = ctx->Outputs(kParamGrads);
|
|
|
|
auto pg_names = ctx->Outputs(kXGRAD);
|
|
|
|
auto var_types = ctx->GetInputsVarType(kParameters);
|
|
|
|
auto var_types = ctx->GetInputsVarType(kX);
|
|
|
|
std::vector<std::string> names_to_set;
|
|
|
|
std::vector<std::string> names_to_set;
|
|
|
|
std::vector<framework::DDim> dims_to_set;
|
|
|
|
std::vector<framework::DDim> dims_to_set;
|
|
|
|
for (size_t i = 0; i < p_names.size(); ++i) {
|
|
|
|
for (size_t i = 0; i < p_names.size(); ++i) {
|
|
|
|
if (pg_names[i] == framework::kEmptyVarName) {
|
|
|
|
if (pg_names[i] == framework::kEmptyVarName) {
|
|
|
|
continue;
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
auto dims = ctx->GetInputsElementDim(kParameters, i);
|
|
|
|
auto dims = ctx->GetInputsElementDim(kX, i);
|
|
|
|
if (var_types[i] == framework::proto::VarDesc::LOD_TENSOR) {
|
|
|
|
if (var_types[i] == framework::proto::VarDesc::LOD_TENSOR) {
|
|
|
|
names_to_set.push_back(pg_names[i]);
|
|
|
|
names_to_set.push_back(pg_names[i]);
|
|
|
|
dims_to_set.push_back(dims);
|
|
|
|
dims_to_set.push_back(dims);
|
|
|
|