|
|
|
@ -157,11 +157,13 @@ class RecurrentBase : public framework::OperatorBase {
|
|
|
|
|
const std::vector<std::string> &src_vars,
|
|
|
|
|
framework::Scope *dst_scope,
|
|
|
|
|
const std::vector<std::string> &dst_vars,
|
|
|
|
|
Callback callback) {
|
|
|
|
|
Callback callback,
|
|
|
|
|
bool is_backward = false) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(src_vars.size(), dst_vars.size());
|
|
|
|
|
for (size_t i = 0; i < dst_vars.size(); ++i) {
|
|
|
|
|
VLOG(10) << "Link " << src_vars[i] << " to " << dst_vars[i];
|
|
|
|
|
AccessTensor(src_scope, src_vars[i], dst_scope, dst_vars[i], callback);
|
|
|
|
|
AccessTensor(src_scope, src_vars[i], dst_scope, dst_vars[i], callback,
|
|
|
|
|
is_backward);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -173,11 +175,13 @@ class RecurrentBase : public framework::OperatorBase {
|
|
|
|
|
const std::vector<std::string> &src_vars,
|
|
|
|
|
const framework::Scope &dst_scope,
|
|
|
|
|
const std::vector<std::string> &dst_vars,
|
|
|
|
|
Callback callback) {
|
|
|
|
|
Callback callback,
|
|
|
|
|
bool is_backward = false) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(src_vars.size(), dst_vars.size());
|
|
|
|
|
for (size_t i = 0; i < dst_vars.size(); ++i) {
|
|
|
|
|
VLOG(10) << "Link " << src_vars[i] << " to " << dst_vars[i];
|
|
|
|
|
AccessTensor(src_scope, src_vars[i], dst_scope, dst_vars[i], callback);
|
|
|
|
|
AccessTensor(src_scope, src_vars[i], dst_scope, dst_vars[i], callback,
|
|
|
|
|
is_backward);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -194,9 +198,13 @@ class RecurrentBase : public framework::OperatorBase {
|
|
|
|
|
static void AccessTensor(const framework::Scope &src_scope,
|
|
|
|
|
const std::string &src_var_name,
|
|
|
|
|
framework::Scope *dst_scope,
|
|
|
|
|
const std::string &dst_var_name, Callback callback) {
|
|
|
|
|
const std::string &dst_var_name, Callback callback,
|
|
|
|
|
bool is_backward = false) {
|
|
|
|
|
auto *src_var = src_scope.FindVar(src_var_name);
|
|
|
|
|
PADDLE_ENFORCE(src_var != nullptr);
|
|
|
|
|
if (is_backward && src_var == nullptr) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
PADDLE_ENFORCE(src_var != nullptr, "%s is not found.", src_var_name);
|
|
|
|
|
auto &src_tensor = src_var->Get<framework::LoDTensor>();
|
|
|
|
|
|
|
|
|
|
auto *dst_var = dst_scope->Var(dst_var_name);
|
|
|
|
@ -208,12 +216,16 @@ class RecurrentBase : public framework::OperatorBase {
|
|
|
|
|
static void AccessTensor(const framework::Scope &src_scope,
|
|
|
|
|
const std::string &src_var_name,
|
|
|
|
|
const framework::Scope &dst_scope,
|
|
|
|
|
const std::string &dst_var_name, Callback callback) {
|
|
|
|
|
const std::string &dst_var_name, Callback callback,
|
|
|
|
|
bool is_backward = false) {
|
|
|
|
|
auto *dst_var = dst_scope.FindVar(dst_var_name);
|
|
|
|
|
if (is_backward && dst_var == nullptr) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
auto *src_var = src_scope.FindVar(src_var_name);
|
|
|
|
|
PADDLE_ENFORCE(src_var != nullptr);
|
|
|
|
|
PADDLE_ENFORCE(src_var != nullptr, "%s is not found.", src_var_name);
|
|
|
|
|
auto &src_tensor = src_var->Get<framework::LoDTensor>();
|
|
|
|
|
auto *dst_var = dst_scope.FindVar(dst_var_name);
|
|
|
|
|
PADDLE_ENFORCE(dst_var != nullptr);
|
|
|
|
|
PADDLE_ENFORCE(dst_var != nullptr, "%s is not found.", dst_var_name);
|
|
|
|
|
auto *dst_tensor = dst_var->GetMutable<framework::LoDTensor>();
|
|
|
|
|
callback(src_tensor, dst_tensor);
|
|
|
|
|
}
|
|
|
|
@ -345,7 +357,8 @@ class RecurrentGradOp : public RecurrentBase {
|
|
|
|
|
auto dims = framework::vectorize(inside->dims());
|
|
|
|
|
dims.erase(dims.begin());
|
|
|
|
|
inside->Resize(framework::make_ddim(dims));
|
|
|
|
|
});
|
|
|
|
|
},
|
|
|
|
|
true /*is_backward*/);
|
|
|
|
|
auto og_set = List2Set(Inputs(kOutputGrads));
|
|
|
|
|
|
|
|
|
|
if (VLOG_IS_ON(10)) {
|
|
|
|
@ -454,7 +467,8 @@ class RecurrentGradOp : public RecurrentBase {
|
|
|
|
|
|
|
|
|
|
auto dst = outside->Slice(seq_offset, seq_offset + 1);
|
|
|
|
|
framework::TensorCopy(inside, place, dev_ctx, &dst);
|
|
|
|
|
});
|
|
|
|
|
},
|
|
|
|
|
true /*is_backward*/);
|
|
|
|
|
VLOG(5) << "Link outside gradient finished ";
|
|
|
|
|
|
|
|
|
|
if (step_id + 1 == seq_len) { // at_end
|
|
|
|
@ -467,7 +481,8 @@ class RecurrentGradOp : public RecurrentBase {
|
|
|
|
|
outside->Resize(inside.dims());
|
|
|
|
|
outside->mutable_data(place, inside.type());
|
|
|
|
|
framework::TensorCopy(inside, place, dev_ctx, outside);
|
|
|
|
|
});
|
|
|
|
|
},
|
|
|
|
|
true /*is_backward*/);
|
|
|
|
|
VLOG(5) << "Link initialize state gradient finished ";
|
|
|
|
|
}
|
|
|
|
|
scopes.Next();
|
|
|
|
@ -608,10 +623,8 @@ class RecurrentGradOpShapeInference : public framework::InferShapeBase {
|
|
|
|
|
std::vector<std::string> input{kInputs, kInitialStates};
|
|
|
|
|
std::vector<std::string> output{kOutputs};
|
|
|
|
|
for (auto &s : input) {
|
|
|
|
|
// NOTE(zcd): In some case, some of kInputs doesn't have gradient.
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInputs(s));
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutputs(framework::GradVarName(s)),
|
|
|
|
|
"Cannot find the gradient variable %s",
|
|
|
|
|
framework::GradVarName(s));
|
|
|
|
|
}
|
|
|
|
|
for (auto &s : output) {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInputs(s));
|
|
|
|
|