|
|
|
@ -1,16 +1,16 @@
|
|
|
|
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
|
|
|
|
|
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
you may not use this file except in compliance with the License.
|
|
|
|
|
You may obtain a copy of the License at
|
|
|
|
|
|
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
|
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
|
|
|
|
//
|
|
|
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
// you may not use this file except in compliance with the License.
|
|
|
|
|
// You may obtain a copy of the License at
|
|
|
|
|
//
|
|
|
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
//
|
|
|
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
// See the License for the specific language governing permissions and
|
|
|
|
|
// limitations under the License.
|
|
|
|
|
|
|
|
|
|
#include <vector>
|
|
|
|
|
#include "paddle/fluid/framework/executor.h"
|
|
|
|
@ -138,6 +138,10 @@ class WhileGradOp : public framework::OperatorBase {
|
|
|
|
|
auto inside_og_name = inside_og_names[i];
|
|
|
|
|
VLOG(8) << "Linking outside " << outside_og_name << " --> inside "
|
|
|
|
|
<< inside_og_name;
|
|
|
|
|
if (scope.FindVar(outside_og_name) == nullptr) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto &og_outside =
|
|
|
|
|
detail::Ref(scope.FindVar(outside_og_name),
|
|
|
|
|
"Cannot find Outside Gradient %s", outside_og_name);
|
|
|
|
@ -167,20 +171,46 @@ class WhileGradOp : public framework::OperatorBase {
|
|
|
|
|
PADDLE_ENFORCE_EQ(inside_array[j].numel(), 0);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_THROW("Currently only support LoDTensor and LoDTensorArray.");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
executor.RunPreparedContext(ctx.get(), *cur_scope_iter, false, true,
|
|
|
|
|
true);
|
|
|
|
|
|
|
|
|
|
auto &pg_names = Outputs(kXGRAD);
|
|
|
|
|
// The Outputs(kXGRAD) contains the names of the gradient of parameters
|
|
|
|
|
// and inputs.
|
|
|
|
|
auto &pg_ig_names = Outputs(kXGRAD);
|
|
|
|
|
auto &p_names = Inputs(kX);
|
|
|
|
|
PADDLE_ENFORCE_EQ(pg_names.size(), p_names.size());
|
|
|
|
|
for (size_t param_id = 0; param_id < pg_names.size(); ++param_id) {
|
|
|
|
|
if (pg_names[param_id] == framework::kEmptyVarName) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(pg_ig_names.size(), p_names.size());
|
|
|
|
|
for (size_t param_id = 0; param_id < pg_ig_names.size(); ++param_id) {
|
|
|
|
|
if (pg_ig_names[param_id] == framework::kEmptyVarName) {
|
|
|
|
|
continue; // parameter doesn't have gradient
|
|
|
|
|
}
|
|
|
|
|
auto inside_grad_name = framework::GradVarName(p_names[param_id]);
|
|
|
|
|
|
|
|
|
|
// for some grad_op, their input doesn't have gradient,
|
|
|
|
|
// for example lookup_table_grad_op, the input(Idx) doesn't have
|
|
|
|
|
// gradient.
|
|
|
|
|
auto pg_ig_var = cur_scope.FindVar(inside_grad_name);
|
|
|
|
|
PADDLE_ENFORCE(pg_ig_var != nullptr);
|
|
|
|
|
if (pg_ig_var->IsType<framework::LoDTensorArray>()) {
|
|
|
|
|
auto pg_ig_lod_t_arr =
|
|
|
|
|
pg_ig_var->GetMutable<framework::LoDTensorArray>();
|
|
|
|
|
bool empty = true;
|
|
|
|
|
for (auto &each : *pg_ig_lod_t_arr) {
|
|
|
|
|
if (each.numel() != 0) {
|
|
|
|
|
empty = false;
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (empty) {
|
|
|
|
|
LOG(WARNING) << pg_ig_names[param_id]
|
|
|
|
|
<< " is not found in cur_scope.";
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// // TODO(tonyyang-svail): Not sure we need the following
|
|
|
|
|
// // If does not compute gradient of that variable inside rnn,
|
|
|
|
|
// just
|
|
|
|
@ -194,6 +224,11 @@ class WhileGradOp : public framework::OperatorBase {
|
|
|
|
|
if (cur_scope_iter == step_scopes->rbegin()) {
|
|
|
|
|
auto *var = (*cur_scope_iter)->FindVar(inside_grad_name);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(var, "Can not find var %s", inside_grad_name);
|
|
|
|
|
PADDLE_ENFORCE(var->IsType<framework::LoDTensorArray>() ||
|
|
|
|
|
var->IsType<LoDTensor>(),
|
|
|
|
|
"Currently the type of var only can be LoDTensorArray "
|
|
|
|
|
"or LoDTensor.");
|
|
|
|
|
|
|
|
|
|
if (var->IsType<LoDTensor>()) {
|
|
|
|
|
auto &inside_tensor = var->Get<framework::LoDTensor>();
|
|
|
|
|
framework::AttributeMap attrs;
|
|
|
|
@ -201,7 +236,7 @@ class WhileGradOp : public framework::OperatorBase {
|
|
|
|
|
attrs["shape"] = framework::vectorize2int(inside_tensor.dims());
|
|
|
|
|
attrs["value"] = 0.0f;
|
|
|
|
|
|
|
|
|
|
auto var_name = pg_names[param_id];
|
|
|
|
|
auto var_name = pg_ig_names[param_id];
|
|
|
|
|
auto zero_op = framework::OpRegistry::CreateOp(
|
|
|
|
|
"fill_constant", framework::VariableNameMap{},
|
|
|
|
|
{{"Out", {var_name}}}, attrs);
|
|
|
|
@ -213,8 +248,8 @@ class WhileGradOp : public framework::OperatorBase {
|
|
|
|
|
}
|
|
|
|
|
auto new_inside_name = cur_scope.Rename(inside_grad_name);
|
|
|
|
|
auto sum_op = framework::OpRegistry::CreateOp(
|
|
|
|
|
"sum", {{"X", {pg_names[param_id], new_inside_name}}},
|
|
|
|
|
{{"Out", {pg_names[param_id]}}},
|
|
|
|
|
"sum", {{"X", {pg_ig_names[param_id], new_inside_name}}},
|
|
|
|
|
{{"Out", {pg_ig_names[param_id]}}},
|
|
|
|
|
framework::AttributeMap{{"use_mkldnn", {false}}});
|
|
|
|
|
sum_op->Run(cur_scope, dev_place);
|
|
|
|
|
cur_scope.Rename(new_inside_name, inside_grad_name);
|
|
|
|
@ -281,6 +316,7 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
|
|
|
|
|
parent_block->FindVarRecursive(input_name) != nullptr)) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
output_grads.insert(input_name);
|
|
|
|
|
}
|
|
|
|
|
for (auto &output_name : op->OutputArgumentNames()) {
|
|
|
|
@ -309,13 +345,13 @@ class WhileGradOpVarTypeInference : public framework::VarTypeInference {
|
|
|
|
|
void operator()(const framework::OpDesc &op_desc,
|
|
|
|
|
framework::BlockDesc *block) const override {
|
|
|
|
|
auto p_names = op_desc.Input(kX);
|
|
|
|
|
auto pg_names = op_desc.Output(framework::GradVarName(kX));
|
|
|
|
|
auto pg_ig_names = op_desc.Output(framework::GradVarName(kX));
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < p_names.size(); ++i) {
|
|
|
|
|
auto &p_var = detail::Ref(block->FindVarRecursive(p_names[i]));
|
|
|
|
|
auto *g_var = block->FindVarRecursive(pg_names[i]);
|
|
|
|
|
auto *g_var = block->FindVarRecursive(pg_ig_names[i]);
|
|
|
|
|
if (g_var != nullptr) { // Gradient could be @EMPTY@
|
|
|
|
|
VLOG(5) << "Setting " << pg_names[i] << " following " << p_names[i]
|
|
|
|
|
VLOG(5) << "Setting " << pg_ig_names[i] << " following " << p_names[i]
|
|
|
|
|
<< " type: " << p_var.GetType();
|
|
|
|
|
g_var->SetType(p_var.GetType());
|
|
|
|
|
g_var->SetDataType(p_var.GetDataType());
|
|
|
|
@ -333,21 +369,21 @@ class WhileGradOpShapeInference : public framework::InferShapeBase {
|
|
|
|
|
ctx->HasInputs(framework::GradVarName(kOutputs));
|
|
|
|
|
|
|
|
|
|
auto p_names = ctx->Inputs(kX);
|
|
|
|
|
auto pg_names = ctx->Outputs(kXGRAD);
|
|
|
|
|
auto pg_ig_names = ctx->Outputs(kXGRAD);
|
|
|
|
|
auto var_types = ctx->GetInputsVarType(kX);
|
|
|
|
|
std::vector<std::string> names_to_set;
|
|
|
|
|
std::vector<framework::DDim> dims_to_set;
|
|
|
|
|
for (size_t i = 0; i < p_names.size(); ++i) {
|
|
|
|
|
if (pg_names[i] == framework::kEmptyVarName) {
|
|
|
|
|
if (pg_ig_names[i] == framework::kEmptyVarName) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
auto dims = ctx->GetInputsElementDim(kX, i);
|
|
|
|
|
if (var_types[i] == framework::proto::VarType::LOD_TENSOR) {
|
|
|
|
|
names_to_set.push_back(pg_names[i]);
|
|
|
|
|
names_to_set.push_back(pg_ig_names[i]);
|
|
|
|
|
dims_to_set.push_back(dims);
|
|
|
|
|
} else if (var_types[i] == framework::proto::VarType::LOD_TENSOR_ARRAY) {
|
|
|
|
|
// not sure how to set the dim of LOD_TENSOR_ARRAY
|
|
|
|
|
names_to_set.push_back(pg_names[i]);
|
|
|
|
|
names_to_set.push_back(pg_ig_names[i]);
|
|
|
|
|
dims_to_set.push_back(dims);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|