Add dygraph double grad implementation (#22939)
	
		
	
				
					
				
			* add double grad implementation for dygraph, test=develop * polish code, add uts, test=develop * fix place bug, test=develop * polish codes, add more uts for coverages, test=develop * add no_grad_set, test=develop * add star gan ut, test=develop * follow comments, test=developrevert-23830-2.0-beta
							parent
							
								
									995a6376f7
								
							
						
					
					
						commit
						a31d7328b7
					
				
											
												
													File diff suppressed because it is too large
													Load Diff
												
											
										
									
								@ -0,0 +1,57 @@
 | 
				
			||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||
//
 | 
				
			||||
// Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
// you may not use this file except in compliance with the License.
 | 
				
			||||
// You may obtain a copy of the License at
 | 
				
			||||
//
 | 
				
			||||
//     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
//
 | 
				
			||||
// Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
// distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
// See the License for the specific language governing permissions and
 | 
				
			||||
// limitations under the License.
 | 
				
			||||
 | 
				
			||||
#pragma once
 | 
				
			||||
 | 
				
			||||
#include <memory>
 | 
				
			||||
#include <unordered_map>
 | 
				
			||||
#include <utility>
 | 
				
			||||
#include <vector>
 | 
				
			||||
#include "paddle/fluid/imperative/backward_strategy.h"
 | 
				
			||||
#include "paddle/fluid/imperative/engine.h"
 | 
				
			||||
#include "paddle/fluid/imperative/gradient_accumulator.h"
 | 
				
			||||
 | 
				
			||||
namespace paddle {
 | 
				
			||||
namespace imperative {
 | 
				
			||||
 | 
				
			||||
class VarBase;
 | 
				
			||||
class OpBase;
 | 
				
			||||
 | 
				
			||||
class BasicEngine : public Engine {
 | 
				
			||||
 public:
 | 
				
			||||
  void Init(VarBase* var, const detail::BackwardStrategy& strategy);
 | 
				
			||||
 | 
				
			||||
  void Execute() override;
 | 
				
			||||
 | 
				
			||||
 private:
 | 
				
			||||
  void PrepareDeps();
 | 
				
			||||
 | 
				
			||||
  void CheckBackwardInputs(const OpBase& op);
 | 
				
			||||
 | 
				
			||||
  void PrepareGradAccumulators(const OpBase& op);
 | 
				
			||||
 | 
				
			||||
  void Clear();
 | 
				
			||||
 | 
				
			||||
 private:
 | 
				
			||||
  std::shared_ptr<GradOpNode> init_node_;
 | 
				
			||||
  detail::BackwardStrategy backward_strategy_;
 | 
				
			||||
  std::unordered_map<GradOpNode*, size_t> node_deps_;
 | 
				
			||||
  std::unordered_map<VariableWrapper*, std::unique_ptr<GradientAccumulator>>
 | 
				
			||||
      accumulators_;
 | 
				
			||||
  std::vector<std::pair<GradientAccumulator*, std::shared_ptr<VariableWrapper>>>
 | 
				
			||||
      need_accu_var_list_;
 | 
				
			||||
};
 | 
				
			||||
 | 
				
			||||
}  // namespace imperative
 | 
				
			||||
}  // namespace paddle
 | 
				
			||||
											
												
													File diff suppressed because it is too large
													Load Diff
												
											
										
									
								@ -0,0 +1,199 @@
 | 
				
			||||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||
//
 | 
				
			||||
// Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
// you may not use this file except in compliance with the License.
 | 
				
			||||
// You may obtain a copy of the License at
 | 
				
			||||
//
 | 
				
			||||
//     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
//
 | 
				
			||||
// Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
// distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
// See the License for the specific language governing permissions and
 | 
				
			||||
// limitations under the License.
 | 
				
			||||
 | 
				
			||||
#pragma once
 | 
				
			||||
 | 
				
			||||
#include <string>
 | 
				
			||||
#include <vector>
 | 
				
			||||
#include "paddle/fluid/framework/operator.h"
 | 
				
			||||
#include "paddle/fluid/framework/type_defs.h"
 | 
				
			||||
#include "paddle/fluid/framework/variable.h"
 | 
				
			||||
#include "paddle/fluid/imperative/type_defs.h"
 | 
				
			||||
 | 
				
			||||
namespace paddle {
 | 
				
			||||
namespace imperative {
 | 
				
			||||
 | 
				
			||||
template <typename VarType>
 | 
				
			||||
class DygraphExecutionContext : public framework::ExecutionContext {
 | 
				
			||||
  using Variable = framework::Variable;
 | 
				
			||||
 | 
				
			||||
 public:
 | 
				
			||||
  DygraphExecutionContext(const framework::OperatorBase& op,
 | 
				
			||||
                          const framework::Scope& scope,
 | 
				
			||||
                          const platform::DeviceContext& device_context,
 | 
				
			||||
                          const framework::RuntimeContext& ctx,
 | 
				
			||||
                          std::vector<framework::KernelConfig>* configs,
 | 
				
			||||
                          const NameVarMap<VarType>& var_base_map_in,
 | 
				
			||||
                          const NameVarMap<VarType>& var_base_map_out,
 | 
				
			||||
                          const framework::AttributeMap& attrs)
 | 
				
			||||
      : ExecutionContext(op, scope, device_context, ctx, configs),
 | 
				
			||||
        var_base_map_in_(var_base_map_in),
 | 
				
			||||
        var_base_map_out_(var_base_map_out),
 | 
				
			||||
        attrs_(attrs) {}
 | 
				
			||||
 | 
				
			||||
  std::string InputName(const std::string& name) const override {
 | 
				
			||||
    auto it = var_base_map_in_.find(name);
 | 
				
			||||
    PADDLE_ENFORCE_NE(it, var_base_map_in_.end(),
 | 
				
			||||
                      platform::errors::PreconditionNotMet(
 | 
				
			||||
                          "Can not find [%s] in Input", name));
 | 
				
			||||
    return it->second[0] ? it->second[0]->Name() : framework::kEmptyVarName;
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
  std::vector<std::string> InputNames(const std::string& name) const override {
 | 
				
			||||
    auto it = var_base_map_in_.find(name);
 | 
				
			||||
    PADDLE_ENFORCE_NE(
 | 
				
			||||
        it, var_base_map_in_.end(),
 | 
				
			||||
        platform::errors::NotFound("Can not find [%s] in Input", name));
 | 
				
			||||
    std::vector<std::string> vec_res;
 | 
				
			||||
    vec_res.reserve(it->second.size());
 | 
				
			||||
    for (size_t i = 0; i < it->second.size(); ++i) {
 | 
				
			||||
      if (it->second[i]) {
 | 
				
			||||
        vec_res.push_back(it->second[i]->Name());
 | 
				
			||||
      } else {
 | 
				
			||||
        vec_res.push_back(framework::kEmptyVarName);
 | 
				
			||||
      }
 | 
				
			||||
    }
 | 
				
			||||
    return vec_res;
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
  std::string OutputName(const std::string& name) const override {
 | 
				
			||||
    auto it = var_base_map_out_.find(name);
 | 
				
			||||
    PADDLE_ENFORCE_NE(
 | 
				
			||||
        it, var_base_map_out_.end(),
 | 
				
			||||
        platform::errors::NotFound("Can not find [%s] in Output", name));
 | 
				
			||||
    return it->second[0] ? it->second[0]->Name() : framework::kEmptyVarName;
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
  std::vector<std::string> OutputNames(const std::string& name) const override {
 | 
				
			||||
    auto it = var_base_map_out_.find(name);
 | 
				
			||||
    PADDLE_ENFORCE_NE(
 | 
				
			||||
        it, var_base_map_out_.end(),
 | 
				
			||||
        platform::errors::NotFound("Can not find [%s] in Output", name));
 | 
				
			||||
    std::vector<std::string> vec_res;
 | 
				
			||||
    vec_res.reserve(it->second.size());
 | 
				
			||||
    for (size_t i = 0; i < it->second.size(); ++i) {
 | 
				
			||||
      if (it->second[i]) {
 | 
				
			||||
        vec_res.push_back(it->second[i]->Name());
 | 
				
			||||
      } else {
 | 
				
			||||
        vec_res.push_back(framework::kEmptyVarName);
 | 
				
			||||
      }
 | 
				
			||||
    }
 | 
				
			||||
    return vec_res;
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
  bool HasAttr(const std::string& name) const override {
 | 
				
			||||
    return attrs_.count(name) != 0;
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
  const framework::AttributeMap& Attrs() const override { return attrs_; }
 | 
				
			||||
 | 
				
			||||
  const framework::Attribute& GetAttr(const std::string& name) const override {
 | 
				
			||||
    auto it = attrs_.find(name);
 | 
				
			||||
 | 
				
			||||
    PADDLE_ENFORCE_NE(
 | 
				
			||||
        it, attrs_.end(),
 | 
				
			||||
        platform::errors::NotFound("can not find [%s] in attrs", name));
 | 
				
			||||
 | 
				
			||||
    return it->second;
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
  std::vector<std::string> InNameList() const override {
 | 
				
			||||
    std::vector<std::string> vec_temp;
 | 
				
			||||
    vec_temp.reserve(var_base_map_in_.size());
 | 
				
			||||
 | 
				
			||||
    for (auto& v : var_base_map_in_) {
 | 
				
			||||
      vec_temp.push_back(v.first);
 | 
				
			||||
    }
 | 
				
			||||
 | 
				
			||||
    return vec_temp;
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
  bool HasInput(const std::string& name) const override {
 | 
				
			||||
    auto it = var_base_map_in_.find(name);
 | 
				
			||||
    return (it != var_base_map_in_.end() && it->second.size() > 0);
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
  bool HasOutput(const std::string& name) const override {
 | 
				
			||||
    auto it = var_base_map_out_.find(name);
 | 
				
			||||
    return (it != var_base_map_out_.end() && it->second.size() > 0);
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
  size_t InputSize(const std::string& name) const override {
 | 
				
			||||
    return InputNames(name).size();
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
  size_t OutputSize(const std::string& name) const override {
 | 
				
			||||
    return OutputNames(name).size();
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
  const Variable* InputVar(const std::string& name) const override {
 | 
				
			||||
    auto it = var_base_map_in_.find(name);
 | 
				
			||||
    if (it == var_base_map_in_.end()) {
 | 
				
			||||
      return nullptr;
 | 
				
			||||
    }
 | 
				
			||||
 | 
				
			||||
    return it->second.empty() || it->second[0] == nullptr
 | 
				
			||||
               ? nullptr
 | 
				
			||||
               : it->second[0]->MutableVar();
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
  Variable* OutputVar(const std::string& name) const override {
 | 
				
			||||
    auto it = var_base_map_out_.find(name);
 | 
				
			||||
    if (it == var_base_map_out_.end()) {
 | 
				
			||||
      return nullptr;
 | 
				
			||||
    }
 | 
				
			||||
 | 
				
			||||
    return it->second.empty() || it->second[0] == nullptr
 | 
				
			||||
               ? nullptr
 | 
				
			||||
               : it->second[0]->MutableVar();
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
  const std::vector<Variable*> MultiInputVar(
 | 
				
			||||
      const std::string& name) const override {
 | 
				
			||||
    auto it = var_base_map_in_.find(name);
 | 
				
			||||
    if (it == var_base_map_in_.end()) {
 | 
				
			||||
      return {};
 | 
				
			||||
    }
 | 
				
			||||
    std::vector<Variable*> vec_res;
 | 
				
			||||
    vec_res.reserve(it->second.size());
 | 
				
			||||
    for (size_t i = 0; i < it->second.size(); ++i) {
 | 
				
			||||
      vec_res.push_back(it->second[i] ? it->second[i]->MutableVar() : nullptr);
 | 
				
			||||
    }
 | 
				
			||||
 | 
				
			||||
    return vec_res;
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
  std::vector<Variable*> MultiOutputVar(
 | 
				
			||||
      const std::string& name) const override {
 | 
				
			||||
    auto it = var_base_map_out_.find(name);
 | 
				
			||||
    if (it == var_base_map_out_.end()) {
 | 
				
			||||
      return {};
 | 
				
			||||
    }
 | 
				
			||||
    std::vector<Variable*> vec_res;
 | 
				
			||||
    vec_res.reserve(it->second.size());
 | 
				
			||||
    for (size_t i = 0; i < it->second.size(); ++i) {
 | 
				
			||||
      vec_res.push_back(it->second[i] ? it->second[i]->MutableVar() : nullptr);
 | 
				
			||||
    }
 | 
				
			||||
 | 
				
			||||
    return vec_res;
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
 private:
 | 
				
			||||
  const NameVarMap<VarType>& var_base_map_in_;
 | 
				
			||||
  const NameVarMap<VarType>& var_base_map_out_;
 | 
				
			||||
  const framework::AttributeMap& attrs_;
 | 
				
			||||
};
 | 
				
			||||
 | 
				
			||||
}  // namespace imperative
 | 
				
			||||
}  // namespace paddle
 | 
				
			||||
											
												
													File diff suppressed because it is too large
													Load Diff
												
											
										
									
								
											
												
													File diff suppressed because it is too large
													Load Diff
												
											
										
									
								@ -0,0 +1,188 @@
 | 
				
			||||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||
//
 | 
				
			||||
// Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
// you may not use this file except in compliance with the License.
 | 
				
			||||
// You may obtain a copy of the License at
 | 
				
			||||
//
 | 
				
			||||
//     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
//
 | 
				
			||||
// Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
// distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
// See the License for the specific language governing permissions and
 | 
				
			||||
// limitations under the License.
 | 
				
			||||
 | 
				
			||||
#pragma once
 | 
				
			||||
 | 
				
			||||
#include <string>
 | 
				
			||||
#include <unordered_map>
 | 
				
			||||
#include <vector>
 | 
				
			||||
#include "paddle/fluid/framework/type_defs.h"
 | 
				
			||||
#include "paddle/fluid/framework/var_type_inference.h"
 | 
				
			||||
#include "paddle/fluid/imperative/type_defs.h"
 | 
				
			||||
#include "paddle/fluid/imperative/variable_wrapper.h"
 | 
				
			||||
 | 
				
			||||
namespace paddle {
 | 
				
			||||
namespace imperative {
 | 
				
			||||
 | 
				
			||||
// infer var type context for imperative mode
 | 
				
			||||
template <typename VarType>
 | 
				
			||||
class RuntimeInferVarTypeContext : public framework::InferVarTypeContext {
 | 
				
			||||
 public:
 | 
				
			||||
  RuntimeInferVarTypeContext(const NameVarMap<VarType>& inputs,
 | 
				
			||||
                             const NameVarMap<VarType>& outputs,
 | 
				
			||||
                             const framework::AttributeMap& attrs_map)
 | 
				
			||||
      : InferVarTypeContext(nullptr, nullptr),
 | 
				
			||||
        inputs_(inputs),
 | 
				
			||||
        outputs_(outputs),
 | 
				
			||||
        attrs_(attrs_map),
 | 
				
			||||
        input_names_(),
 | 
				
			||||
        output_names_(),
 | 
				
			||||
        var_set_() {
 | 
				
			||||
    input_names_.reserve(inputs_.size());
 | 
				
			||||
    for (auto& it : inputs_) {
 | 
				
			||||
      for (auto& var : it.second) {
 | 
				
			||||
        if (var) {
 | 
				
			||||
          input_names_[it.first].emplace_back(var->Name());
 | 
				
			||||
          var_set_[var->Name()] = var.get();
 | 
				
			||||
        }
 | 
				
			||||
      }
 | 
				
			||||
    }
 | 
				
			||||
 | 
				
			||||
    output_names_.reserve(outputs_.size());
 | 
				
			||||
    for (auto& it : outputs_) {
 | 
				
			||||
      for (auto& var : it.second) {
 | 
				
			||||
        if (var) {
 | 
				
			||||
          output_names_[it.first].emplace_back(var->Name());
 | 
				
			||||
          var_set_[var->Name()] = var.get();
 | 
				
			||||
        }
 | 
				
			||||
      }
 | 
				
			||||
    }
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
  virtual ~RuntimeInferVarTypeContext() {}
 | 
				
			||||
 | 
				
			||||
  framework::Attribute GetAttr(const std::string& name) const override {
 | 
				
			||||
    auto iter = attrs_.find(name);
 | 
				
			||||
    PADDLE_ENFORCE_EQ(
 | 
				
			||||
        iter != attrs_.end(), true,
 | 
				
			||||
        platform::errors::NotFound("Cannot find attribute %s", name));
 | 
				
			||||
    return iter->second;
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
  bool HasVar(const std::string& name) const override {
 | 
				
			||||
    return var_set_.count(name) > 0;
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
  bool HasInput(const std::string& name) const override {
 | 
				
			||||
    auto it = inputs_.find(name);
 | 
				
			||||
    return (it != inputs_.end() && it->second.size() > 0);
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
  bool HasOutput(const std::string& name) const override {
 | 
				
			||||
    auto it = outputs_.find(name);
 | 
				
			||||
    return (it != outputs_.end() && it->second.size() > 0);
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
  const std::vector<std::string>& Input(
 | 
				
			||||
      const std::string& name) const override {
 | 
				
			||||
    auto iter = input_names_.find(name);
 | 
				
			||||
    PADDLE_ENFORCE_EQ(
 | 
				
			||||
        iter != input_names_.end(), true,
 | 
				
			||||
        platform::errors::NotFound("Cannot find input var %s", name));
 | 
				
			||||
    return iter->second;
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
  const std::vector<std::string>& Output(
 | 
				
			||||
      const std::string& name) const override {
 | 
				
			||||
    auto iter = output_names_.find(name);
 | 
				
			||||
 | 
				
			||||
    PADDLE_ENFORCE_EQ(
 | 
				
			||||
        iter != output_names_.end(), true,
 | 
				
			||||
        platform::errors::NotFound("Cannot find output var %s", name));
 | 
				
			||||
    return iter->second;
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
  framework::proto::VarType::Type GetType(
 | 
				
			||||
      const std::string& name) const override {
 | 
				
			||||
    auto iter = var_set_.find(name);
 | 
				
			||||
 | 
				
			||||
    PADDLE_ENFORCE_EQ(
 | 
				
			||||
        iter != var_set_.end(), true,
 | 
				
			||||
        platform::errors::NotFound("Cannot find var %s in GetType", name));
 | 
				
			||||
    return iter->second->Type();
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
  void SetType(const std::string& name,
 | 
				
			||||
               framework::proto::VarType::Type type) override {
 | 
				
			||||
    if (name == "kLookupTablePath") {
 | 
				
			||||
      VLOG(2) << "SUPER UGLY FIX, remove this when move imperative mode in C++";
 | 
				
			||||
    } else {
 | 
				
			||||
      var_set_[name]->SetType(type);
 | 
				
			||||
      if ((var_set_[name]->MutableVar()->IsInitialized() == true) &&
 | 
				
			||||
          (var_set_[name]->MutableVar()->Type() != type)) {
 | 
				
			||||
        var_set_[name]->MutableVar()->Clear();
 | 
				
			||||
      }
 | 
				
			||||
    }
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
  framework::proto::VarType::Type GetDataType(
 | 
				
			||||
      const std::string& name) const override {
 | 
				
			||||
    auto iter = var_set_.find(name);
 | 
				
			||||
 | 
				
			||||
    PADDLE_ENFORCE_EQ(
 | 
				
			||||
        iter != var_set_.end(), true,
 | 
				
			||||
        platform::errors::NotFound("Cannot find var %s in GetDataType", name));
 | 
				
			||||
    return iter->second->DataType();
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
  void SetDataType(const std::string& name,
 | 
				
			||||
                   framework::proto::VarType::Type type) override {
 | 
				
			||||
    var_set_[name]->SetDataType(type);
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
  std::vector<framework::proto::VarType::Type> GetDataTypes(
 | 
				
			||||
      const std::string& name) const override {
 | 
				
			||||
    PADDLE_THROW(platform::errors::PermissionDenied(
 | 
				
			||||
        "GetDataTypes is not supported in runtime InferVarType"));
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
  void SetDataTypes(const std::string& name,
 | 
				
			||||
                    const std::vector<framework::proto::VarType::Type>&
 | 
				
			||||
                        multiple_data_type) override {
 | 
				
			||||
    PADDLE_THROW(platform::errors::PermissionDenied(
 | 
				
			||||
        "SetDataTypes is not supported in runtime InferVarType"));
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
  std::vector<int64_t> GetShape(const std::string& name) const override {
 | 
				
			||||
    PADDLE_THROW(platform::errors::PermissionDenied(
 | 
				
			||||
        "Do not handle Shape in runtime InferVarType"));
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
  void SetShape(const std::string& name,
 | 
				
			||||
                const std::vector<int64_t>& dims) override {
 | 
				
			||||
    PADDLE_THROW(platform::errors::PermissionDenied(
 | 
				
			||||
        "Do not handle Shape in runtime InferVarType"));
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
  int32_t GetLoDLevel(const std::string& name) const override {
 | 
				
			||||
    PADDLE_THROW(platform::errors::PermissionDenied(
 | 
				
			||||
        "Do not handle LoDLevel in runtime InferVarType"));
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
  void SetLoDLevel(const std::string& name, int32_t lod_level) override {
 | 
				
			||||
    PADDLE_THROW(platform::errors::PermissionDenied(
 | 
				
			||||
        "Do not handle LoDLevel in runtime InferVarType"));
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
 private:
 | 
				
			||||
  const NameVarMap<VarType>& inputs_;
 | 
				
			||||
  const NameVarMap<VarType>& outputs_;
 | 
				
			||||
  const framework::AttributeMap& attrs_;
 | 
				
			||||
  std::unordered_map<std::string, std::vector<std::string>> input_names_;
 | 
				
			||||
  std::unordered_map<std::string, std::vector<std::string>> output_names_;
 | 
				
			||||
  std::unordered_map<std::string, VarType*> var_set_;
 | 
				
			||||
};
 | 
				
			||||
 | 
				
			||||
}  // namespace imperative
 | 
				
			||||
}  // namespace paddle
 | 
				
			||||
											
												
													File diff suppressed because it is too large
													Load Diff
												
											
										
									
								@ -0,0 +1,211 @@
 | 
				
			||||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||
//
 | 
				
			||||
// Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
// you may not use this file except in compliance with the License.
 | 
				
			||||
// You may obtain a copy of the License at
 | 
				
			||||
//
 | 
				
			||||
//     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
//
 | 
				
			||||
// Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
// distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
// See the License for the specific language governing permissions and
 | 
				
			||||
// limitations under the License.
 | 
				
			||||
 | 
				
			||||
#pragma once
 | 
				
			||||
 | 
				
			||||
#include <atomic>
 | 
				
			||||
#include <memory>
 | 
				
			||||
#include <string>
 | 
				
			||||
#include <utility>
 | 
				
			||||
#include <vector>
 | 
				
			||||
#include "paddle/fluid/framework/type_defs.h"
 | 
				
			||||
#include "paddle/fluid/imperative/type_defs.h"
 | 
				
			||||
#include "paddle/fluid/imperative/variable_wrapper.h"
 | 
				
			||||
#include "paddle/fluid/platform/place.h"
 | 
				
			||||
 | 
				
			||||
namespace paddle {
 | 
				
			||||
namespace imperative {
 | 
				
			||||
 | 
				
			||||
// TODO(zjl): to support py_func layer
 | 
				
			||||
class OpBase {
 | 
				
			||||
 public:
 | 
				
			||||
  OpBase() = default;
 | 
				
			||||
 | 
				
			||||
  OpBase(const OpBase&) = delete;
 | 
				
			||||
 | 
				
			||||
  OpBase(OpBase&&) = default;
 | 
				
			||||
 | 
				
			||||
  OpBase& operator=(const OpBase&) = delete;
 | 
				
			||||
 | 
				
			||||
  OpBase& operator=(OpBase&&) = default;
 | 
				
			||||
 | 
				
			||||
  ~OpBase() { VLOG(3) << "Destruct Op: " << Type(); }
 | 
				
			||||
 | 
				
			||||
  const std::string& Type() const { return op_->Type(); }
 | 
				
			||||
 | 
				
			||||
  const framework::AttributeMap& Attrs() const { return attrs_; }
 | 
				
			||||
 | 
				
			||||
  const framework::OpInfo& Info() const { return op_->Info(); }
 | 
				
			||||
 | 
				
			||||
  const framework::OperatorBase& InnerOp() const { return *op_; }
 | 
				
			||||
 | 
				
			||||
  void ClearBackwardTrace();
 | 
				
			||||
 | 
				
			||||
  NameVarMap<VariableWrapper>* GetMutableOutsMap() { return &outs_; }
 | 
				
			||||
 | 
				
			||||
  NameVarMap<VariableWrapper>* GetMutableInsMap() { return &ins_; }
 | 
				
			||||
 | 
				
			||||
  const NameVarMap<VariableWrapper>& GetInsMap() const { return ins_; }
 | 
				
			||||
 | 
				
			||||
  const NameVarMap<VariableWrapper>& GetOutsMap() const { return outs_; }
 | 
				
			||||
 | 
				
			||||
  void SetType(const std::string& type);
 | 
				
			||||
 | 
				
			||||
  void CheckAttrs() {
 | 
				
			||||
    auto& info = op_->Info();
 | 
				
			||||
    if (info.Checker() != nullptr) {
 | 
				
			||||
      info.Checker()->Check(&attrs_, true);
 | 
				
			||||
    }
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
  void SetInput(const std::string& name, VariableWrapperList vars,
 | 
				
			||||
                bool is_grad) {
 | 
				
			||||
    auto& in_vars = ins_[name];
 | 
				
			||||
    *(in_vars.MutableVarList()) = std::move(vars);
 | 
				
			||||
    in_vars.SetIsGrad(is_grad);
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
  void SetOutput(const std::string& name, VariableWrapperList vars,
 | 
				
			||||
                 bool is_grad) {
 | 
				
			||||
    auto& out_vars = outs_[name];
 | 
				
			||||
    *(out_vars.MutableVarList()) = std::move(vars);
 | 
				
			||||
    out_vars.SetIsGrad(is_grad);
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
  void SetAttrMap(const framework::AttributeMap& attrs) { attrs_ = attrs; }
 | 
				
			||||
 | 
				
			||||
  void SetAttr(const std::string& name, const framework::Attribute& v) {
 | 
				
			||||
    attrs_[name] = v;
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
  void SetBlockAttr(const std::string& name, framework::BlockDesc* block) {
 | 
				
			||||
    PADDLE_THROW(platform::errors::PermissionDenied(
 | 
				
			||||
        "SetBlockAttr is not support in dygraph OpBase"));
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
  const framework::AttributeMap& Attrs() { return attrs_; }
 | 
				
			||||
 | 
				
			||||
  bool HasAttr(const std::string& name) const { return attrs_.count(name) > 0; }
 | 
				
			||||
 | 
				
			||||
  const framework::Attribute& GetAttr(const std::string& name) const {
 | 
				
			||||
    auto it = attrs_.find(name);
 | 
				
			||||
    PADDLE_ENFORCE_NE(
 | 
				
			||||
        it, attrs_.end(),
 | 
				
			||||
        platform::errors::NotFound("can not find attribute [%s]", name));
 | 
				
			||||
    return it->second;
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
  template <typename T>
 | 
				
			||||
  inline const T& Attr(const std::string& name) const {
 | 
				
			||||
    return boost::get<T>(GetAttr(name));
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
  size_t id() const { return id_; }
 | 
				
			||||
 | 
				
			||||
  void SetId(size_t id) { id_ = id; }
 | 
				
			||||
 | 
				
			||||
  const platform::Place& place() const { return place_; }
 | 
				
			||||
 | 
				
			||||
  void SetPlace(const platform::Place& place) { place_ = place; }
 | 
				
			||||
 | 
				
			||||
  static size_t GenerateUniqueId() {
 | 
				
			||||
    static std::atomic<size_t> unique_id{0};
 | 
				
			||||
    return unique_id.fetch_add(1);
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
  static void Run(const framework::OperatorBase& op,
 | 
				
			||||
                  const NameVarMap<VarBase>& ins,
 | 
				
			||||
                  const NameVarMap<VarBase>& outs,
 | 
				
			||||
                  const framework::AttributeMap& attrs,
 | 
				
			||||
                  const platform::Place& place);
 | 
				
			||||
 | 
				
			||||
  static void Run(const framework::OperatorBase& op,
 | 
				
			||||
                  const NameVarMap<VariableWrapper>& ins,
 | 
				
			||||
                  const NameVarMap<VariableWrapper>& outs,
 | 
				
			||||
                  const framework::AttributeMap& attrs,
 | 
				
			||||
                  const platform::Place& place);
 | 
				
			||||
 | 
				
			||||
 private:
 | 
				
			||||
  NameVarMap<VariableWrapper> ins_;
 | 
				
			||||
  NameVarMap<VariableWrapper> outs_;
 | 
				
			||||
  framework::AttributeMap attrs_;
 | 
				
			||||
  std::unique_ptr<framework::OperatorBase> op_;
 | 
				
			||||
  platform::Place place_;
 | 
				
			||||
  size_t id_{-1UL};
 | 
				
			||||
 | 
				
			||||
  std::vector<std::function<void()>> backward_hooks_;
 | 
				
			||||
};
 | 
				
			||||
 | 
				
			||||
class GradOpNode {
 | 
				
			||||
 public:
 | 
				
			||||
  GradOpNode() = default;
 | 
				
			||||
 | 
				
			||||
  void reserve(size_t size) { ops_.reserve(size); }
 | 
				
			||||
 | 
				
			||||
  size_t size() const { return ops_.size(); }
 | 
				
			||||
 | 
				
			||||
  bool empty() const { return ops_.empty(); }
 | 
				
			||||
 | 
				
			||||
  void clear() { ops_.clear(); }
 | 
				
			||||
 | 
				
			||||
  void pop_back() { ops_.pop_back(); }
 | 
				
			||||
 | 
				
			||||
  template <typename... ARGS>
 | 
				
			||||
  OpBase& emplace_back(ARGS&&... args) {  // NOLINT
 | 
				
			||||
    ops_.emplace_back(std::forward<ARGS>(args)...);
 | 
				
			||||
    return ops_.back();
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
  const OpBase& back() const { return ops_.back(); }
 | 
				
			||||
 | 
				
			||||
  OpBase& back() { return ops_.back(); }
 | 
				
			||||
 | 
				
			||||
  OpBase& operator[](size_t idx) { return ops_[idx]; }
 | 
				
			||||
 | 
				
			||||
  const OpBase& operator[](size_t idx) const { return ops_[idx]; }
 | 
				
			||||
 | 
				
			||||
  /* Iterator related */
 | 
				
			||||
  using Iterator = std::vector<OpBase>::iterator;
 | 
				
			||||
  using ConstIterator = std::vector<OpBase>::const_iterator;
 | 
				
			||||
 | 
				
			||||
  Iterator begin() { return ops_.begin(); }
 | 
				
			||||
 | 
				
			||||
  Iterator end() { return ops_.end(); }
 | 
				
			||||
 | 
				
			||||
  ConstIterator begin() const { return ops_.begin(); }
 | 
				
			||||
 | 
				
			||||
  ConstIterator end() const { return ops_.end(); }
 | 
				
			||||
 | 
				
			||||
  void InsertGradPendingNode(const std::shared_ptr<GradOpNode>& node) {
 | 
				
			||||
    if (node &&
 | 
				
			||||
        std::find(grad_pending_nodes_.begin(), grad_pending_nodes_.end(),
 | 
				
			||||
                  node) == grad_pending_nodes_.end()) {
 | 
				
			||||
      grad_pending_nodes_.emplace_back(node);
 | 
				
			||||
    }
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
  const std::vector<std::shared_ptr<GradOpNode>>& GradPendingNodes() const {
 | 
				
			||||
    return grad_pending_nodes_;
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
 private:
 | 
				
			||||
  DISABLE_COPY_AND_ASSIGN(GradOpNode);
 | 
				
			||||
 | 
				
			||||
 private:
 | 
				
			||||
  std::vector<OpBase> ops_;
 | 
				
			||||
  std::vector<std::shared_ptr<GradOpNode>> grad_pending_nodes_;
 | 
				
			||||
};
 | 
				
			||||
 | 
				
			||||
}  // namespace imperative
 | 
				
			||||
}  // namespace paddle
 | 
				
			||||
											
												
													File diff suppressed because it is too large
													Load Diff
												
											
										
									
								@ -0,0 +1,58 @@
 | 
				
			||||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||
//
 | 
				
			||||
// Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
// you may not use this file except in compliance with the License.
 | 
				
			||||
// You may obtain a copy of the License at
 | 
				
			||||
//
 | 
				
			||||
//     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
//
 | 
				
			||||
// Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
// distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
// See the License for the specific language governing permissions and
 | 
				
			||||
// limitations under the License.
 | 
				
			||||
 | 
				
			||||
#pragma once
 | 
				
			||||
 | 
				
			||||
#include <memory>
 | 
				
			||||
#include <vector>
 | 
				
			||||
#include "paddle/fluid/imperative/backward_strategy.h"
 | 
				
			||||
#include "paddle/fluid/imperative/engine.h"
 | 
				
			||||
#include "paddle/fluid/platform/place.h"
 | 
				
			||||
 | 
				
			||||
namespace paddle {
 | 
				
			||||
namespace imperative {
 | 
				
			||||
 | 
				
			||||
class VarBase;
 | 
				
			||||
 | 
				
			||||
class PartialGradEngine : public Engine {
 | 
				
			||||
 public:
 | 
				
			||||
  PartialGradEngine(const std::vector<std::shared_ptr<VarBase>> &input_targets,
 | 
				
			||||
                    const std::vector<std::shared_ptr<VarBase>> &output_targets,
 | 
				
			||||
                    const std::vector<std::shared_ptr<VarBase>> &output_grads,
 | 
				
			||||
                    const std::vector<std::shared_ptr<VarBase>> &no_grad_vars,
 | 
				
			||||
                    const platform::Place &place,
 | 
				
			||||
                    const detail::BackwardStrategy &strategy,
 | 
				
			||||
                    bool create_graph);
 | 
				
			||||
 | 
				
			||||
  void Execute() override;
 | 
				
			||||
 | 
				
			||||
  std::vector<std::shared_ptr<VarBase>> GetResult() const;
 | 
				
			||||
 | 
				
			||||
 private:
 | 
				
			||||
  void Clear();
 | 
				
			||||
 | 
				
			||||
 private:
 | 
				
			||||
  std::vector<std::shared_ptr<VarBase>> input_targets_;
 | 
				
			||||
  std::vector<std::shared_ptr<VarBase>> output_targets_;
 | 
				
			||||
  std::vector<std::shared_ptr<VarBase>> output_grads_;
 | 
				
			||||
  std::vector<std::shared_ptr<VarBase>> no_grad_vars_;
 | 
				
			||||
  platform::Place place_;
 | 
				
			||||
  detail::BackwardStrategy strategy_;
 | 
				
			||||
  bool create_graph_;
 | 
				
			||||
 | 
				
			||||
  std::vector<std::shared_ptr<VarBase>> results_;
 | 
				
			||||
};
 | 
				
			||||
 | 
				
			||||
}  // namespace imperative
 | 
				
			||||
}  // namespace paddle
 | 
				
			||||
@ -0,0 +1,87 @@
 | 
				
			||||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||
//
 | 
				
			||||
// Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
// you may not use this file except in compliance with the License.
 | 
				
			||||
// You may obtain a copy of the License at
 | 
				
			||||
//
 | 
				
			||||
//     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
//
 | 
				
			||||
// Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
// distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
// See the License for the specific language governing permissions and
 | 
				
			||||
// limitations under the License.
 | 
				
			||||
 | 
				
			||||
#pragma once
 | 
				
			||||
 | 
				
			||||
#include <memory>
 | 
				
			||||
#include <utility>
 | 
				
			||||
#include <vector>
 | 
				
			||||
 | 
				
			||||
namespace paddle {
 | 
				
			||||
namespace imperative {
 | 
				
			||||
 | 
				
			||||
class VariableWrapper;
 | 
				
			||||
 | 
				
			||||
class SavedVariableWrapperList {
 | 
				
			||||
 public:
 | 
				
			||||
  SavedVariableWrapperList() : vars_(), is_grad_(false) {}
 | 
				
			||||
 | 
				
			||||
  template <typename... Args>
 | 
				
			||||
  explicit SavedVariableWrapperList(bool is_grad, Args&&... args)
 | 
				
			||||
      : vars_(std::forward<Args>(args)...), is_grad_(is_grad) {}
 | 
				
			||||
 | 
				
			||||
  bool IsGrad() const { return is_grad_; }
 | 
				
			||||
 | 
				
			||||
  void SetIsGrad(bool is_grad) { is_grad_ = is_grad; }
 | 
				
			||||
 | 
				
			||||
  const std::vector<std::shared_ptr<VariableWrapper>>& VarList() const {
 | 
				
			||||
    return vars_;
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
  std::vector<std::shared_ptr<VariableWrapper>>* MutableVarList() {
 | 
				
			||||
    return &vars_;
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
  /* Borrow method from std::vector */
 | 
				
			||||
  size_t size() const { return vars_.size(); }
 | 
				
			||||
 | 
				
			||||
  bool empty() const { return vars_.empty(); }
 | 
				
			||||
 | 
				
			||||
  template <typename... ARGS>
 | 
				
			||||
  void emplace_back(ARGS&&... args) {
 | 
				
			||||
    vars_.emplace_back(std::forward<ARGS>(args)...);
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
  using Iterator = std::vector<std::shared_ptr<VariableWrapper>>::iterator;
 | 
				
			||||
 | 
				
			||||
  using ConstIterator =
 | 
				
			||||
      std::vector<std::shared_ptr<VariableWrapper>>::const_iterator;
 | 
				
			||||
 | 
				
			||||
  Iterator begin() { return vars_.begin(); }
 | 
				
			||||
 | 
				
			||||
  Iterator end() { return vars_.end(); }
 | 
				
			||||
 | 
				
			||||
  ConstIterator begin() const { return vars_.begin(); }
 | 
				
			||||
 | 
				
			||||
  ConstIterator end() const { return vars_.end(); }
 | 
				
			||||
 | 
				
			||||
  std::shared_ptr<VariableWrapper>& operator[](size_t idx) {
 | 
				
			||||
    return vars_[idx];
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
  const std::shared_ptr<VariableWrapper>& operator[](size_t idx) const {
 | 
				
			||||
    return vars_[idx];
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
  operator const std::vector<std::shared_ptr<VariableWrapper>>&() const {
 | 
				
			||||
    return vars_;
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
 private:
 | 
				
			||||
  std::vector<std::shared_ptr<VariableWrapper>> vars_;
 | 
				
			||||
  bool is_grad_;
 | 
				
			||||
};
 | 
				
			||||
 | 
				
			||||
}  // namespace imperative
 | 
				
			||||
}  // namespace paddle
 | 
				
			||||
Some files were not shown because too many files have changed in this diff Show More
					Loading…
					
					
				
		Reference in new issue