Imperative tracer refactoring (#22457)

* refine grad maker, test=develop

* refactor tracer stage 1, test=develop

* merge develop to solve conflict third times, test=develop
revert-22710-feature/integrated_ps_api
Zeng Jinle 5 years ago committed by GitHub
parent 08a772cb46
commit d33c4343e1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -242,10 +242,11 @@ struct OpInfoFiller<T, kGradOpBaseMaker> {
"GradOpBaseMaker of %s has been registered", op_type));
info->dygraph_grad_op_maker_ = [](
const imperative::OpBase* fw_op_base,
const std::string& type,
const imperative::NameVarBaseMap& var_base_map_in,
const imperative::NameVarBaseMap& var_base_map_out) {
T maker(fw_op_base, var_base_map_in, var_base_map_out);
const imperative::NameVarBaseMap& var_base_map_out,
const framework::AttributeMap& attrs) {
T maker(type, var_base_map_in, var_base_map_out, attrs);
return maker();
};
}

@ -28,6 +28,26 @@ limitations under the License. */
namespace paddle {
namespace framework {
namespace details {
template <typename T>
struct GradOpPtrTrait {};
template <>
struct GradOpPtrTrait<OpDesc> {
using Type = OpDesc*;
};
template <>
struct GradOpPtrTrait<imperative::OpBase> {
using Type = imperative::TracedGradOp*;
};
} // namespace details
template <typename T>
using GradOpPtr = typename details::GradOpPtrTrait<T>::Type;
/*
This functor class is responsible for creating the gradient ops for the given
operator fwd_op. After it is called (through operator()), the pairs of
@ -47,6 +67,10 @@ class GradOpDescMakerBase {
grad_to_var_(grad_to_var),
grad_block_(grad_block) {}
static std::unique_ptr<OpDesc> CreateOp() {
return std::unique_ptr<OpDesc>(new OpDesc());
}
virtual ~GradOpDescMakerBase() = default;
virtual std::vector<std::unique_ptr<OpDesc>> operator()() const = 0;
@ -100,7 +124,13 @@ class GradOpDescMakerBase {
return ret_val;
}
std::vector<std::string> Empty() const { return {}; }
static std::vector<std::string> EmptyInput() { return {}; }
static std::vector<std::string> EmptyOutput() { return {}; }
static std::vector<std::string> EmptyInputGrad() { return {}; }
static std::vector<std::string> EmptyOutputGrad() { return {}; }
std::vector<std::string> InputNames() const {
return this->fwd_op_.InputNames();
@ -155,16 +185,7 @@ class GradOpDescMakerBase {
};
template <typename T>
class SingleGradOpMaker {
public:
std::vector<std::unique_ptr<T>> operator()() const {
PADDLE_ENFORCE(false, "should not call this function");
return {};
}
protected:
virtual std::unique_ptr<T> Apply() const = 0;
};
class SingleGradOpMaker {};
template <>
class SingleGradOpMaker<OpDesc> : public GradOpDescMakerBase {
@ -173,12 +194,13 @@ class SingleGradOpMaker<OpDesc> : public GradOpDescMakerBase {
std::vector<std::unique_ptr<OpDesc>> operator()() const {
std::vector<std::unique_ptr<OpDesc>> retv;
retv.emplace_back(this->Apply());
retv.emplace_back(new OpDesc());
this->Apply(retv.front().get());
return retv;
}
protected:
virtual std::unique_ptr<OpDesc> Apply() const = 0;
virtual void Apply(GradOpPtr<OpDesc> op) const = 0;
};
template <>
@ -187,16 +209,18 @@ class SingleGradOpMaker<imperative::OpBase>
public:
using GradOpBaseMakerBase::GradOpBaseMakerBase;
public:
std::vector<std::unique_ptr<imperative::OpBase>> operator()() const {
std::vector<std::unique_ptr<imperative::OpBase>> retv;
retv.emplace_back(this->Apply());
std::vector<std::shared_ptr<imperative::OpBase>> operator()() const {
std::vector<std::shared_ptr<imperative::OpBase>> retv{
std::make_shared<imperative::OpBase>()};
{
imperative::TracedGradOp grad_op(retv.front());
this->Apply(&grad_op);
}
return retv;
}
protected:
virtual std::unique_ptr<imperative::OpBase> Apply() const = 0;
virtual void Apply(GradOpPtr<imperative::OpBase> op) const = 0;
};
template <typename T, bool DropEmptyIG = true>
@ -205,8 +229,7 @@ class DefaultGradOpMaker final : public SingleGradOpMaker<T> {
using SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const final {
auto* grad = new T();
void Apply(GradOpPtr<T> grad) const final {
grad->SetType(this->ForwardOpType() + "_grad");
for (auto& input_param : this->InputNames()) {
@ -221,19 +244,11 @@ class DefaultGradOpMaker final : public SingleGradOpMaker<T> {
}
grad->SetAttrMap(this->Attrs());
return std::unique_ptr<T>(grad);
}
};
template <typename T>
class EmptyGradOpMaker {
public:
virtual std::vector<std::unique_ptr<T>> operator()()
const final { /* NOLINT */
return {};
}
};
class EmptyGradOpMaker {};
template <>
class EmptyGradOpMaker<OpDesc> final : public GradOpDescMakerBase {
@ -247,10 +262,18 @@ class EmptyGradOpMaker<imperative::OpBase> final
: public imperative::GradOpBaseMakerBase {
public:
using GradOpBaseMakerBase::GradOpBaseMakerBase;
std::vector<std::unique_ptr<imperative::OpBase>> operator()() const final {
std::vector<std::shared_ptr<imperative::OpBase>> operator()() const final {
return {};
}
};
} // namespace framework
namespace operators {
template <typename T>
using GradOpPtr = framework::GradOpPtr<T>;
} // namespace operators
} // namespace paddle

@ -45,8 +45,9 @@ bool StaticGraphInferNoNeedBufferVarsContext::HasOutput(
}
DyGraphInferNoNeedBufferVarsContext::DyGraphInferNoNeedBufferVarsContext(
const imperative::NameVarBaseMap &inputs,
const imperative::NameVarBaseMap &outputs, const AttributeMap &attrs)
const imperative::NameVarMap<imperative::VariableWrapper> &inputs,
const imperative::NameVarMap<imperative::VariableWrapper> &outputs,
const AttributeMap &attrs)
: InferNoNeedBufferVarsContext(attrs), inputs_(inputs), outputs_(outputs) {}
bool DyGraphInferNoNeedBufferVarsContext::HasOutput(

@ -56,15 +56,16 @@ class StaticGraphInferNoNeedBufferVarsContext final
class DyGraphInferNoNeedBufferVarsContext final
: public InferNoNeedBufferVarsContext {
public:
DyGraphInferNoNeedBufferVarsContext(const imperative::NameVarBaseMap &inputs,
const imperative::NameVarBaseMap &outputs,
const AttributeMap &attr);
DyGraphInferNoNeedBufferVarsContext(
const imperative::NameVarMap<imperative::VariableWrapper> &inputs,
const imperative::NameVarMap<imperative::VariableWrapper> &outputs,
const AttributeMap &attrs);
bool HasOutput(const std::string &slot) const final;
private:
const imperative::NameVarBaseMap &inputs_;
const imperative::NameVarBaseMap &outputs_;
const imperative::NameVarMap<imperative::VariableWrapper> &inputs_;
const imperative::NameVarMap<imperative::VariableWrapper> &outputs_;
};
class NoNeedBufferVarsInference {
@ -106,8 +107,8 @@ class InferNoNeedBufferVarsFN {
}
inline const std::unordered_set<std::string> &operator()(
const imperative::NameVarBaseMap &inputs,
const imperative::NameVarBaseMap &outputs,
const imperative::NameVarMap<imperative::VariableWrapper> &inputs,
const imperative::NameVarMap<imperative::VariableWrapper> &outputs,
const AttributeMap &attrs) const {
PADDLE_ENFORCE_NOT_NULL(inferer_);
DyGraphInferNoNeedBufferVarsContext ctx(inputs, outputs, attrs);

@ -35,10 +35,10 @@ TEST(test_no_need_buffer_vars_inference, test_static_graph) {
TEST(test_no_need_buffer_vars_inference, test_dygraph) {
AttributeMap attrs{{"is_test", true}};
imperative::NameVarBaseMap inputs;
imperative::NameVarBaseMap outputs;
imperative::NameVarMap<imperative::VariableWrapper> inputs;
imperative::NameVarMap<imperative::VariableWrapper> outputs;
outputs["Out"].emplace_back(nullptr);
outputs["Out"].emplace_back(new imperative::VarBase("tmp_0"));
outputs["Out"].emplace_back(new imperative::VariableWrapper("tmp_0"));
DyGraphInferNoNeedBufferVarsContext ctx(inputs, outputs, attrs);

@ -1310,8 +1310,10 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType(
for (auto& input : ctx.InNameList()) {
ParseInputDataType(ctx, input, &data_type);
}
PADDLE_ENFORCE_NE(data_type, dafault_data_type,
"DataType should be indicated by input Variable.");
PADDLE_ENFORCE_NE(
data_type, dafault_data_type,
platform::errors::NotFound(
"DataType should be indicated by input Variable at %s.", Type()));
return data_type;
}

@ -56,10 +56,11 @@ using GradOpMakerFN = std::function<std::vector<std::unique_ptr<OpDesc>>(
const std::vector<BlockDesc*>& grad_block)>;
using DygraphGradOpMakerFN =
std::function<std::vector<std::unique_ptr<imperative::OpBase>>(
const imperative::OpBase* fw_op_base,
const imperative::NameVarBaseMap& var_base_map_in,
const imperative::NameVarBaseMap& var_base_map_out)>;
std::function<std::vector<std::shared_ptr<imperative::OpBase>>(
const std::string& /*op_type*/,
const imperative::NameVarBaseMap& /*var_base_map_in*/,
const imperative::NameVarBaseMap& /*var_base_map_out*/,
const framework::AttributeMap& /*attributes*/)>;
using InferVarTypeFN =
std::function<void(framework::InferVarTypeContext* /*context*/)>;

File diff suppressed because it is too large Load Diff

@ -30,16 +30,9 @@
namespace paddle {
namespace imperative {
void Engine::RunOp(paddle::imperative::OpBase* op,
const paddle::imperative::NameVarBaseMap& ins,
const paddle::imperative::NameVarBaseMap& outs,
const paddle::platform::Place& place) {
op->Run(ins, outs);
}
void BasicEngine::Init(VarBase* var, const detail::BackwardStrategy& strategy) {
backward_strategy_ = strategy;
const std::vector<OpBase*> ops = var->GradVarBase()->GradOps();
const auto& ops = var->GradVarBase()->GradOps();
var->ClearGradOps();
if (ops.empty() || var->OverridedStopGradient()) {
@ -59,7 +52,9 @@ void BasicEngine::Init(VarBase* var, const detail::BackwardStrategy& strategy) {
return;
}
}
init_ops_ = ops;
var->GradVarBase()->ClearGradOps();
VLOG(3) << "start backward";
PADDLE_ENFORCE_EQ(var->HasGradVar(), true,
@ -71,7 +66,6 @@ void BasicEngine::Init(VarBase* var, const detail::BackwardStrategy& strategy) {
VLOG(6) << "init loss grad:" << var->GradVarBase()->Name()
<< " as stop_gradient false";
var->GradVarBase()->InnerSetOverridedStopGradient(false);
var->GradVarBase()->SetGradGenerated(true);
auto* dev_ctx = platform::DeviceContextPool::Instance().Get(fwd_var.place());
grad_var->Resize(fwd_var.dims());
grad_var->mutable_data(fwd_var.place(), fwd_var.type());
@ -81,35 +75,29 @@ void BasicEngine::Init(VarBase* var, const detail::BackwardStrategy& strategy) {
void BasicEngine::CheckBackwardInputs(OpBase* op) {
for (auto& pair : op->GetInsMap()) {
for (auto& var : pair.second) {
if (var && IsGrad(var.get())) {
// if grad var has OverridedStopGradient skip this Op
if (!var->GradGenerated()) {
VLOG(6) << "Set ungenerated Grad: " << var->Name() << " as zero";
auto* dev_ctx =
platform::DeviceContextPool::Instance().Get(op->place());
auto* tensor = var->MutableVar()->GetMutable<framework::LoDTensor>();
tensor->mutable_data(op->place(), var->DataType());
operators::math::set_constant(*dev_ctx, tensor, 0.0);
} else {
continue;
}
if (!var || op->IsAllowedEmptyVar(var.get())) {
continue;
}
}
}
}
void BasicEngine::SetBackwardOutputs(paddle::imperative::OpBase* op) {
for (auto& pair : op->GetOutsMap()) {
for (auto& var : pair.second) {
if (var) {
// Set Backward outputs's generate_grad as true
var->SetGradGenerated(true);
VLOG(6) << "Set backward output: " << var->Name()
<< "'s SetGeneratedGrad as True";
auto* inner_var = var->MutableVar();
framework::Tensor* tensor = nullptr;
if (!inner_var->IsInitialized() ||
inner_var->IsType<framework::LoDTensor>()) {
tensor = inner_var->GetMutable<framework::LoDTensor>();
}
if (tensor && !tensor->IsInitialized()) {
// if grad var has OverridedStopGradient skip this Op
VLOG(6) << "Set ungenerated Grad: " << var->Name() << " as zero";
auto* dev_ctx =
platform::DeviceContextPool::Instance().Get(op->place());
tensor->mutable_data(op->place(), var->DataType());
operators::math::set_constant(*dev_ctx, tensor, 0.0);
}
}
}
}
void BasicEngine::PrepareGradAccumulators(OpBase* op) {
for (const auto& pair : op->GetOutsMap()) {
for (const auto& var : pair.second) {
@ -140,50 +128,63 @@ void BasicEngine::PrepareDeps() {
std::queue<OpBase*> q;
std::unordered_set<OpBase*> visited;
for (const auto& init_op : init_ops_) {
q.push(init_op);
visited.insert(init_op);
q.push(init_op.get());
visited.insert(init_op.get());
}
while (!q.empty()) {
auto* cur_op = q.front();
q.pop();
VLOG(3) << "Checking grads of op " << cur_op->Type();
SetBackwardOutputs(cur_op);
PADDLE_ENFORCE_NE(
cur_op->GetInsMap().empty() && cur_op->GetOutsMap().empty(), true,
platform::errors::NotFound(
"Inputs and outputs of %s do not exist. "
"This may be because you call \"backward()\" twice for the same "
"subgraph. Please try to call \"stop_gradient = True\" or "
"\"detach()\" if you use some same vars between two \"backward()\" "
"calls.",
cur_op->Type()));
PrepareGradAccumulators(cur_op);
auto& grad_pending_ops = cur_op->GradPendingOps();
for (auto* grad_pending_op : grad_pending_ops) {
const auto& grad_pending_ops = cur_op->GradPendingOps();
for (auto& grad_pending_op : grad_pending_ops) {
PADDLE_ENFORCE_NOT_NULL(grad_pending_op);
++op_deps_[grad_pending_op];
if (visited.count(grad_pending_op) == 0) {
visited.insert(grad_pending_op);
q.push(grad_pending_op);
++op_deps_[grad_pending_op.get()];
if (visited.count(grad_pending_op.get()) == 0) {
visited.insert(grad_pending_op.get());
q.push(grad_pending_op.get());
}
}
}
}
void BasicEngine::SumGradient(OpBase* op, std::shared_ptr<VarBase> src,
VarBase* dst) {
void BasicEngine::SumGradient(OpBase* op, std::shared_ptr<VariableWrapper> src,
VariableWrapper* dst) {
auto iter = accumulators_.find(dst);
PADDLE_ENFORCE_EQ(iter != accumulators_.end(), true,
"Cannot find gradient of variable %s", dst->Name());
iter->second->Add(std::move(src), op->id());
}
void BasicEngine::Execute() {
PrepareDeps();
// Start execute Computation graph
std::queue<OpBase*> q;
std::queue<std::shared_ptr<OpBase>> q;
for (const auto& init_op : init_ops_) {
q.push(init_op);
q.push(std::move(init_op));
}
size_t op_num = 0;
while (!q.empty()) {
OpBase* cur_op = q.front();
auto shared_cur_op = std::move(q.front());
q.pop();
auto* cur_op = shared_cur_op.get();
++op_num;
// CheckBackWardInput
CheckBackwardInputs(cur_op);
@ -191,26 +192,28 @@ void BasicEngine::Execute() {
auto& bwd_ins = cur_op->GetInsMap();
auto& bwd_outs = cur_op->GetOutsMap();
NameVarBaseMap tmp_outs(bwd_outs);
NameVarMap<VariableWrapper> tmp_outs(bwd_outs);
// 1. construct the output map 2. replace the element in the map
// A var may be coresponding to several grad var in one op
for (auto it = tmp_outs.begin(); it != tmp_outs.end(); ++it) {
for (size_t i = 0; i < it->second.size(); ++i) {
auto tmp_var =
std::make_shared<VarBase>(false, "Gtmp@"); // Do not need grad
std::make_shared<VariableWrapper>("Gtmp@"); // Do not need grad
auto var = it->second[i];
it->second[i] = tmp_var;
if (var) {
need_accu_var_list_.emplace_back(
make_pair(var.get(), std::move(tmp_var)));
var->ClearGradOps();
need_accu_var_list_.emplace_back(var.get(), std::move(tmp_var));
}
}
}
VLOG(3) << "Start to execute grad op " << cur_op->Type();
RunOp(cur_op, bwd_ins, tmp_outs, cur_op->place());
{
VLOG(3) << "Start to execute grad op " << cur_op->Type();
OpBase::Run(cur_op->InnerOp(), bwd_ins, tmp_outs, cur_op->Attrs(),
cur_op->place());
}
// Step 2: Sum Gradient
if (need_accu_var_list_.size() > 0) {
@ -223,9 +226,9 @@ void BasicEngine::Execute() {
// Step 3: Collect ready ops
for (auto* grad_pending_op : cur_op->GradPendingOps()) {
for (auto& grad_pending_op : cur_op->GradPendingOps()) {
PADDLE_ENFORCE_NOT_NULL(grad_pending_op);
auto iter = op_deps_.find(grad_pending_op);
auto iter = op_deps_.find(grad_pending_op.get());
if (iter == op_deps_.end()) {
continue;
}
@ -242,10 +245,11 @@ void BasicEngine::Execute() {
// Step 4: Delete op to collect unused variables
VLOG(3) << "Remove op after op " << cur_op->Type() << " runs";
RemoveOp(cur_op);
cur_op->ClearBackwardTrace();
}
VLOG(3) << "Clean properties of BasicEngine";
CleanEngine();
Clear();
VLOG(1) << "Backward op number: " << op_num;
}
} // namespace imperative
} // namespace paddle

@ -37,49 +37,12 @@ class Engine {
virtual ~Engine() = default;
virtual void Execute() = 0;
virtual void Init(VarBase* var, const detail::BackwardStrategy& strategy) = 0;
virtual void RunOp(imperative::OpBase* op, const NameVarBaseMap& ins,
const NameVarBaseMap& outs, const platform::Place& place);
virtual void RemoveOp(OpBase* op) {
PADDLE_ENFORCE_NOT_NULL(op, "Cannot remove null op");
auto iter = grad_ops_.find(op);
PADDLE_ENFORCE_EQ(iter != grad_ops_.end(), true, "Op is not inside tracer");
grad_ops_.erase(iter);
}
void InsertOp(OpBase* op, std::shared_ptr<OpBase> op_shared) {
grad_ops_[op] = std::move(op_shared);
}
const std::unordered_set<VarBase*>& GradVars() const { return grad_vars_; }
const std::unordered_map<OpBase*, std::shared_ptr<OpBase>>& GradOps() const {
return grad_ops_;
}
void InsertGradVar(VarBase* grad) { grad_vars_.emplace(grad); }
bool IsGrad(VarBase* var) { return grad_vars_.count(var) > 0; }
void Clear() {
grad_ops_.clear();
grad_vars_.clear();
}
private:
std::unordered_map<OpBase*, std::shared_ptr<OpBase>>
grad_ops_; // opBase for remove - grad_op
std::unordered_set<VarBase*> grad_vars_;
};
class BasicEngine : public Engine {
public:
BasicEngine() = default;
void Init(VarBase* var, const detail::BackwardStrategy& strategy) override;
~BasicEngine() override = default;
void Execute() override;
private:
@ -87,28 +50,26 @@ class BasicEngine : public Engine {
void CheckBackwardInputs(OpBase* op);
void SetBackwardOutputs(OpBase* op);
void PrepareGradAccumulators(OpBase* op);
void SumGradient(OpBase* op, std::shared_ptr<VarBase> src, VarBase* dst);
void SumGradient(OpBase* op, std::shared_ptr<VariableWrapper> src,
VariableWrapper* dst);
// TODO(jiabin): maybe we can optimize the performance of engine by cache the
// result
void CleanEngine() {
void Clear() {
init_ops_.clear();
op_deps_.clear();
accumulators_.clear();
Clear();
}
std::vector<OpBase*> init_ops_;
std::vector<std::shared_ptr<OpBase>> init_ops_;
detail::BackwardStrategy backward_strategy_;
std::unordered_map<OpBase*, size_t> op_deps_;
std::unordered_map<VarBase*, std::unique_ptr<GradientAccumulator>>
std::unordered_map<VariableWrapper*, std::unique_ptr<GradientAccumulator>>
accumulators_;
std::vector<std::pair<VarBase*, std::shared_ptr<VarBase>>>
std::vector<std::pair<VariableWrapper*, std::shared_ptr<VariableWrapper>>>
need_accu_var_list_;
};

@ -144,8 +144,8 @@ void SelectedRowsAddToTensor(const framework::Variable& src,
// Note(chenweihang): when two selected rows need to be added,
// adding one to another is not equal to merging two selected rows
// to one then add it to a empty selected rows, the after is correct
std::shared_ptr<VarBase> SelectedRowsMerge(const framework::Variable& src1,
const framework::Variable& src2) {
std::shared_ptr<VariableWrapper> SelectedRowsMerge(
const framework::Variable& src1, const framework::Variable& src2) {
auto& src_selected_rows1 = src1.Get<framework::SelectedRows>();
auto& src_selected_rows2 = src2.Get<framework::SelectedRows>();
auto place = src_selected_rows1.value().place();
@ -155,7 +155,7 @@ std::shared_ptr<VarBase> SelectedRowsMerge(const framework::Variable& src1,
std::vector<const framework::SelectedRows*> src_selected_rows;
src_selected_rows.emplace_back(&src_selected_rows1);
src_selected_rows.emplace_back(&src_selected_rows2);
auto dst_var = std::make_shared<VarBase>(false, "Temp");
auto dst_var = std::make_shared<VariableWrapper>("Temp");
auto* dst_selected_rows =
dst_var->MutableVar()->GetMutable<framework::SelectedRows>();
@ -188,7 +188,8 @@ std::shared_ptr<VarBase> SelectedRowsMerge(const framework::Variable& src1,
framework::DataTypeToString(data_type)));
}
void VarBaseAdd(std::shared_ptr<VarBase> var, VarBase* var_) {
void VariableWrapperAdd(std::shared_ptr<VariableWrapper> var,
VariableWrapper* var_) {
auto& src = var->Var();
auto* dst = var_->MutableVar();
if (dst->IsType<framework::LoDTensor>()) {
@ -208,7 +209,7 @@ void VarBaseAdd(std::shared_ptr<VarBase> var, VarBase* var_) {
*dst = std::move(*(var->MutableVar()));
var_->SetType(framework::proto::VarType::LOD_TENSOR);
} else if (src.IsType<framework::SelectedRows>()) {
std::shared_ptr<VarBase> temp = SelectedRowsMerge(src, *dst);
auto temp = SelectedRowsMerge(src, *dst);
*dst = std::move(*(temp->MutableVar()));
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
@ -218,7 +219,8 @@ void VarBaseAdd(std::shared_ptr<VarBase> var, VarBase* var_) {
}
}
platform::Place GetPlaceOfVarBase(const std::shared_ptr<VarBase>& var) {
static platform::Place GetPlaceOfVar(
const std::shared_ptr<VariableWrapper>& var) {
platform::Place place;
if (var->Var().IsType<framework::LoDTensor>()) {
place = var->Var().Get<framework::LoDTensor>().place();
@ -231,10 +233,10 @@ platform::Place GetPlaceOfVarBase(const std::shared_ptr<VarBase>& var) {
return place;
}
void EagerGradientAccumulator::Add(std::shared_ptr<VarBase> var,
void EagerGradientAccumulator::Add(std::shared_ptr<VariableWrapper> var,
size_t trace_id) {
auto* dst_var = var_->MutableVar();
platform::Place place = GetPlaceOfVarBase(var);
platform::Place place = GetPlaceOfVar(var);
if (!var_->OverridedStopGradient()) {
VLOG(3) << "Sum Gradient for: " << var_->Name();
if (cur_cnt_ == 0) {
@ -243,7 +245,7 @@ void EagerGradientAccumulator::Add(std::shared_ptr<VarBase> var,
}
*dst_var = std::move(*(var->MutableVar()));
} else {
VarBaseAdd(var, var_);
VariableWrapperAdd(var, var_);
}
} else {
if (!var_->Var().IsInitialized() ||
@ -268,10 +270,10 @@ void EagerGradientAccumulator::Add(std::shared_ptr<VarBase> var,
++cur_cnt_;
}
void SortedGradientAccumulator::Add(std::shared_ptr<VarBase> var,
void SortedGradientAccumulator::Add(std::shared_ptr<VariableWrapper> var,
size_t trace_id) {
auto* dst_var = var_->MutableVar();
platform::Place place = GetPlaceOfVarBase(var);
platform::Place place = GetPlaceOfVar(var);
if (!var_->OverridedStopGradient()) {
if (ref_cnt_ == 1) {
if (var->Var().IsType<framework::SelectedRows>()) {
@ -291,11 +293,12 @@ void SortedGradientAccumulator::Add(std::shared_ptr<VarBase> var,
return;
}
std::sort(tmp_grad_vars_.begin(), tmp_grad_vars_.end(),
[](const std::pair<std::shared_ptr<VarBase>, size_t>& p1,
const std::pair<std::shared_ptr<VarBase>, size_t>& p2) {
return p1.second > p2.second;
});
std::sort(
tmp_grad_vars_.begin(), tmp_grad_vars_.end(),
[](const std::pair<std::shared_ptr<VariableWrapper>, size_t>& p1,
const std::pair<std::shared_ptr<VariableWrapper>, size_t>& p2) {
return p1.second > p2.second;
});
#ifdef PADDLE_WITH_CUDA
if (paddle::platform::is_gpu_place(place)) {
@ -310,7 +313,7 @@ void SortedGradientAccumulator::Add(std::shared_ptr<VarBase> var,
var_->SetType(framework::proto::VarType::SELECTED_ROWS);
*dst_var = std::move(*(tmp_grad_vars_[i].first->MutableVar()));
} else {
VarBaseAdd(tmp_grad_vars_[i].first, var_);
VariableWrapperAdd(tmp_grad_vars_[i].first, var_);
}
}
}
@ -321,7 +324,7 @@ void SortedGradientAccumulator::Add(std::shared_ptr<VarBase> var,
*dst_var = std::move(*(tmp_grad_vars_[0].first->MutableVar()));
}
if (tmp_grad_vars_[i].first->Var().IsType<framework::LoDTensor>()) {
VarBaseAdd(tmp_grad_vars_[i].first, var_);
VariableWrapperAdd(tmp_grad_vars_[i].first, var_);
}
}
} else {
@ -333,7 +336,7 @@ void SortedGradientAccumulator::Add(std::shared_ptr<VarBase> var,
*dst_var = std::move(*(tmp_grad_vars_[0].first->MutableVar()));
}
for (size_t i = 1; i < tmp_grad_vars_.size(); ++i) {
VarBaseAdd(tmp_grad_vars_[i].first, var_);
VariableWrapperAdd(tmp_grad_vars_[i].first, var_);
}
#ifdef PADDLE_WITH_CUDA
}

@ -24,9 +24,9 @@ namespace imperative {
class GradientAccumulator {
public:
explicit GradientAccumulator(VarBase* var) : var_(var) {}
explicit GradientAccumulator(VariableWrapper* var) : var_(var) {}
virtual void Add(std::shared_ptr<VarBase> var, size_t trace_id) = 0;
virtual void Add(std::shared_ptr<VariableWrapper> var, size_t trace_id) = 0;
virtual ~GradientAccumulator() = default;
@ -35,7 +35,7 @@ class GradientAccumulator {
inline size_t RefCnt() const { return ref_cnt_; }
protected:
VarBase* var_;
VariableWrapper* var_;
size_t ref_cnt_{0};
};
@ -43,7 +43,7 @@ class EagerGradientAccumulator : public GradientAccumulator {
public:
using GradientAccumulator::GradientAccumulator;
void Add(std::shared_ptr<VarBase> var, size_t trace_id) override;
void Add(std::shared_ptr<VariableWrapper> var, size_t trace_id) override;
private:
size_t cur_cnt_{0};
@ -53,10 +53,11 @@ class SortedGradientAccumulator : public GradientAccumulator {
public:
using GradientAccumulator::GradientAccumulator;
void Add(std::shared_ptr<VarBase> var, size_t trace_id) override;
void Add(std::shared_ptr<VariableWrapper> var, size_t trace_id) override;
private:
std::vector<std::pair<std::shared_ptr<VarBase>, size_t>> tmp_grad_vars_;
std::vector<std::pair<std::shared_ptr<VariableWrapper>, size_t>>
tmp_grad_vars_;
};
} // namespace imperative

@ -113,9 +113,10 @@ static framework::RuntimeContext PrepareRuntimeContext(
return framework::RuntimeContext(std::move(inputs), std::move(outputs));
}
template <typename VarType>
static std::string DebugString(
const std::string& name,
const std::vector<std::shared_ptr<VarBase>>& vars) {
const std::vector<std::shared_ptr<VarType>>& vars) {
std::stringstream ss;
ss << name << "{";
@ -127,7 +128,7 @@ static std::string DebugString(
continue;
}
ss << vars[i]->Name() << "[";
auto& var = vars[i]->Var();
const framework::Variable& var = vars[i]->Var();
if (!var.IsInitialized()) {
ss << "NOT_INITED_VAR";
} else if (var.IsType<framework::LoDTensor>()) {
@ -167,9 +168,10 @@ static std::string DebugString(
return ss.str();
}
std::string LayerDebugString(const std::string& op_type,
const NameVarBaseMap& ins,
const NameVarBaseMap& outs) {
template <typename VarType>
static std::string LayerDebugStringImpl(const std::string& op_type,
const NameVarMap<VarType>& ins,
const NameVarMap<VarType>& outs) {
std::stringstream ss;
ss << "Op(" << op_type << "): ";
@ -192,28 +194,30 @@ std::string LayerDebugString(const std::string& op_type,
return ss.str();
}
void VarBase::AddGradOps(const std::weak_ptr<OpBase>& op) {
if (op.lock() == nullptr) {
return;
}
for (const auto& cur_op : grad_ops_) {
if (cur_op.lock() == op.lock()) {
return;
}
}
grad_ops_.emplace_back(op);
std::string LayerDebugString(const std::string& op_type,
const NameVarMap<VarBase>& ins,
const NameVarMap<VarBase>& outs) {
return LayerDebugStringImpl<VarBase>(op_type, ins, outs);
}
std::string LayerDebugString(const std::string& op_type,
const NameVarMap<VariableWrapper>& ins,
const NameVarMap<VariableWrapper>& outs) {
return LayerDebugStringImpl<VariableWrapper>(op_type, ins, outs);
}
void VarBase::ClearGradient() {
if (grad_var_) {
if (grad_var_->var_.IsType<framework::SelectedRows>()) {
auto* grad_t = grad_var_->var_.GetMutable<framework::SelectedRows>();
if (grad_var_->Var().IsType<framework::SelectedRows>()) {
auto* grad_t =
grad_var_->MutableVar()->GetMutable<framework::SelectedRows>();
if (grad_t->mutable_value()->IsInitialized()) {
grad_t->mutable_rows()->clear();
grad_t->mutable_value()->clear();
}
} else {
auto* grad_t = grad_var_->var_.GetMutable<framework::LoDTensor>();
auto* grad_t =
grad_var_->MutableVar()->GetMutable<framework::LoDTensor>();
if (grad_t->IsInitialized()) {
auto* dev_ctx =
platform::DeviceContextPool::Instance().Get(grad_t->place());
@ -226,19 +230,20 @@ void VarBase::ClearGradient() {
std::shared_ptr<VarBase> VarBase::NewVarBase(const platform::Place& dst_place,
const bool blocking) const {
PADDLE_ENFORCE_EQ(
var_.IsInitialized() && (var_.IsType<framework::LoDTensor>() ||
var_.IsType<framework::SelectedRows>()),
Var().IsInitialized() && (Var().IsType<framework::LoDTensor>() ||
Var().IsType<framework::SelectedRows>()),
true, platform::errors::InvalidArgument(
"Variable is not initialized or Variable's type is not "
"LoDTensor or SelectedRows when getting numpy tensor"));
if (var_.IsType<framework::LoDTensor>()) {
auto& src_tensor = var_.Get<framework::LoDTensor>();
if (Var().IsType<framework::LoDTensor>()) {
auto& src_tensor = Var().Get<framework::LoDTensor>();
// TODO(Jiabin): change this after move unique_name generator to CXX
auto new_var = std::make_shared<VarBase>(
true, Name() + std::to_string(copied_counter_++));
auto* dst_tensor = new_var->var_.GetMutable<framework::LoDTensor>();
auto* dst_tensor =
new_var->MutableVar()->GetMutable<framework::LoDTensor>();
dst_tensor->set_lod(src_tensor.lod());
new_var->SetPersistable(Persistable());
new_var->SetDataType(DataType());
@ -257,12 +262,12 @@ std::shared_ptr<VarBase> VarBase::NewVarBase(const platform::Place& dst_place,
}
return new_var;
} else {
auto& src_selected_rows = var_.Get<framework::SelectedRows>();
auto& src_selected_rows = Var().Get<framework::SelectedRows>();
auto new_var = std::make_shared<VarBase>(
false, "Itmp" + std::to_string(copied_counter_++));
new_var->SetType(framework::proto::VarType::SELECTED_ROWS);
auto* dst_selected_rows =
new_var->var_.GetMutable<framework::SelectedRows>();
new_var->MutableVar()->GetMutable<framework::SelectedRows>();
framework::TensorCopy(src_selected_rows.value(), dst_place,
dst_selected_rows->mutable_value());
@ -281,39 +286,32 @@ std::shared_ptr<VarBase> VarBase::NewVarBase(const platform::Place& dst_place,
return new_var;
}
}
// create OpBase from optype
OpBase::OpBase(size_t id, const std::string& type, const NameVarBaseMap& ins,
const NameVarBaseMap& outs, const framework::AttributeMap& attrs,
const platform::Place& place)
: id_(id), place_(place), attrs_(attrs) {
const auto& info = framework::OpInfoMap::Instance().Get(type);
// Step 1: Run forward
if (info.Checker() != nullptr) {
info.Checker()->Check(&attrs_, true);
}
void OpBase::SetType(const std::string& type) {
op_ = framework::OpRegistry::CreateOp(type, {}, {}, {}, false);
VLOG(3) << "Construct Op: " << type << std::endl;
}
void OpBase::CreateOperatorBase() {
const auto& info = framework::OpInfoMap::Instance().Get(type_);
if (info.Checker() != nullptr) {
info.Checker()->Check(&attrs_, true);
}
op_ = framework::OpRegistry::CreateOp(type_, {}, {}, {}, false);
void OpBase::ClearBackwardTrace() {
grad_pending_ops_.clear();
allow_empty_vars_.clear();
ins_.clear();
outs_.clear();
}
void OpBase::Run(const NameVarBaseMap& ins, const NameVarBaseMap& outs) {
auto* op_kernel = dynamic_cast<framework::OperatorWithKernel*>(op_.get());
template <typename VarType>
static void OpBaseRunImpl(const framework::OperatorBase& op,
const NameVarMap<VarType>& ins,
const NameVarMap<VarType>& outs,
const framework::AttributeMap& attrs,
const platform::Place& place) {
auto* op_kernel = dynamic_cast<const framework::OperatorWithKernel*>(&op);
PADDLE_ENFORCE_NOT_NULL(op_kernel, "only support op with kernel");
auto& info = op_->Info();
auto& info = op.Info();
if (info.infer_var_type_) {
RuntimeInferVarTypeContext infer_var_type_ctx(ins, &outs, attrs_);
RuntimeInferVarTypeContext<VarType> infer_var_type_ctx(ins, &outs, attrs);
info.infer_var_type_(&infer_var_type_ctx);
}
// Initialize output var type
for (auto& var_pair : outs) {
for (auto& var : var_pair.second) {
@ -321,20 +319,29 @@ void OpBase::Run(const NameVarBaseMap& ins, const NameVarBaseMap& outs) {
}
}
VLOG(3) << "Running Op " << Type();
VLOG(5) << LayerDebugString(Type(), ins, outs);
auto prepared_op =
PreparedOp::Prepare(ins, outs, *op_kernel, place(), &attrs_);
// VLOG(3) << "Running Op " << op.Type();
VLOG(5) << LayerDebugString(op.Type(), ins, outs);
auto prepared_op = PreparedOp::Prepare(ins, outs, *op_kernel, place, attrs);
prepared_op.Run(&ins, &outs, &attrs_);
prepared_op.Run(ins, outs, attrs);
VLOG(4) << LayerDebugString(Type(), ins, outs);
VLOG(4) << LayerDebugString(op.Type(), ins, outs);
}
void OpBase::ClearBackwardTrace() {
grad_pending_ops_.clear();
ins_.clear();
outs_.clear();
void OpBase::Run(const framework::OperatorBase& op,
const NameVarMap<VarBase>& ins,
const NameVarMap<VarBase>& outs,
const framework::AttributeMap& attrs,
const platform::Place& place) {
OpBaseRunImpl<VarBase>(op, ins, outs, attrs, place);
}
void OpBase::Run(const framework::OperatorBase& op,
const NameVarMap<VariableWrapper>& ins,
const NameVarMap<VariableWrapper>& outs,
const framework::AttributeMap& attrs,
const platform::Place& place) {
OpBaseRunImpl<VariableWrapper>(op, ins, outs, attrs, place);
}
} // namespace imperative

File diff suppressed because it is too large Load Diff

@ -28,8 +28,9 @@ const framework::Tensor* GetTensorFromVar(const framework::Variable& var) {
}
}
void PreparedOp::PrepareData(
const platform::Place& place, const NameVarBaseMap& ins,
template <typename VarType>
static void PrepareDataImpl(
const platform::Place& place, const NameVarMap<VarType>& ins,
const framework::OperatorWithKernel& op,
const framework::OpKernelType& expected_kernel_key) {
for (const auto& name_pair : ins) {
@ -59,22 +60,37 @@ void PreparedOp::PrepareData(
}
}
void PreparedOp::PrepareData(
const platform::Place& place, const NameVarMap<VarBase>& ins,
const framework::OperatorWithKernel& op,
const framework::OpKernelType& expected_kernel_key) {
PrepareDataImpl<VarBase>(place, ins, op, expected_kernel_key);
}
void PreparedOp::PrepareData(
const platform::Place& place, const NameVarMap<VariableWrapper>& ins,
const framework::OperatorWithKernel& op,
const framework::OpKernelType& expected_kernel_key) {
PrepareDataImpl<VariableWrapper>(place, ins, op, expected_kernel_key);
}
PreparedOp::PreparedOp(const framework::OperatorBase& op,
const framework::RuntimeContext& ctx,
framework::OperatorWithKernel::OpKernelFunc func,
const framework::OperatorWithKernel::OpKernelFunc& func,
platform::DeviceContext* dev_ctx,
std::vector<framework::KernelConfig>* kernel_configs)
: op_(op),
ctx_(ctx),
func_(std::move(func)),
func_(func),
dev_ctx_(dev_ctx),
kernel_configs_(kernel_configs) {}
PreparedOp PreparedOp::Prepare(const NameVarBaseMap& ins,
const NameVarBaseMap& outs,
const framework::OperatorWithKernel& op,
platform::Place place,
const framework::AttributeMap* attrs) {
template <typename VarType>
PreparedOp PrepareOpImpl(const NameVarMap<VarType>& ins,
const NameVarMap<VarType>& outs,
const framework::OperatorWithKernel& op,
platform::Place place,
const framework::AttributeMap& attrs) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place);
@ -90,8 +106,9 @@ PreparedOp PreparedOp::Prepare(const NameVarBaseMap& ins,
auto& kernels = kernels_iter->second;
framework::RuntimeContext ctx({}, {});
auto expected_kernel_key = op.GetExpectedKernelType(DygraphExecutionContext(
op, framework::Scope(), *dev_ctx, ctx, nullptr, ins, outs, attrs));
auto expected_kernel_key =
op.GetExpectedKernelType(DygraphExecutionContext<VarType>(
op, framework::Scope(), *dev_ctx, ctx, nullptr, ins, outs, attrs));
VLOG(3) << "expected_kernel_key:" << expected_kernel_key;
auto kernel_iter = kernels.find(expected_kernel_key);
@ -108,24 +125,57 @@ PreparedOp PreparedOp::Prepare(const NameVarBaseMap& ins,
place = dev_ctx->GetPlace();
}
PrepareData(place, ins, op, expected_kernel_key);
PrepareDataImpl<VarType>(place, ins, op, expected_kernel_key);
return PreparedOp(op, ctx, kernel_iter->second, dev_ctx, kernel_configs);
}
void PreparedOp::Run(const NameVarBaseMap* in, const NameVarBaseMap* out,
const framework::AttributeMap* attrs) {
PreparedOp PreparedOp::Prepare(const NameVarMap<VarBase>& ins,
const NameVarMap<VarBase>& outs,
const framework::OperatorWithKernel& op,
const platform::Place& place,
const framework::AttributeMap& attrs) {
return PrepareOpImpl<VarBase>(ins, outs, op, place, attrs);
}
PreparedOp PreparedOp::Prepare(const NameVarMap<VariableWrapper>& ins,
const NameVarMap<VariableWrapper>& outs,
const framework::OperatorWithKernel& op,
const platform::Place& place,
const framework::AttributeMap& attrs) {
return PrepareOpImpl<VariableWrapper>(ins, outs, op, place, attrs);
}
template <typename VarType>
static void PreparedOpRunImpl(
const framework::OperatorBase& op, const framework::RuntimeContext& ctx,
const framework::OperatorWithKernel::OpKernelFunc& func,
platform::DeviceContext* dev_ctx,
std::vector<framework::KernelConfig>* kernel_configs,
const NameVarMap<VarType>& ins, const NameVarMap<VarType>& outs,
const framework::AttributeMap& attrs) {
// TODO(zjl): remove scope in dygraph
framework::Scope scope;
DygraphInferShapeContext infer_shape_ctx(in, out, attrs);
DygraphInferShapeContext<VarType> infer_shape_ctx(&ins, &outs, &attrs);
static_cast<const framework::OperatorWithKernel&>(op).InferShape(
&infer_shape_ctx);
framework::OperatorWithKernel* op_ker =
(framework::OperatorWithKernel*)(&op_);
func(DygraphExecutionContext<VarType>(op, scope, *dev_ctx, ctx,
kernel_configs, ins, outs, attrs));
}
op_ker->InferShape(&infer_shape_ctx);
void PreparedOp::Run(const NameVarMap<VarBase>& ins,
const NameVarMap<VarBase>& outs,
const framework::AttributeMap& attrs) {
PreparedOpRunImpl<VarBase>(op_, ctx_, func_, dev_ctx_, kernel_configs_, ins,
outs, attrs);
}
func_(DygraphExecutionContext(op_, scope, *dev_ctx_, ctx_, kernel_configs_,
*in, *out, attrs));
void PreparedOp::Run(const NameVarMap<VariableWrapper>& ins,
const NameVarMap<VariableWrapper>& outs,
const framework::AttributeMap& attrs) {
PreparedOpRunImpl<VariableWrapper>(op_, ctx_, func_, dev_ctx_,
kernel_configs_, ins, outs, attrs);
}
} // namespace imperative

@ -30,28 +30,42 @@ const framework::Tensor* GetTensorFromVar(const framework::Variable& var);
class PreparedOp {
public:
static PreparedOp Prepare(const NameVarBaseMap& ins,
const NameVarBaseMap& outs,
PreparedOp(const framework::OperatorBase& op,
const framework::RuntimeContext& ctx,
const framework::OperatorWithKernel::OpKernelFunc& func,
platform::DeviceContext* dev_ctx,
std::vector<framework::KernelConfig>* kernel_configs);
static PreparedOp Prepare(const NameVarMap<VarBase>& ins,
const NameVarMap<VarBase>& outs,
const framework::OperatorWithKernel& op,
const platform::Place& place,
const framework::AttributeMap& attrs);
static PreparedOp Prepare(const NameVarMap<VariableWrapper>& ins,
const NameVarMap<VariableWrapper>& outs,
const framework::OperatorWithKernel& op,
platform::Place place,
const framework::AttributeMap* attrs);
const platform::Place& place,
const framework::AttributeMap& attrs);
inline platform::DeviceContext* GetDeviceContext() const { return dev_ctx_; }
void Run(const NameVarBaseMap* in, const NameVarBaseMap* out,
const framework::AttributeMap* attrs);
void Run(const NameVarMap<VarBase>& in, const NameVarMap<VarBase>& out,
const framework::AttributeMap& attrs);
void Run(const NameVarMap<VariableWrapper>& ins,
const NameVarMap<VariableWrapper>& outs,
const framework::AttributeMap& attrs);
static void PrepareData(const platform::Place& place,
const NameVarBaseMap& ins,
const NameVarMap<VarBase>& ins,
const framework::OperatorWithKernel& op,
const framework::OpKernelType& expected_kernel_key);
private:
PreparedOp(const framework::OperatorBase& op,
const framework::RuntimeContext& ctx,
framework::OperatorWithKernel::OpKernelFunc func,
platform::DeviceContext* dev_ctx,
std::vector<framework::KernelConfig>* kernel_configs);
static void PrepareData(const platform::Place& place,
const NameVarMap<VariableWrapper>& ins,
const framework::OperatorWithKernel& op,
const framework::OpKernelType& expected_kernel_key);
private:
const framework::OperatorBase& op_;

@ -44,7 +44,8 @@ TEST(test_layer, test_runtime_context) {
imperative::NameVarBaseMap ins = {in_pair};
imperative::NameVarBaseMap outs = {out_pair};
framework::AttributeMap attrs;
auto* ctx = new imperative::RuntimeInferVarTypeContext(ins, &outs, attrs);
auto *ctx = new imperative::RuntimeInferVarTypeContext<imperative::VarBase>(
ins, &outs, attrs);
ASSERT_TRUE(ctx->HasVar("vin"));
ASSERT_TRUE(ctx->HasInput("X"));
ASSERT_TRUE(ctx->HasOutput("Out"));
@ -57,9 +58,9 @@ TEST(test_layer, test_runtime_context) {
ASSERT_ANY_THROW(ctx->SetLoDLevel("vin", 2));
}
std::string LayerDebugString(const std::string& op_type,
const NameVarBaseMap& ins,
const NameVarBaseMap& outs);
std::string LayerDebugString(const std::string &op_type,
const NameVarBaseMap &ins,
const NameVarBaseMap &outs);
TEST(test_layer, test_debug_string) {
platform::CPUPlace place;
@ -67,7 +68,7 @@ TEST(test_layer, test_debug_string) {
new imperative::VarBase(false, "vin"));
var_pair in_pair = var_pair("X", vb_vector(1, vin));
auto test_func = [&](std::shared_ptr<imperative::VarBase>& vout) {
auto test_func = [&](std::shared_ptr<imperative::VarBase> &vout) {
var_pair out_pair = var_pair("Out", vb_vector(1, vout));
imperative::NameVarBaseMap ins = {in_pair};
imperative::NameVarBaseMap outs = {out_pair};
@ -119,6 +120,34 @@ TEST(test_layer, test_debug_string) {
ASSERT_TRUE(res_sr.find("SelectedRows") != std::string::npos);
}
static std::shared_ptr<imperative::OpBase> CreateOpBase(
size_t id, const std::string &type, const imperative::NameVarBaseMap &ins,
const imperative::NameVarBaseMap &outs,
const framework::AttributeMap &attrs, const platform::Place &place) {
auto op = std::make_shared<imperative::OpBase>();
op->SetId(id);
op->SetPlace(place);
op->SetType(type);
op->SetAttrMap(attrs);
for (auto &pair : ins) {
std::vector<std::shared_ptr<VariableWrapper>> vars;
for (auto &var : pair.second) {
vars.emplace_back(var->SharedVar());
}
op->SetInput(pair.first, vars);
}
for (auto &pair : outs) {
std::vector<std::shared_ptr<VariableWrapper>> vars;
for (auto &var : pair.second) {
vars.emplace_back(var->SharedVar());
}
op->SetOutput(pair.first, vars);
}
return op;
}
TEST(test_layer, test_clear_backward_info) {
std::shared_ptr<imperative::VarBase> vin(
new imperative::VarBase(false, "vin"));
@ -133,13 +162,11 @@ TEST(test_layer, test_clear_backward_info) {
imperative::NameVarBaseMap outs = {out_pair};
framework::AttributeMap concat_att_map;
concat_att_map["axis"] = 1;
std::shared_ptr<imperative::OpBase> op(
OpBase::Create(0, "mul", ins, outs, concat_att_map, place));
std::shared_ptr<imperative::OpBase> preceding_op(
OpBase::Create(0, "mul", ins, outs, concat_att_map, place));
op->InsertGradPendingOps(preceding_op.get());
*(op->GetMutableInsMap()) = ins;
*(op->GetMutableOutsMap()) = outs;
auto op = CreateOpBase(0, "mul", ins, outs, concat_att_map, place);
auto preceding_op = CreateOpBase(0, "mul", ins, outs, concat_att_map, place);
op->SetGradPendingOps({preceding_op});
ASSERT_GT(op->GetInsMap().size(), 0UL);
ASSERT_GT(op->GetOutsMap().size(), 0UL);
ASSERT_GT(op->GradPendingOps().size(), 0UL);
@ -163,10 +190,10 @@ TEST(test_layer, test_varbase_basic) {
std::shared_ptr<imperative::VarBase> vin_with_grad(
new imperative::VarBase(true, "vin"));
ASSERT_ANY_THROW(vin->MutableGradVar());
ASSERT_NO_THROW(ASSERT_TRUE(dynamic_cast<framework::Variable*>(
ASSERT_NO_THROW(ASSERT_TRUE(dynamic_cast<framework::Variable *>(
vin_with_grad->MutableGradVar()) != 0));
ASSERT_TRUE(
dynamic_cast<framework::Variable*>(vin_with_grad->MutableGradVar()) != 0);
ASSERT_TRUE(dynamic_cast<framework::Variable *>(
vin_with_grad->MutableGradVar()) != 0);
vin_with_grad->SetOverridedStopGradient(false);
ASSERT_FALSE(vin_with_grad->OverridedStopGradient());
ASSERT_NO_FATAL_FAILURE(vin_with_grad->SetPersistable(true));
@ -195,14 +222,14 @@ TEST(test_layer, test_dygraph_execution_context) {
auto op = framework::OpRegistry::CreateOp("mul", {}, {}, {}, false);
paddle::platform::CPUPlace cpu_place;
paddle::platform::DeviceContextPool& pool =
paddle::platform::DeviceContextPool &pool =
paddle::platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(cpu_place);
auto *dev_ctx = pool.Get(cpu_place);
paddle::framework::RuntimeContext ctx({}, {});
framework::Scope scope;
DygraphExecutionContext dy_exe_context(*(op.get()), scope, *dev_ctx, ctx,
nullptr, ins, outs, &concat_att_map);
DygraphExecutionContext<imperative::VarBase> dy_exe_context(
*(op.get()), scope, *dev_ctx, ctx, nullptr, ins, outs, concat_att_map);
ASSERT_EQ(dy_exe_context.InputSize("X"), 1u);
ASSERT_EQ(dy_exe_context.InputName("X"), "vin");
@ -229,7 +256,8 @@ TEST(test_layer, test_dygraph_infershape_context) {
framework::AttributeMap concat_att_map;
concat_att_map["axis"] = 1;
DygraphInferShapeContext infer_shape_ctx(&ins, &outs, &concat_att_map);
DygraphInferShapeContext<imperative::VarBase> infer_shape_ctx(
&ins, &outs, &concat_att_map);
bool have_x = infer_shape_ctx.HasOutputs("Out");
ASSERT_EQ(have_x, true);

@ -114,7 +114,7 @@ TEST(test_prepare_op, test_prepare_op) {
ASSERT_NO_FATAL_FAILURE(PreparedOp preparedOp = PreparedOp::Prepare(
ins, outs,
dynamic_cast<framework::OperatorWithKernel&>(*op),
place, &split_attr_map));
place, split_attr_map));
}
const framework::Tensor* GetTensorFromVar(const framework::Variable& var);
@ -165,7 +165,7 @@ TEST(test_prepare_op, test_prepare_data) {
// test if it can be transformed to GPU place
PreparedOp prepared_op = PreparedOp::Prepare(
ins, outs, dynamic_cast<framework::OperatorWithKernel&>(*op), gpu_place,
&attr_map);
attr_map);
for (const auto& name_pair : ins) {
for (const auto& vb : name_pair.second) {
ASSERT_TRUE(platform::is_same_place(
@ -213,7 +213,7 @@ TEST(test_prepare_op, test_prepare_data_same_place) {
// test if it never transferred on GPU place
PreparedOp prepared_op = PreparedOp::Prepare(
ins, outs, dynamic_cast<framework::OperatorWithKernel&>(*op), cpu_place,
&attr_map);
attr_map);
for (const auto& name_pair : ins) {
for (const auto& vb : name_pair.second) {
ASSERT_TRUE(platform::is_same_place(

@ -18,6 +18,7 @@
#include <paddle/fluid/framework/op_registry.h>
#include <memory>
#include <set>
#include <string>
#include <vector>
#include "gtest/gtest.h"
@ -147,9 +148,9 @@ TEST(test_tracer, test_track_backward_output) {
framework::AttributeMap mul_attr_map;
mul_attr_map["use_mkldnn"] = false;
tracer.TraceOp("mul", ins, outs, mul_attr_map, place, true);
auto* engine = tracer.GetDefaultEngine();
ASSERT_NE(engine->GradVars().size(), 0UL);
ASSERT_NE(engine->GradOps().size(), 0UL); // trace_backward already ran.
ASSERT_EQ(x_in->GradVarBase()->GradOps().size(), 0UL);
ASSERT_EQ(y_in->GradVarBase()->GradOps().size(), 0UL);
ASSERT_EQ(vout->GradVarBase()->GradOps().size(), 1UL);
}
TEST(test_tracer, test_track_backward_input) {
@ -186,9 +187,10 @@ TEST(test_tracer, test_track_backward_input) {
framework::AttributeMap mul_attr_map;
mul_attr_map["use_mkldnn"] = false;
tracer.TraceOp("mul", ins, outs, mul_attr_map, place, true);
auto* engine = tracer.GetDefaultEngine();
ASSERT_NE(engine->GradVars().size(), 0UL);
ASSERT_NE(engine->GradOps().size(), 0UL); // trace_backward already ran.
ASSERT_EQ(x_in->GradVarBase()->GradOps().size(), 0UL);
ASSERT_EQ(y_in->GradVarBase()->GradOps().size(), 0UL);
ASSERT_EQ(vout->GradVarBase()->GradOps().size(), 1UL);
}
#if defined(PADDLE_WITH_CUDA)
TEST(test_tracer, test_trace_op_with_multi_device_inputs) {
@ -344,10 +346,12 @@ TEST(test_tracer, test_var_without_grad_var) {
ASSERT_EQ(out_tensor.data<float>()[i], 20.0);
}
ASSERT_EQ(x_in->GradVarBase()->GradOps().size(), 0UL);
ASSERT_EQ(y_in->GradVarBase()->GradOps().size(), 0UL);
ASSERT_EQ(vout->GradVarBase()->GradOps().size(), 1UL);
detail::BackwardStrategy back_st;
imperative::Engine* engine = tracer.GetDefaultEngine();
ASSERT_NE(engine->GradVars().size(), 0UL);
ASSERT_NE(engine->GradOps().size(), 0UL); // trace_backward already ran.
engine->Init(vout.get(), back_st);
engine->Execute();
@ -369,10 +373,137 @@ TEST(test_tracer, test_var_without_grad_var) {
}
}
template <typename T>
using WeakPtrSet =
std::set<std::weak_ptr<T>, std::owner_less<std::weak_ptr<T>>>;
static void TestVarOpDestructionMain(const platform::Place& place,
int64_t tensor_size = 10,
size_t loop_num = 10) {
WeakPtrSet<VariableWrapper> var_wrappers;
WeakPtrSet<VarBase> var_bases;
WeakPtrSet<OpBase> op_bases;
Tracer tracer;
{
auto x = std::make_shared<VarBase>("x");
auto y = std::make_shared<VarBase>("y");
x->MutableVar()
->GetMutable<framework::LoDTensor>()
->Resize({tensor_size, tensor_size})
.mutable_data<float>(place);
y->MutableVar()
->GetMutable<framework::LoDTensor>()
->Resize({tensor_size, tensor_size})
.mutable_data<float>(place);
x->SetOverridedStopGradient(false);
y->SetOverridedStopGradient(true);
for (size_t i = 0; i < loop_num; ++i) {
size_t var_wrapper_num = var_wrappers.size();
size_t var_base_num = var_bases.size();
size_t op_base_num = op_bases.size();
auto z = std::make_shared<VarBase>("z_" + std::to_string(i));
tracer.TraceOp("mul", NameVarBaseMap{{"X", {x}}, {"Y", {y}}},
NameVarBaseMap{{"Out", {z}}}, framework::AttributeMap{},
place, true);
ASSERT_EQ(z->GradOps().size(), 0UL);
ASSERT_EQ(z->GradVarBase()->GradOps().size(), 1UL);
auto new_op = z->GradVarBase()->GradOps()[0];
ASSERT_EQ(x->GradOps().size(), 0UL);
ASSERT_EQ(y->GradOps().size(), 0UL);
std::unordered_set<std::shared_ptr<OpBase>> expected_pending_ops;
if (i == 0) {
ASSERT_EQ(x->GradVarBase()->GradOps().size(), 0UL);
ASSERT_EQ(y->GradVarBase()->GradOps().size(), 0UL);
} else {
ASSERT_EQ(x->GradVarBase()->GradOps().size(), 1UL);
ASSERT_EQ(y->GradVarBase()->GradOps().size(), 0UL);
for (auto& op : x->GradVarBase()->GradOps()) {
expected_pending_ops.emplace(op);
}
for (auto& op : y->GradVarBase()->GradOps()) {
expected_pending_ops.emplace(op);
}
std::unordered_set<std::shared_ptr<OpBase>> actual_pending_ops;
for (auto& op : new_op->GradPendingOps()) {
actual_pending_ops.emplace(op);
}
ASSERT_TRUE(expected_pending_ops == actual_pending_ops);
ASSERT_EQ(expected_pending_ops.count(new_op), 0UL);
}
var_wrappers.emplace(x->SharedVar());
var_wrappers.emplace(x->GradVarBase()->SharedVar());
var_wrappers.emplace(y->SharedVar());
var_wrappers.emplace(y->GradVarBase()->SharedVar());
var_wrappers.emplace(z->SharedVar());
var_wrappers.emplace(z->GradVarBase()->SharedVar());
var_bases.emplace(x);
var_bases.emplace(x->GradVarBase());
var_bases.emplace(y);
var_bases.emplace(y->GradVarBase());
var_bases.emplace(z);
var_bases.emplace(z->GradVarBase());
for (auto& op : expected_pending_ops) {
op_bases.emplace(op);
}
if (i == 0) {
ASSERT_EQ(var_wrapper_num, 0UL);
ASSERT_EQ(var_base_num, 0UL);
ASSERT_EQ(op_base_num, 0UL);
ASSERT_EQ(var_wrappers.size(), 6UL);
ASSERT_EQ(var_bases.size(), 6UL);
ASSERT_EQ(op_bases.size(), 0UL);
} else {
ASSERT_EQ(var_wrappers.size(), var_wrapper_num + 2);
ASSERT_EQ(var_bases.size(), var_base_num + 2);
ASSERT_EQ(op_bases.size(), op_base_num + 1);
}
x = z; // recurrent usage
}
}
for (auto& var : var_wrappers) {
ASSERT_TRUE(var.expired());
}
for (auto& var : var_bases) {
ASSERT_TRUE(var.expired());
}
for (auto& op : op_bases) {
ASSERT_TRUE(op.expired());
}
}
TEST(test_tracer, test_var_op_destruction) {
TestVarOpDestructionMain(platform::CPUPlace());
#ifdef PADDLE_WITH_CUDA
TestVarOpDestructionMain(platform::CUDAPlace(0));
#endif
}
} // namespace imperative
} // namespace paddle
USE_OP(mul);
USE_OP(mul_grad);
USE_OP(reduce_sum);
USE_OP(reduce_sum_grad);
USE_OP(elementwise_add);

@ -15,7 +15,10 @@
#include <set>
#include <unordered_set>
#include <utility>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/string/string_helper.h"
namespace paddle {
namespace imperative {
@ -48,22 +51,24 @@ static void ClearNoNeedBufferInputs(OpBase* op) {
PADDLE_ENFORCE_EQ(var.IsType<framework::LoDTensor>(), true,
"Only support LoDTensor");
// TODO(zjl): support higher order derivatives
auto new_var = new VarBase(false, each_var->Name());
auto new_var = new VariableWrapper(each_var->Name());
auto* new_tensor =
new_var->MutableVar()->GetMutable<framework::LoDTensor>();
auto& old_tensor = var.Get<framework::LoDTensor>();
new_tensor->Resize(old_tensor.dims());
new_tensor->set_lod(old_tensor.lod());
each_var.reset(new_var);
op->AddAllowedEmptyVar(new_var);
}
}
}
static std::vector<std::unique_ptr<OpBase>> CreateGradOpBases(
const OpBase* fw_op_base, const NameVarBaseMap& in,
const NameVarBaseMap& out) {
if (fw_op_base->Info().dygraph_grad_op_maker_) {
return fw_op_base->Info().dygraph_grad_op_maker_(fw_op_base, in, out);
static std::vector<std::shared_ptr<OpBase>> CreateGradOpBases(
const framework::OpInfo& info, const std::string& type,
const NameVarBaseMap& in, const NameVarBaseMap& out,
const framework::AttributeMap& attrs) {
if (info.dygraph_grad_op_maker_) {
return info.dygraph_grad_op_maker_(type, in, out, attrs);
} else {
return {};
}
@ -83,17 +88,22 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins,
const NameVarBaseMap& outs, framework::AttributeMap attrs,
const platform::Place& place, bool trace_backward) {
VLOG(1) << "Trace Op: " << type;
size_t op_id = GenerateUniqueId();
auto op = OpBase::Create(op_id, type, ins, outs, attrs, place);
op->Run(ins, outs);
auto op = framework::OpRegistry::CreateOp(type, {}, {}, {}, false);
const auto& op_info = op->Info();
auto* attr_checker = op_info.Checker();
if (attr_checker) {
attr_checker->Check(&attrs, true);
}
OpBase::Run(*op, ins, outs, attrs, place);
if (enable_program_desc_tracing_) {
VLOG(5) << "Trace op " << type << " into ProgramDesc";
program_desc_tracer_->InsertOp(type, ins, outs, op->Attrs());
program_desc_tracer_->InsertOp(type, ins, outs, attrs);
}
if (ComputeRequiredGrad(ins, outs, trace_backward)) {
TraceBackward(op, ins, outs);
TraceBackward(op_info, type, ins, outs, attrs, place);
} else {
VLOG(3) << "No Grad to track for Op: " << type;
}
@ -102,22 +112,7 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins,
void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins,
const NameVarBaseMap& outs,
framework::AttributeMap attrs) {
VLOG(1) << "Trace Op: " << type;
size_t op_id = GenerateUniqueId();
auto op =
OpBase::Create(op_id, type, ins, outs, std::move(attrs), expected_place_);
op->Run(ins, outs);
if (enable_program_desc_tracing_) {
VLOG(5) << "Trace op " << type << " into ProgramDesc";
program_desc_tracer_->InsertOp(type, ins, outs, op->Attrs());
}
if (ComputeRequiredGrad(ins, outs, no_grad_)) {
TraceBackward(op, ins, outs);
} else {
VLOG(3) << "No Grad to track for Op: " << type;
}
TraceOp(type, ins, outs, std::move(attrs), expected_place_, no_grad_);
}
bool Tracer::ComputeRequiredGrad(const NameVarBaseMap& ins,
@ -138,78 +133,19 @@ bool Tracer::ComputeRequiredGrad(const NameVarBaseMap& ins,
return false;
}
void Tracer::TraceBackward(const std::shared_ptr<OpBase>& fwd_op,
const NameVarBaseMap& ins,
const NameVarBaseMap& outs) {
// grad_to_var is a map of framework::GradVarName(in_var_name/out_var_name) ->
// in_var_name/out_var_name
std::unordered_map<std::string, std::string> grad_to_var;
// Get grad_op_desc using fwd_op_desc
std::vector<std::unique_ptr<OpBase>> grad_op_bases_ =
CreateGradOpBases(fwd_op.get(), ins, outs);
size_t grad_op_num = grad_op_bases_.size();
std::set<VarBase*> set_input_vars;
for (auto& fwd_in_it : ins) {
for (auto& var_base_it : fwd_in_it.second) {
set_input_vars.insert(var_base_it.get());
}
}
for (auto& fwd_out_it : outs) {
for (auto& var_base_it : fwd_out_it.second) {
set_input_vars.insert(var_base_it.get());
}
}
for (size_t i = 0; i < grad_op_num; ++i) {
size_t trace_id = fwd_op->id();
std::shared_ptr<OpBase> grad_op = std::move(grad_op_bases_[i]);
void Tracer::TraceBackward(const framework::OpInfo& info,
const std::string& type, const NameVarBaseMap& ins,
const NameVarBaseMap& outs,
const framework::AttributeMap& attrs,
const platform::Place& place) {
auto grad_op_bases = CreateGradOpBases(info, type, ins, outs, attrs);
auto grad_op_num = grad_op_bases.size();
if (grad_op_num == 0) return;
size_t trace_id = GenerateUniqueId();
for (auto& grad_op : grad_op_bases) {
grad_op->SetPlace(place);
grad_op->SetId(trace_id);
grad_op->SetPlace(fwd_op->place());
grad_op->CreateOperatorBase();
auto& grad_in = *(grad_op->GetMutableInsMap());
auto& grad_out = *(grad_op->GetMutableOutsMap());
for (auto& grad_in_it : grad_in) {
for (auto& var_base_it : grad_in_it.second) {
if (set_input_vars.count(var_base_it.get()) == 0) {
var_base_it->AddGradOps(grad_op);
engine_->InsertGradVar(var_base_it.get());
}
}
}
std::set<OpBase*> visited_preceding_ops;
for (auto& grad_out_it : grad_out) {
bool flag_clear_list = false;
for (auto& var_base_it : grad_out_it.second) {
if ((!var_base_it->OverridedStopGradient()) ||
(grad_out_it.second.size() > 1)) {
auto preceding_ops = var_base_it->GradOps();
if (!preceding_ops.empty()) {
for (const auto& op : preceding_ops) {
visited_preceding_ops.insert(op);
}
}
} else {
flag_clear_list = true;
}
}
if (flag_clear_list) {
grad_out_it.second.clear();
}
}
std::vector<OpBase*> vec_preceding_ops(visited_preceding_ops.begin(),
visited_preceding_ops.end());
grad_op->SetGradPendingOps(std::move(vec_preceding_ops));
// this OpBase* is just used to manage op's life time
engine_->InsertOp(grad_op.get(), grad_op);
ClearNoNeedBufferInputs(grad_op.get());
}
}

@ -64,9 +64,6 @@ class Tracer {
bool ComputeRequiredGrad(const NameVarBaseMap& ins,
const NameVarBaseMap& outs, bool trace_backward);
void TraceBackward(const std::shared_ptr<OpBase>& fwd_op,
const NameVarBaseMap& ins, const NameVarBaseMap& outs);
Engine* GetDefaultEngine() const { return engine_.get(); }
void SetEnableProgramDescTracing(bool enabled) {
@ -94,6 +91,11 @@ class Tracer {
void SetNoGrad(bool no_grad) { no_grad_ = no_grad; }
private:
void TraceBackward(const framework::OpInfo& info, const std::string& type,
const NameVarBaseMap& ins, const NameVarBaseMap& outs,
const framework::AttributeMap& attrs,
const platform::Place& place);
static size_t GenerateUniqueId() {
static std::atomic<size_t> id{0};
return id.fetch_add(1);

@ -22,12 +22,16 @@ limitations under the License. */
namespace paddle {
namespace imperative {
class VariableWrapper;
class VarBase;
class OpBase;
class Tracer;
using NameVarBaseMap =
std::map<std::string, std::vector<std::shared_ptr<VarBase>>>;
template <typename T>
using NameVarMap = std::map<std::string, std::vector<std::shared_ptr<T>>>;
using NameVarBaseMap = NameVarMap<VarBase>;
using NameVariableWrapperMap = NameVarMap<VariableWrapper>;
using WeakNameVarBaseMap =
std::map<std::string, std::vector<std::weak_ptr<VarBase>>>;

@ -0,0 +1,102 @@
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/variable.h"
namespace paddle {
namespace imperative {
class VariableWrapper {
public:
explicit VariableWrapper(const std::string& name) : name_(name) {}
const framework::Variable& Var() const { return var_; }
framework::Variable* MutableVar() { return &var_; }
// This is used for python api
void SetOverridedStopGradient(bool stop_gradient) {
overrided_stop_gradient_ = static_cast<int>(stop_gradient);
}
// This is used for python api
bool OverridedStopGradient() const { return overrided_stop_gradient_ != 0; }
// This is used inside C++
int InnerOverridedStopGradient() const { return overrided_stop_gradient_; }
// This is used inside C++
void InnerSetOverridedStopGradient(bool stop_gradient) {
if (overrided_stop_gradient_ == -1) {
overrided_stop_gradient_ = static_cast<int>(stop_gradient);
} else {
VLOG(6) << "Ignore Stop gradient conversion for Var: " << Name()
<< "Set value is: " << overrided_stop_gradient_;
}
}
void SetPersistable(bool persistable) { persistable_ = persistable; }
bool Persistable() const { return persistable_; }
const std::string& Name() const { return name_; }
void SetName(const std::string& name) { name_ = name; }
void SetType(framework::proto::VarType::Type type) { type_ = type; }
framework::proto::VarType::Type Type() const { return type_; }
void SetDataType(framework::proto::VarType::Type data_type) {
data_type_ = data_type;
}
framework::proto::VarType::Type DataType() const {
const framework::Tensor* tensor = nullptr;
if (var_.IsInitialized()) {
if (type_ == framework::proto::VarType::LOD_TENSOR) {
tensor = &(var_.Get<framework::LoDTensor>());
} else if (type_ == framework::proto::VarType::SELECTED_ROWS) {
tensor = &(var_.Get<framework::SelectedRows>().value());
} else {
VLOG(6) << "Variable " << name_ << " is not initialized";
return data_type_;
}
}
if (tensor && tensor->IsInitialized()) {
return tensor->type();
} else {
VLOG(6) << "The tensor of variable " << name_ << " is not initialized";
return data_type_;
}
}
private:
framework::Variable var_;
std::string name_;
// add this property for users may set stop_gradient themselves and this
// should override the frameworks setting (-1) unset, (1) true, (0) false
int overrided_stop_gradient_{-1};
bool persistable_{false};
framework::proto::VarType::Type type_{framework::proto::VarType::LOD_TENSOR};
framework::proto::VarType::Type data_type_{framework::proto::VarType::FP32};
};
} // namespace imperative
} // namespace paddle

@ -68,8 +68,7 @@ class ActivationGradOpMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> op(new T());
void Apply(GradOpPtr<T> op) const override {
op->SetType(this->ForwardOpType() + "_grad");
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
@ -86,8 +85,6 @@ class ActivationGradOpMaker : public framework::SingleGradOpMaker<T> {
static_cast<int>(ActBwdOpFwdDeps::kDepOut)) {
op->SetInput("Out", this->Output("Out"));
}
return op;
}
};
@ -727,8 +724,7 @@ class ReluDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
auto* op = new T();
void Apply(GradOpPtr<T> op) const override {
op->SetType("relu_grad_grad");
// input1: Out
op->SetInput("Out", this->Input("Out"));
@ -737,7 +733,6 @@ class ReluDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
op->SetAttrMap(this->Attrs());
// output: ddy
op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
return std::unique_ptr<T>(op);
}
};
@ -750,8 +745,7 @@ class LeakyReluDoubleGradMaker
using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
auto* op = new T();
void Apply(GradOpPtr<T> op) const override {
op->SetType("leaky_relu_grad_grad");
// input1: Out
op->SetInput("Out", this->Input("Out"));
@ -760,7 +754,6 @@ class LeakyReluDoubleGradMaker
op->SetAttrMap(this->Attrs());
// Out@GRAD@GRAD: ddy
op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
return std::unique_ptr<T>(op);
}
};
@ -772,8 +765,7 @@ class ELUDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
auto* op = new T();
void Apply(GradOpPtr<T> op) const override {
op->SetType("elu_grad_grad");
op->SetInput("X", this->Input("X"));
@ -785,7 +777,6 @@ class ELUDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
// Out@GRAD@GRAD: ddy
op->SetOutput("DX", this->InputGrad("X"));
op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
return std::unique_ptr<T>(op);
}
};
@ -797,8 +788,7 @@ class SqrtDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
auto* op = new T();
void Apply(GradOpPtr<T> op) const override {
op->SetType("sqrt_grad_grad");
op->SetInput("Out", this->Input("Out"));
op->SetInput("DX", this->Output(framework::GradVarName("X")));
@ -806,7 +796,6 @@ class SqrtDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
op->SetAttrMap(this->Attrs());
op->SetOutput("DOut", this->InputGrad("Out"));
op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
return std::unique_ptr<T>(op);
}
};
@ -818,8 +807,7 @@ class SquareDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
auto* op = new T();
void Apply(GradOpPtr<T> op) const override {
op->SetType("square_grad_grad");
op->SetInput("X", this->Input("X"));
// Out@GRAD: dy
@ -833,7 +821,6 @@ class SquareDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
op->SetOutput("DX", this->InputGrad("X"));
// Out@GRAD@GRAD: ddy
op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
return std::unique_ptr<T>(op);
}
};
@ -849,16 +836,13 @@ class PowGradOpMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> op(new T());
void Apply(GradOpPtr<T> op) const override {
op->SetType("pow_grad");
op->SetInput("X", this->Input("X"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetInput("FactorTensor", this->Input("FactorTensor"));
op->SetAttrMap(this->Attrs());
return op;
}
};
class PowOp : public framework::OperatorWithKernel {

@ -93,13 +93,11 @@ class AddPositionEncodingGradOpMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> op(new T());
void Apply(GradOpPtr<T> op) const override {
op->SetType("add_position_encoding_grad");
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs());
return op;
}
};

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save