Support memory eager deletion on recurrent OP (#17710)
Test PaddingRNN on V100 GPU device. Test configuration: large model, padding mode (which is the mode using recurrentOp), one GPU. GPU memory (MiB): 6414 (this PR) vs 6837 (without this PR) Speed (steps/s): 10.28 (this PR) vs 9.89 (without this PR)DDDivano-patch-1
parent
0d8e6c9b8b
commit
89bc3fd841
@ -0,0 +1,76 @@
|
||||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/framework/ir/memory_optimize_pass/recurrent_op_eager_deletion_pass.h"
|
||||
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "paddle/fluid/framework/details/computation_op_handle.h"
|
||||
#include "paddle/fluid/framework/details/multi_devices_helper.h"
|
||||
#include "paddle/fluid/framework/ir/graph_helper.h"
|
||||
#include "paddle/fluid/string/string_helper.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
using paddle::operators::OpVariant;
|
||||
using paddle::operators::OpVariantSet;
|
||||
using paddle::operators::OpAndGradOpPair;
|
||||
|
||||
void RecurrentOpEagerDeletionPass::ApplyImpl(Graph *graph) const {
|
||||
// Find all recurrent_op and recurrent_grad_op in graph
|
||||
// Note the graph only contains ops and block 0
|
||||
std::unordered_map<size_t, OpAndGradOpPair> target_ops =
|
||||
DeviceIdToRecurrentAndRecurrentGradOp(*graph);
|
||||
|
||||
for (auto &entry : target_ops) {
|
||||
// Prepare safe eager deletion on different devices because the garbage
|
||||
// collection may be different across devices
|
||||
OpAndGradOpPair &op_pair = entry.second;
|
||||
PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp(&op_pair);
|
||||
}
|
||||
}
|
||||
|
||||
// Returns a std::unordered_map mapping from the device id to recurrent op and
|
||||
// grad op pair
|
||||
std::unordered_map<size_t, OpAndGradOpPair>
|
||||
RecurrentOpEagerDeletionPass::DeviceIdToRecurrentAndRecurrentGradOp(
|
||||
const Graph &graph) const {
|
||||
std::unordered_map<size_t, OpAndGradOpPair> ret;
|
||||
std::vector<details::OpHandleBase *> all_ops =
|
||||
FilterByNodeWrapper<details::OpHandleBase>(graph);
|
||||
|
||||
for (auto *op : all_ops) {
|
||||
auto compute_op = dynamic_cast<details::ComputationOpHandle *>(op);
|
||||
if (compute_op == nullptr) continue;
|
||||
|
||||
if (compute_op->Name() == "recurrent") {
|
||||
// GetScopeIdx() returns device/place id
|
||||
ret[compute_op->GetScopeIdx()].first.emplace(compute_op->GetOp());
|
||||
} else if (compute_op->Name() == "recurrent_grad") {
|
||||
// GetScopeIdx() returns device/place id
|
||||
ret[compute_op->GetScopeIdx()].second.emplace(compute_op->GetOp());
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
|
||||
REGISTER_PASS(recurrent_op_eager_deletion_pass,
|
||||
paddle::framework::ir::RecurrentOpEagerDeletionPass);
|
@ -0,0 +1,43 @@
|
||||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
#include "paddle/fluid/framework/details/computation_op_handle.h"
|
||||
#include "paddle/fluid/framework/details/multi_devices_helper.h"
|
||||
#include "paddle/fluid/framework/ir/graph_helper.h"
|
||||
#include "paddle/fluid/operators/controlflow/op_variant.h"
|
||||
#include "paddle/fluid/operators/controlflow/recurrent_op_helper.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
// Pass class set skip eager deletion vars for recurrent ops
|
||||
class RecurrentOpEagerDeletionPass : public Pass {
|
||||
protected:
|
||||
void ApplyImpl(Graph *graph) const override;
|
||||
|
||||
private:
|
||||
// Returns a std::unordered_map mapping from the device id to recurrent op and
|
||||
// grad op pair
|
||||
std::unordered_map<size_t, paddle::operators::OpAndGradOpPair>
|
||||
DeviceIdToRecurrentAndRecurrentGradOp(const Graph &graph) const;
|
||||
};
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -1,5 +1,7 @@
|
||||
include(operators)
|
||||
register_operators(DEPS naive_executor)
|
||||
cc_library(while_op_helper SRCS while_op_helper.cc DEPS operator)
|
||||
cc_library(op_variant SRCS op_variant.cc DEPS operator proto_desc)
|
||||
cc_library(recurrent_op_helper SRCS recurrent_op_helper.cc DEPS operator op_variant recurrent_op)
|
||||
cc_library(while_op_helper SRCS while_op_helper.cc DEPS operator op_variant)
|
||||
|
||||
file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(logical_and);\nUSE_NO_KERNEL_OP(read_from_array);\n")
|
||||
|
@ -0,0 +1,72 @@
|
||||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/operators/controlflow/op_variant.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
struct InputsVisitor
|
||||
: public boost::static_visitor<const framework::VariableNameMap *> {
|
||||
template <typename OpType>
|
||||
const framework::VariableNameMap *operator()(const OpType *op) const {
|
||||
return &(op->Inputs());
|
||||
}
|
||||
};
|
||||
|
||||
struct OutputsVisitor
|
||||
: public boost::static_visitor<const framework::VariableNameMap *> {
|
||||
template <typename OpType>
|
||||
const framework::VariableNameMap *operator()(const OpType *op) const {
|
||||
return &(op->Outputs());
|
||||
}
|
||||
};
|
||||
|
||||
struct AttributeMapVisitor
|
||||
: public boost::static_visitor<const framework::AttributeMap *> {
|
||||
const framework::AttributeMap *operator()(const framework::OpDesc *op) const {
|
||||
return &(op->GetAttrMap());
|
||||
}
|
||||
|
||||
const framework::AttributeMap *operator()(
|
||||
const framework::OperatorBase *op) const {
|
||||
return &(op->Attrs());
|
||||
}
|
||||
};
|
||||
|
||||
struct RawPointerVisitor : public boost::static_visitor<const void *> {
|
||||
template <typename OpType>
|
||||
const void *operator()(const OpType *op) const {
|
||||
return op;
|
||||
}
|
||||
};
|
||||
|
||||
const framework::VariableNameMap &OpVariant::Inputs() const {
|
||||
return *boost::apply_visitor(InputsVisitor(), op_);
|
||||
}
|
||||
|
||||
const framework::VariableNameMap &OpVariant::Outputs() const {
|
||||
return *boost::apply_visitor(OutputsVisitor(), op_);
|
||||
}
|
||||
|
||||
const framework::AttributeMap &OpVariant::Attrs() const {
|
||||
return *boost::apply_visitor(AttributeMapVisitor(), op_);
|
||||
}
|
||||
|
||||
const void *OpVariant::RawPointer() const {
|
||||
return boost::apply_visitor(RawPointerVisitor(), op_);
|
||||
}
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,69 @@
|
||||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "paddle/fluid/framework/operator.h"
|
||||
#include "paddle/fluid/framework/program_desc.h"
|
||||
#include "paddle/fluid/platform/variant.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
// OpVariant is a wrapper class of OpDesc and OperatorBase pointer
|
||||
// So that API would be the same.
|
||||
class OpVariant {
|
||||
public:
|
||||
OpVariant(const framework::OperatorBase *op) : op_(op) {} // NOLINT
|
||||
|
||||
OpVariant(const framework::OpDesc *op) : op_(op) {} // NOLINT
|
||||
|
||||
const framework::VariableNameMap &Inputs() const;
|
||||
|
||||
const framework::VariableNameMap &Outputs() const;
|
||||
|
||||
const framework::AttributeMap &Attrs() const;
|
||||
|
||||
const void *RawPointer() const;
|
||||
|
||||
template <typename AttrType>
|
||||
const AttrType &Attr(const std::string &name) const {
|
||||
auto &attrs = Attrs();
|
||||
auto it = attrs.find(name);
|
||||
PADDLE_ENFORCE(it != attrs.end(), "Cannot find attribute %s", name);
|
||||
return boost::get<AttrType>(it->second);
|
||||
}
|
||||
|
||||
bool operator==(const OpVariant &other) const {
|
||||
return RawPointer() == other.RawPointer();
|
||||
}
|
||||
|
||||
int which() const { return static_cast<int>(op_.which()); }
|
||||
|
||||
struct Hasher {
|
||||
size_t operator()(const OpVariant &op) const {
|
||||
return reinterpret_cast<size_t>(op.RawPointer());
|
||||
}
|
||||
};
|
||||
|
||||
private:
|
||||
const boost::variant<const framework::OperatorBase *,
|
||||
const framework::OpDesc *>
|
||||
op_;
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,52 @@
|
||||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "paddle/fluid/framework/operator.h"
|
||||
#include "paddle/fluid/operators/controlflow/op_variant.h"
|
||||
#include "paddle/fluid/operators/recurrent_op.h"
|
||||
#include "paddle/fluid/platform/variant.h"
|
||||
#include "paddle/fluid/string/string_helper.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using OpVariantSet = std::unordered_set<OpVariant, OpVariant::Hasher>;
|
||||
using OpAndGradOpPair = std::pair<OpVariantSet, OpVariantSet>;
|
||||
|
||||
// Set vars to skip eager deletion on input recurrent and recurrent_grad for
|
||||
// preparing safe eager deletion. Input contains all recurrent and
|
||||
// recurrent_grad ops at block 0 and the function will find all recurrent and
|
||||
// recurrent_grad ops across blocks.
|
||||
void PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp(
|
||||
OpAndGradOpPair *op_pair);
|
||||
|
||||
// Set vars to skip eager deletion on input recurrent and recurrent_grad for
|
||||
// preparing safe eager deletion. The input block_id must be 0 and caller can
|
||||
// input all ops in the block. The function will find all recurrent and
|
||||
// recurrent_grad ops across blocks.
|
||||
void PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp(
|
||||
int block_id,
|
||||
const std::vector<std::unique_ptr<paddle::framework::OperatorBase>>
|
||||
&all_ops);
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,226 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "paddle/fluid/framework/executor.h"
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
// StepScopes manages scopes inside RNN.
|
||||
// StepScopes::CurScope() get the current scope
|
||||
// StepScopes::ExScope() get the ex-scope, or scope in previous time step.
|
||||
// StepScopes::Next() move to next time step.
|
||||
//
|
||||
// if is_train = False, then
|
||||
// there are two scopes for the RNN and just support forward.
|
||||
// else
|
||||
// the len(scopes) == seq_len
|
||||
//
|
||||
// if is_backward = True, then
|
||||
// reversely access scopes
|
||||
// else
|
||||
// access scopes from begin to end.
|
||||
class StepScopes {
|
||||
public:
|
||||
StepScopes(const platform::DeviceContext &dev_ctx,
|
||||
const framework::Scope &parent,
|
||||
std::vector<framework::Scope *> *scopes, bool is_train,
|
||||
size_t seq_len, bool is_backward = false);
|
||||
|
||||
framework::Scope &CurScope();
|
||||
|
||||
framework::Scope &ExScope();
|
||||
|
||||
void Next();
|
||||
|
||||
private:
|
||||
framework::Scope &GetScope(size_t scope_id) const;
|
||||
|
||||
size_t counter_;
|
||||
std::vector<framework::Scope *> *scopes_;
|
||||
bool is_train_;
|
||||
bool is_backward_;
|
||||
};
|
||||
|
||||
// Base class for RecurrentOp/RecurrentGradOp
|
||||
// Some common protected functions for RecurrentOp/RecurrentGradOp
|
||||
class RecurrentBase : public framework::OperatorBase {
|
||||
public:
|
||||
static const char kInputs[];
|
||||
static const char kInitialStates[];
|
||||
static const char kParameters[];
|
||||
static const char kOutputs[];
|
||||
static const char kStepScopes[];
|
||||
static const char kHasStates[];
|
||||
static const char kExStates[];
|
||||
static const char kStates[];
|
||||
static const char kStepBlock[];
|
||||
static const char kReverse[];
|
||||
static const char kIsTrain[];
|
||||
static const char kSkipEagerDeletionVars[];
|
||||
static const char kInputGrads[];
|
||||
static const char kOutputGrads[];
|
||||
static const char kParamGrads[];
|
||||
static const char kInitStateGrads[];
|
||||
|
||||
RecurrentBase(const std::string &type,
|
||||
const framework::VariableNameMap &inputs,
|
||||
const framework::VariableNameMap &outputs,
|
||||
const framework::AttributeMap &attrs);
|
||||
|
||||
protected:
|
||||
// Get SequenceLength from Scope
|
||||
// The sequence length is got from input tensor. The input tensor's
|
||||
// dimension should be [SEQ_LEN, ..., ...]. The first of the tensor's shape
|
||||
// is SEQ_LEN. The second of the tensor's shape could be the batch size or
|
||||
// nested sequence length.
|
||||
int64_t GetSequenceLength(const framework::Scope &scope) const;
|
||||
|
||||
// for src_tensor, dst_tensor in zip(map(src_scope.FindVar, src_vars),
|
||||
// map(dst_scope.Var, dst_vars)):
|
||||
// dst_tensor.ShareDataWith(src_tensor)
|
||||
static void LinkTensor(const framework::Scope &src_scope,
|
||||
const std::vector<std::string> &src_vars,
|
||||
framework::Scope *dst_scope,
|
||||
const std::vector<std::string> &dst_vars);
|
||||
|
||||
// for src_tensor, dst_tensor in zip(map(src_scope.FindVar, src_vars),
|
||||
// map(dst_scope.Var, dst_vars)):
|
||||
// callback(src_tensor, &dst_tensor)
|
||||
template <typename Callback>
|
||||
static void LinkTensorWithCallback(const framework::Scope &src_scope,
|
||||
const std::vector<std::string> &src_vars,
|
||||
framework::Scope *dst_scope,
|
||||
const std::vector<std::string> &dst_vars,
|
||||
Callback callback,
|
||||
bool is_backward = false) {
|
||||
PADDLE_ENFORCE_EQ(src_vars.size(), dst_vars.size());
|
||||
for (size_t i = 0; i < dst_vars.size(); ++i) {
|
||||
VLOG(10) << "Link " << src_vars[i] << " to " << dst_vars[i];
|
||||
AccessTensor(src_scope, src_vars[i], dst_scope, dst_vars[i], callback,
|
||||
is_backward);
|
||||
}
|
||||
}
|
||||
|
||||
// for src_tensor, dst_tensor in zip(map(src_scope.FindVar, src_vars),
|
||||
// map(dst_scope.FindVar, dst_vars)):
|
||||
// callback(src_tensor, &dst_tensor)
|
||||
template <typename Callback>
|
||||
static void LinkTensorWithCallback(const framework::Scope &src_scope,
|
||||
const std::vector<std::string> &src_vars,
|
||||
const framework::Scope &dst_scope,
|
||||
const std::vector<std::string> &dst_vars,
|
||||
Callback callback,
|
||||
bool is_backward = false) {
|
||||
PADDLE_ENFORCE_EQ(src_vars.size(), dst_vars.size());
|
||||
for (size_t i = 0; i < dst_vars.size(); ++i) {
|
||||
VLOG(10) << "Link " << src_vars[i] << " to " << dst_vars[i];
|
||||
AccessTensor(src_scope, src_vars[i], dst_scope, dst_vars[i], callback,
|
||||
is_backward);
|
||||
}
|
||||
}
|
||||
|
||||
// (seq_len, shape) -> return [seq_len] + list(shape)
|
||||
static framework::DDim PrependDims(size_t seq_len,
|
||||
const framework::DDim &src);
|
||||
|
||||
private:
|
||||
template <typename Callback>
|
||||
static void AccessTensor(const framework::Scope &src_scope,
|
||||
const std::string &src_var_name,
|
||||
framework::Scope *dst_scope,
|
||||
const std::string &dst_var_name, Callback callback,
|
||||
bool is_backward = false) {
|
||||
auto *src_var = src_scope.FindVar(src_var_name);
|
||||
if (is_backward && src_var == nullptr) {
|
||||
return;
|
||||
}
|
||||
PADDLE_ENFORCE(src_var != nullptr, "%s is not found.", src_var_name);
|
||||
auto &src_tensor = src_var->Get<framework::LoDTensor>();
|
||||
|
||||
auto *dst_var = dst_scope->Var(dst_var_name);
|
||||
auto *dst_tensor = dst_var->GetMutable<framework::LoDTensor>();
|
||||
callback(src_tensor, dst_tensor);
|
||||
}
|
||||
|
||||
template <typename Callback>
|
||||
static void AccessTensor(const framework::Scope &src_scope,
|
||||
const std::string &src_var_name,
|
||||
const framework::Scope &dst_scope,
|
||||
const std::string &dst_var_name, Callback callback,
|
||||
bool is_backward = false) {
|
||||
auto *dst_var = dst_scope.FindVar(dst_var_name);
|
||||
if (is_backward && dst_var == nullptr) {
|
||||
return;
|
||||
}
|
||||
auto *src_var = src_scope.FindVar(src_var_name);
|
||||
PADDLE_ENFORCE(src_var != nullptr, "%s is not found.", src_var_name);
|
||||
auto &src_tensor = src_var->Get<framework::LoDTensor>();
|
||||
PADDLE_ENFORCE(dst_var != nullptr, "%s is not found.", dst_var_name);
|
||||
auto *dst_tensor = dst_var->GetMutable<framework::LoDTensor>();
|
||||
callback(src_tensor, dst_tensor);
|
||||
}
|
||||
};
|
||||
|
||||
class RecurrentOp : public RecurrentBase {
|
||||
public:
|
||||
RecurrentOp(const std::string &type, const framework::VariableNameMap &inputs,
|
||||
const framework::VariableNameMap &outputs,
|
||||
const framework::AttributeMap &attrs);
|
||||
|
||||
private:
|
||||
void RunImpl(const framework::Scope &scope,
|
||||
const platform::Place &place) const override;
|
||||
|
||||
private:
|
||||
StepScopes CreateStepScopes(const platform::DeviceContext &dev_ctx,
|
||||
const framework::Scope &scope,
|
||||
size_t seq_len) const;
|
||||
};
|
||||
|
||||
class RecurrentGradOp : public RecurrentBase {
|
||||
public:
|
||||
RecurrentGradOp(const std::string &type,
|
||||
const framework::VariableNameMap &inputs,
|
||||
const framework::VariableNameMap &outputs,
|
||||
const framework::AttributeMap &attrs);
|
||||
|
||||
private:
|
||||
void RunImpl(const framework::Scope &scope,
|
||||
const platform::Place &place) const override;
|
||||
|
||||
StepScopes CreateStepScopes(const platform::DeviceContext &dev_ctx,
|
||||
const framework::Scope &scope,
|
||||
size_t seq_len) const;
|
||||
|
||||
std::unordered_set<std::string> List2Set(
|
||||
const std::vector<std::string> &list) const;
|
||||
|
||||
std::unordered_set<std::string> LocalVarNames(
|
||||
const framework::Scope &scope) const;
|
||||
|
||||
static std::vector<std::string> GradVarLists(
|
||||
const std::vector<std::string> &var_names);
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Loading…
Reference in new issue