Merge pull request #16214 from velconia/imperative_infer_var_type

Implement imperative infer var type
revert-16190-refine_parallel_executor
Qiyang Min 7 years ago committed by GitHub
commit c7f1f3ed0c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -68,11 +68,11 @@ class SplitOpMaker : public OpProtoAndCheckerMaker {
class DummyVarTypeInference : public VarTypeInference {
public:
void operator()(const OpDesc& op_desc, BlockDesc* block) const override {
auto& inputs = op_desc.Input("X");
auto type = block->Var(inputs.front())->GetType();
auto out_var_name = op_desc.Output("Out").front();
block->Var(out_var_name)->SetType(type);
void operator()(framework::InferVarTypeContext* ctx) const override {
auto& inputs = ctx->Input("X");
auto type = ctx->GetType(inputs.front());
auto out_var_name = ctx->Output("Out").front();
ctx->SetType(out_var_name, type);
}
};

@ -16,6 +16,8 @@ limitations under the License. */
#include <string>
#include <tuple>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/grad_op_desc_maker.h"
#include "paddle/fluid/framework/inplace_op_inference.h"
@ -127,9 +129,9 @@ struct OpInfoFiller<T, kGradOpDescMaker> {
template <typename T>
struct OpInfoFiller<T, kVarTypeInference> {
void operator()(const char* op_type, OpInfo* info) const {
info->infer_var_type_ = [](const OpDesc& fwd_op, BlockDesc* block) {
info->infer_var_type_ = [](InferVarTypeContext* context) {
T inference;
inference(fwd_op, block);
inference(context);
};
}
};

@ -43,20 +43,20 @@ class SumOpMaker : public OpProtoAndCheckerMaker {
class SumOpVarTypeInference : public VarTypeInference {
public:
void operator()(const OpDesc &op_desc, BlockDesc *block) const override {
auto &inputs = op_desc.Input("X");
void operator()(InferVarTypeContext *ctx) const override {
auto &inputs = ctx->Input("X");
auto default_var_type = proto::VarType::SELECTED_ROWS;
bool any_input_is_lod_tensor = std::any_of(
inputs.begin(), inputs.end(), [block](const std::string &name) {
return block->Var(name)->GetType() == proto::VarType::LOD_TENSOR;
inputs.begin(), inputs.end(), [&ctx](const std::string &name) {
return ctx->GetType(name) == proto::VarType::LOD_TENSOR;
});
if (any_input_is_lod_tensor) {
default_var_type = proto::VarType::LOD_TENSOR;
}
auto out_var_name = op_desc.Output("Out").front();
block->Var(out_var_name)->SetType(default_var_type);
auto out_var_name = ctx->Output("Out").front();
ctx->SetType(out_var_name, default_var_type);
}
};
@ -71,7 +71,7 @@ class DummyOpMaker : public OpProtoAndCheckerMaker {
class DummyOpVarTypeInference : public VarTypeInference {
public:
void operator()(const OpDesc &op_desc, BlockDesc *block) const override {}
void operator()(framework::InferVarTypeContext *ctx) const override {}
};
} // namespace framework
} // namespace paddle

@ -24,6 +24,7 @@ limitations under the License. */
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/shape_inference.h"
#include "paddle/fluid/framework/var_type_inference.h"
namespace paddle {
namespace framework {
@ -677,7 +678,8 @@ void OpDesc::InferVarType(BlockDesc *block) const {
// var type inference. Hence, we don't do any "default" setting here.
auto &info = OpInfoMap::Instance().Get(this->Type());
if (info.infer_var_type_) {
info.infer_var_type_(*this, block);
InferVarTypeContext context(this, block);
info.infer_var_type_(&context);
}
}

@ -27,6 +27,7 @@ namespace framework {
class OperatorBase;
class OpDesc;
class InferShapeContext;
class InferVarTypeContext;
class BlockDesc;
class Variable;
@ -53,7 +54,7 @@ using GradOpMakerFN = std::function<std::vector<std::unique_ptr<OpDesc>>(
const std::vector<BlockDesc*>& grad_block)>;
using InferVarTypeFN =
std::function<void(const OpDesc& /*op_desc*/, BlockDesc* /*block*/)>;
std::function<void(framework::InferVarTypeContext* /*context*/)>;
using InferShapeFN = std::function<void(InferShapeContext*)>;

@ -14,6 +14,8 @@ limitations under the License. */
#pragma once
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/type_defs.h"
@ -21,26 +23,123 @@ limitations under the License. */
namespace paddle {
namespace framework {
class OpDesc;
class BlockDesc;
// default infer var type context
class InferVarTypeContext {
public:
InferVarTypeContext(const OpDesc* op, BlockDesc* block)
: op_(op), block_(block) {}
virtual ~InferVarTypeContext() {}
virtual Attribute GetAttr(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(op_);
return op_->GetAttr(name);
}
virtual bool HasVar(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(block_);
return block_->FindVarRecursive(name) != nullptr;
}
virtual bool HasInput(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(op_);
return op_->Inputs().count(name) > 0;
}
virtual bool HasOutput(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(op_);
return op_->Outputs().count(name) > 0;
}
virtual const std::vector<std::string>& Input(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(op_);
return op_->Input(name);
}
virtual const std::vector<std::string>& Output(
const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(op_);
return op_->Output(name);
}
virtual proto::VarType::Type GetType(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(block_);
return block_->FindRecursiveOrCreateVar(name).GetType();
}
virtual void SetType(const std::string& name, proto::VarType::Type type) {
PADDLE_ENFORCE_NOT_NULL(block_);
block_->FindRecursiveOrCreateVar(name).SetType(type);
}
virtual proto::VarType::Type GetDataType(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(block_);
return block_->FindRecursiveOrCreateVar(name).GetDataType();
}
virtual void SetDataType(const std::string& name, proto::VarType::Type type) {
PADDLE_ENFORCE_NOT_NULL(block_);
block_->FindRecursiveOrCreateVar(name).SetDataType(type);
}
virtual std::vector<proto::VarType::Type> GetDataTypes(
const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(block_);
return block_->FindRecursiveOrCreateVar(name).GetDataTypes();
}
virtual void SetDataTypes(
const std::string& name,
const std::vector<proto::VarType::Type>& multiple_data_type) {
PADDLE_ENFORCE_NOT_NULL(block_);
block_->FindRecursiveOrCreateVar(name).SetDataTypes(multiple_data_type);
}
virtual std::vector<int64_t> GetShape(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(block_);
return block_->FindRecursiveOrCreateVar(name).GetShape();
}
virtual void SetShape(const std::string& name,
const std::vector<int64_t>& dims) {
PADDLE_ENFORCE_NOT_NULL(block_);
block_->FindRecursiveOrCreateVar(name).SetShape(dims);
}
virtual int32_t GetLoDLevel(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(block_);
return block_->FindRecursiveOrCreateVar(name).GetLoDLevel();
}
virtual void SetLoDLevel(const std::string& name, int32_t lod_level) {
PADDLE_ENFORCE_NOT_NULL(block_);
block_->FindRecursiveOrCreateVar(name).SetLoDLevel(lod_level);
}
protected:
const OpDesc* op_;
BlockDesc* block_;
};
class VarTypeInference {
public:
virtual ~VarTypeInference() {}
virtual void operator()(const OpDesc& op_desc, BlockDesc* block) const = 0;
virtual void operator()(InferVarTypeContext* context) const = 0; // NOLINT
};
class PassInDtypeAndVarTypeToOutput : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const final {
void operator()(framework::InferVarTypeContext* ctx) const final { // NOLINT
auto in_out_var_names = this->GetInputOutputWithSameType();
for (auto& i_o_n : in_out_var_names) {
auto& x_name = op_desc.Input(i_o_n.first).at(0);
auto& out_name = op_desc.Output(i_o_n.second).at(0);
auto& x_name = ctx->Input(i_o_n.first).at(0);
auto& out_name = ctx->Output(i_o_n.second).at(0);
auto& x = block->FindRecursiveOrCreateVar(x_name);
auto& out = block->FindRecursiveOrCreateVar(out_name);
out.SetType(x.GetType());
out.SetDataType(x.GetDataType());
ctx->SetType(out_name, ctx->GetType(x_name));
ctx->SetDataType(out_name, ctx->GetDataType(x_name));
}
}

@ -44,20 +44,20 @@ class SumOpMaker : public OpProtoAndCheckerMaker {
class SumOpVarTypeInference : public VarTypeInference {
public:
void operator()(const OpDesc &op_desc, BlockDesc *block) const override {
auto &inputs = op_desc.Input("X");
void operator()(framework::InferVarTypeContext *ctx) const override {
auto &inputs = ctx->Input("X");
auto default_var_type = proto::VarType::SELECTED_ROWS;
bool any_input_is_lod_tensor = std::any_of(
inputs.begin(), inputs.end(), [block](const std::string &name) {
return block->Var(name)->GetType() == proto::VarType::LOD_TENSOR;
inputs.begin(), inputs.end(), [&ctx](const std::string &name) {
return ctx->GetType(name) == proto::VarType::LOD_TENSOR;
});
if (any_input_is_lod_tensor) {
default_var_type = proto::VarType::LOD_TENSOR;
}
auto out_var_name = op_desc.Output("Out").front();
block->Var(out_var_name)->SetType(default_var_type);
auto out_var_name = ctx->Output("Out").front();
ctx->SetType(out_var_name, default_var_type);
}
};
} // namespace framework

@ -218,7 +218,7 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
"%s has no backward implementation", Type());
VLOG(3) << "apply op grad: " << Type();
std::vector<framework::VariableValueMap> tmp_grad_outputs;
std::vector<VarBasePtrMap> tmp_grad_outputs;
if (backward_id_ > 0) {
VLOG(3) << "py_layer_grad";
tmp_grad_outputs.resize(1);
@ -241,26 +241,62 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
auto& outputs = tmp_grad_outputs[k][it.first];
outputs.reserve(it.second.size());
for (size_t i = 0; i < it.second.size(); ++i) {
VarBase* origin_grad_var_base = it.second[i];
// Allocate a new variable
Variable* tmp_var = new framework::Variable();
tmp_var->GetMutable<framework::LoDTensor>();
outputs.emplace_back(tmp_var);
VarBase* tmp_grad_var_base = new VarBase(
string::Sprintf("%s@IGrad", origin_grad_var_base->Name()),
origin_grad_var_base->DataType(), origin_grad_var_base->Dims(),
place_, true, false);
outputs.emplace_back(tmp_grad_var_base);
}
}
// Run grad op
framework::RuntimeContext ctx(grad_input_vars_[k], tmp_grad_outputs[k]);
// No need to do compile time infer shape here.
// grad_op_desc_->InferShape(*block_);
// grad_op_desc->InferVarType(block_);
std::unique_ptr<framework::OperatorBase> opbase =
framework::OpRegistry::CreateOp(*grad_op_desc);
auto& info = framework::OpInfoMap::Instance().Get(grad_op_desc->Type());
if (info.infer_var_type_) {
RuntimeInferVarTypeContext infer_var_type_ctx(
&grad_input_vars_[k], &tmp_grad_outputs[k], &attrs_);
info.infer_var_type_(&infer_var_type_ctx);
}
framework::OperatorWithKernel* op_kernel =
dynamic_cast<framework::OperatorWithKernel*>(opbase.get());
PADDLE_ENFORCE_NOT_NULL(op_kernel, "only support op with kernel");
// Run grad op
framework::VariableValueMap grad_invars_map;
framework::VariableValueMap grad_outvars_map;
for (const auto& it : grad_input_vars_[k]) {
auto& grad_invars = grad_invars_map[it.first];
grad_invars.reserve(it.second.size());
for (const VarBase* grad_inp : it.second) {
PADDLE_ENFORCE_NOT_NULL(grad_inp->var_, "op %s input %s nullptr",
grad_op_desc->Type(), grad_inp->Name());
grad_invars.emplace_back(grad_inp->var_);
}
}
for (const auto& it : tmp_grad_outputs[k]) {
auto& grad_outvars = grad_outvars_map[it.first];
grad_outvars.reserve(it.second.size());
for (VarBase* grad_out : it.second) {
PADDLE_ENFORCE_NOT_NULL(grad_out->var_, "op %s output %s nullptr",
grad_op_desc->Type(), grad_out->Name());
grad_outvars.emplace_back(grad_out->var_);
}
}
framework::RuntimeContext ctx(grad_invars_map, grad_outvars_map);
framework::Scope scope;
PreparedOp p = PreparedOp::Prepare(ctx, *op_kernel, place_);
p.op.RuntimeInferShape(scope, place_, ctx);
@ -277,8 +313,8 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
PADDLE_ENFORCE_EQ(outputs.size(), origin_outputs.size());
for (size_t i = 0; i < outputs.size(); ++i) {
framework::Variable* grad = outputs[i];
framework::Variable* orig_grad = origin_outputs[i];
framework::Variable* grad = outputs[i]->var_;
framework::Variable* orig_grad = origin_outputs[i]->var_;
AddTo(grad, orig_grad, place_);
delete grad;
}
@ -326,28 +362,35 @@ void PyLayer::RegisterFunc(int func_id, const py::object& py_func) {
int PyLayer::NumFuncs() { return py_funcs_.size(); }
std::vector<Variable*> PyLayer::Apply(int func_id,
const std::vector<VarBase*>& inputs) {
std::vector<framework::Variable*> invars;
for (const VarBase* in : inputs) {
invars.push_back(in->var_);
}
std::vector<framework::Variable*> PyLayer::Apply(
int func_id, const std::vector<VarBase*>& inputs) {
PADDLE_ENFORCE(py_funcs_.find(func_id) != py_funcs_.end());
return CallPythonFunc(py_funcs_[func_id], invars);
return CallPythonFunc(py_funcs_[func_id], inputs);
}
std::vector<Variable*> PyLayer::ApplyGrad(
int func_id, const std::vector<framework::Variable*>& inputs) {
std::vector<VarBase*> PyLayer::ApplyGrad(int func_id,
const std::vector<VarBase*>& inputs) {
PADDLE_ENFORCE(py_funcs_.find(func_id) != py_funcs_.end());
return CallPythonFunc(py_funcs_[func_id], inputs);
auto rets = CallPythonFunc(py_funcs_[func_id], inputs);
std::vector<VarBase*> outs;
outs.reserve(rets.size());
for (size_t i = 0U; i != rets.size(); ++i) {
outs.emplace_back(new VarBase(
string::Sprintf("%s_out_%d", framework::GradVarName(PyLayer::kFwdOut),
i),
rets[i], nullptr, true));
}
return outs;
}
std::vector<framework::Variable*> PyLayer::CallPythonFunc(
const py::object& callable, const std::vector<framework::Variable*>& ins) {
const py::object& callable, const std::vector<VarBase*>& ins) {
py::gil_scoped_acquire guard;
py::tuple in_args(ins.size());
for (size_t i = 0; i < ins.size(); ++i) {
const framework::LoDTensor& t = ins[i]->Get<framework::LoDTensor>();
const framework::LoDTensor& t = ins[i]->var_->Get<framework::LoDTensor>();
in_args[i] = t.IsInitialized() ? py::cast(t) : py::cast(nullptr);
}
VLOG(3) << "pyfunc in " << py::len(in_args);
@ -357,6 +400,7 @@ std::vector<framework::Variable*> PyLayer::CallPythonFunc(
auto ret_tuple = py::cast<py::tuple>(ret);
size_t ret_num = py::len(ret_tuple);
std::vector<framework::Variable*> outs;
outs.reserve(ret_num);
VLOG(3) << "pyfunc out " << ret_num;
for (size_t i = 0; i < ret_num; ++i) {
try {
@ -367,7 +411,7 @@ std::vector<framework::Variable*> PyLayer::CallPythonFunc(
auto* tensor = var->GetMutable<framework::LoDTensor>();
tensor->ShareDataWith(*py_out_tensor);
tensor->set_lod(py_out_tensor->lod());
outs.push_back(var);
outs.emplace_back(var);
} catch (py::cast_error&) {
PADDLE_THROW("The %d-th output must be LoDTensor", i);
}

@ -18,14 +18,16 @@
#include "paddle/fluid/framework/python_headers.h"
// clang-format on
#include <map> // NOLINT
#include <string> // NOLINT
#include <vector> // NOLINT
#include <memory> // NOLINT
#include <map> // NOLINT
#include <string> // NOLINT
#include <vector> // NOLINT
#include <memory> // NOLINT
#include <unordered_map> // NOLINT
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/framework/var_type_inference.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/operators/math/math_function.h"
@ -135,13 +137,13 @@ class VarBase {
persistable) {}
private:
// TODO(minqiyang): need support SelectedRows
VarBase(const std::string& name, framework::proto::VarType::Type dtype,
const framework::DDim& shape, const platform::Place& place,
framework::Variable* var, VarBase* grad, bool stop_gradient,
bool persistable)
: name_(name),
dtype_(dtype),
place_(place),
type_(framework::proto::VarType::LOD_TENSOR),
var_(var),
grads_(grad),
stop_gradient_(stop_gradient),
@ -151,10 +153,12 @@ class VarBase {
pre_op_out_idx_(-1) {
if (!var_) {
var_ = new framework::Variable();
auto tensor = var_->GetMutable<framework::LoDTensor>();
tensor->Resize(shape);
tensor->mutable_data(place_, dtype_);
}
auto tensor = var_->GetMutable<framework::LoDTensor>();
tensor->Resize(shape);
tensor->mutable_data(place, dtype);
VLOG(10) << "create varbase: " << name_ << " type: " << dtype
<< " place: " << place;
}
public:
@ -184,7 +188,23 @@ class VarBase {
}
}
inline framework::proto::VarType::Type DType() const { return dtype_; }
inline framework::DDim Dims() const {
return var_->Get<framework::LoDTensor>().dims();
}
// data type. e.g.. FP32
inline void SetDataType(framework::proto::VarType::Type type) {
auto tensor = var_->GetMutable<framework::LoDTensor>();
tensor->mutable_data(tensor->place(), type);
}
inline framework::proto::VarType::Type DataType() const {
auto tensor = var_->Get<framework::LoDTensor>();
return tensor.type();
}
// tensor type. e.g.. LoDTensor
inline void SetType(framework::proto::VarType::Type type) { type_ = type; }
inline framework::proto::VarType::Type Type() const { return type_; }
inline void SetStopGradient(bool stop_gradient) {
stop_gradient_ = stop_gradient;
@ -238,7 +258,7 @@ class VarBase {
}
std::string name_;
framework::proto::VarType::Type dtype_;
framework::proto::VarType::Type type_;
platform::Place place_;
framework::Variable* var_;
@ -334,11 +354,13 @@ class PYBIND11_HIDDEN OpBase {
std::map<std::string, std::vector<int>> pre_ops_out_idx_;
// Inputs to a vector of bwd ops.
std::vector<framework::VariableValueMap> grad_input_vars_;
std::vector<VarBasePtrMap> grad_input_vars_;
// Outputs to a vector of bwd ops.
std::vector<framework::VariableValueMap> grad_output_vars_;
std::vector<VarBasePtrMap> grad_output_vars_;
std::vector<py::object> backward_hooks_;
framework::AttributeMap attrs_;
};
class Layer {
@ -365,12 +387,131 @@ class PyLayer {
static std::vector<framework::Variable*> Apply(
int func_id, const std::vector<VarBase*>& inputs);
static std::vector<framework::Variable*> ApplyGrad(
int func_id, const std::vector<framework::Variable*>& inputs);
static std::vector<VarBase*> ApplyGrad(int func_id,
const std::vector<VarBase*>& inputs);
private:
static std::vector<framework::Variable*> CallPythonFunc(
const py::object& callable, const std::vector<framework::Variable*>& ins);
const py::object& callable, const std::vector<VarBase*>& ins);
};
// infer var type context for imperative mode
class PYBIND11_HIDDEN RuntimeInferVarTypeContext
: public framework::InferVarTypeContext {
public:
RuntimeInferVarTypeContext(const imperative::VarBasePtrMap* inputs,
imperative::VarBasePtrMap* outputs,
const framework::AttributeMap* attrs_map)
: InferVarTypeContext(nullptr, nullptr),
inputs_(inputs),
outputs_(outputs),
attrs_(attrs_map),
input_names_(),
output_names_(),
var_set_() {
input_names_.reserve(inputs_->size());
for (auto& it : *inputs_) {
for (imperative::VarBase* var : it.second) {
input_names_[it.first].emplace_back(var->Name());
var_set_[var->Name()] = var;
}
}
output_names_.reserve(outputs_->size());
for (auto& it : *outputs_) {
for (imperative::VarBase* var : it.second) {
output_names_[it.first].emplace_back(var->Name());
var_set_[var->Name()] = var;
}
}
}
virtual ~RuntimeInferVarTypeContext() {}
framework::Attribute GetAttr(const std::string& name) const override {
PADDLE_ENFORCE_NOT_NULL(attrs_);
return attrs_->at(name);
}
bool HasVar(const std::string& name) const override {
return var_set_.count(name) > 0;
}
bool HasInput(const std::string& name) const override {
PADDLE_ENFORCE_NOT_NULL(inputs_);
return inputs_->count(name) > 0;
}
bool HasOutput(const std::string& name) const override {
PADDLE_ENFORCE_NOT_NULL(outputs_);
return outputs_->count(name) > 0;
}
const std::vector<std::string>& Input(
const std::string& name) const override {
return input_names_.at(name);
}
const std::vector<std::string>& Output(
const std::string& name) const override {
return output_names_.at(name);
}
framework::proto::VarType::Type GetType(
const std::string& name) const override {
return var_set_.at(name)->Type();
}
void SetType(const std::string& name,
framework::proto::VarType::Type type) override {
var_set_[name]->SetType(type);
}
framework::proto::VarType::Type GetDataType(
const std::string& name) const override {
return var_set_.at(name)->DataType();
}
void SetDataType(const std::string& name,
framework::proto::VarType::Type type) override {
var_set_[name]->SetDataType(type);
}
std::vector<framework::proto::VarType::Type> GetDataTypes(
const std::string& name) const override {
PADDLE_THROW("GetDataTypes is not supported in runtime InferVarType");
}
void SetDataTypes(const std::string& name,
const std::vector<framework::proto::VarType::Type>&
multiple_data_type) override {
PADDLE_THROW("SetDataTypes is not supported in runtime InferVarType");
}
std::vector<int64_t> GetShape(const std::string& name) const override {
PADDLE_THROW("Do not handle Shape in runtime InferVarType");
}
void SetShape(const std::string& name,
const std::vector<int64_t>& dims) override {
PADDLE_THROW("Do not handle Shape in runtime InferVarType");
}
int32_t GetLoDLevel(const std::string& name) const override {
PADDLE_THROW("Do not handle LoDLevel in runtime InferVarType");
}
void SetLoDLevel(const std::string& name, int32_t lod_level) override {
PADDLE_THROW("Do not handle LoDLevel in runtime InferVarType");
}
private:
const imperative::VarBasePtrMap* inputs_;
imperative::VarBasePtrMap* outputs_;
const framework::AttributeMap* attrs_;
std::unordered_map<std::string, std::vector<std::string>> input_names_;
std::unordered_map<std::string, std::vector<std::string>> output_names_;
std::unordered_map<std::string, imperative::VarBase*> var_set_;
};
} // namespace imperative

@ -19,6 +19,7 @@
#include <unordered_map>
#include <unordered_set>
#include "paddle/fluid/framework/var_type_inference.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
@ -135,7 +136,7 @@ framework::VariableNameMap CreateOutputVarNameMap(
Tracer::Tracer(framework::BlockDesc* root_block) : root_block_(root_block) {}
std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
const VarBasePtrMap& outputs,
VarBasePtrMap* outputs,
framework::AttributeMap attrs_map,
const platform::Place expected_place,
const bool stop_gradient) {
@ -163,7 +164,7 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
op->TrackPreOp(it.first, it.second);
}
op->output_vars_ = outputs;
op->output_vars_ = *outputs;
for (auto it : op->output_vars_) {
auto& outvars = outvars_map[it.first];
const std::vector<VarBase*>& outputs = it.second;
@ -186,7 +187,7 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
framework::VariableNameMap invars_name_map =
CreateInputVarNameMap(op, inputs);
framework::VariableNameMap outvars_name_map =
CreateOutputVarNameMap(op, outputs);
CreateOutputVarNameMap(op, *outputs);
auto& info = framework::OpInfoMap::Instance().Get(op->Type());
if (info.Checker() != nullptr) {
@ -197,6 +198,11 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
framework::OpRegistry::CreateOp(op->Type(), invars_name_map,
outvars_name_map, attrs_map);
if (info.infer_var_type_) {
RuntimeInferVarTypeContext infer_var_type_ctx(&inputs, outputs, &attrs_map);
info.infer_var_type_(&infer_var_type_ctx);
}
// TODO(minqiyang): Support infer var type in imperative mode
// Run forward op
VLOG(3) << "tracer running " << op->Type();
@ -221,6 +227,7 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
VLOG(5) << "start construct backward op";
// construct grad op descs
op->attrs_ = attrs_map;
std::unique_ptr<framework::OpDesc> fwd_op_desc(new framework::OpDesc(
op->Type(), invars_name_map, outvars_name_map, attrs_map));
std::unique_ptr<std::unordered_map<std::string, std::string>> grad_to_var(
@ -247,12 +254,12 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
auto fwd_var_it = current_vars_map.find(grad_invar);
PADDLE_ENFORCE(fwd_var_it != current_vars_map.end());
// Forward inputs or outputs.
grad_in_vars.emplace_back(fwd_var_it->second->var_);
grad_in_vars.emplace_back(fwd_var_it->second);
} else {
VarBase* var = current_vars_map[var_it->second];
InitGrad(var, prepared_op.GetDeviceContext());
// Douts.
grad_in_vars.emplace_back(var->grads_->var_);
grad_in_vars.emplace_back(var->grads_);
}
vars_saved_for_backward.insert(it.first);
@ -269,7 +276,7 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
op->Type());
VarBase* var = current_vars_map[var_it->second];
InitGrad(var, prepared_op.GetDeviceContext());
grad_out_vars.push_back(var->grads_->var_);
grad_out_vars.push_back(var->grads_);
}
}
}
@ -309,23 +316,23 @@ std::vector<VarBase*> Tracer::PyTrace(OpBase* op,
auto& grad_output_vars =
op->grad_output_vars_[0][framework::GradVarName(PyLayer::kFwdOut)];
for (const VarBase* inp : inputs) {
grad_input_vars.push_back(inp->var_);
for (VarBase* inp : inputs) {
grad_input_vars.push_back(inp);
}
for (VarBase* out : outputs) {
grad_input_vars.push_back(out->var_);
grad_input_vars.push_back(out);
}
// TODO(minqiyang): Add GPU support for PyLayer, only support CPU now
platform::CPUPlace place;
for (VarBase* out : outputs) {
InitGrad(out, platform::DeviceContextPool::Instance().Get(place));
grad_input_vars.push_back(out->grads_->var_);
grad_input_vars.push_back(out->grads_);
}
for (VarBase* inp : inputs) {
InitGrad(inp, platform::DeviceContextPool::Instance().Get(place));
grad_output_vars.push_back(inp->grads_->var_);
grad_output_vars.push_back(inp->grads_);
}
}
return outputs;

@ -48,7 +48,7 @@ class Tracer {
virtual ~Tracer() {}
std::set<std::string> Trace(OpBase* op, const VarBasePtrMap& inputs,
const VarBasePtrMap& outputs,
VarBasePtrMap* outputs, // NOLINT
framework::AttributeMap attrs_map,
const platform::Place expected_place,
const bool stop_gradient = false);

@ -25,6 +25,7 @@ class VarBase;
class OpBase;
typedef std::map<std::string, std::vector<VarBase*>> VarBasePtrMap;
typedef std::map<std::string, std::vector<const VarBase*>> ConstVarBasePtrMap;
typedef std::map<std::string, std::vector<OpBase*>> OpBasePtrMap;
} // namespace imperative

@ -178,10 +178,10 @@ Beam Search Decode Operator. This Operator constructs the full hypotheses for
each source sentence by walking back along the LoDTensorArray Input(ids)
whose lods can be used to restore the path in the beam search tree.
The Output(SentenceIds) and Output(SentenceScores) separately contain the
generated id sequences and the corresponding scores. The shapes and lods of the
two LodTensor are same. The lod level is 2 and the two levels separately
indicate how many hypotheses each source sentence has and how many ids each
The Output(SentenceIds) and Output(SentenceScores) separately contain the
generated id sequences and the corresponding scores. The shapes and lods of the
two LodTensor are same. The lod level is 2 and the two levels separately
indicate how many hypotheses each source sentence has and how many ids each
hypothesis has.
)DOC");
}
@ -203,15 +203,12 @@ class BeamSearchDecodeInferShape : public framework::InferShapeBase {
class BeamSearchDecodeInferVarType : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
for (auto& o : op_desc.Output("SentenceIds")) {
auto& sentence_ids = block->FindRecursiveOrCreateVar(o);
sentence_ids.SetType(framework::proto::VarType::LOD_TENSOR);
void operator()(framework::InferVarTypeContext* ctx) const override {
for (auto& o : ctx->Output("SentenceIds")) {
ctx->SetType(o, framework::proto::VarType::LOD_TENSOR);
}
for (auto& o : op_desc.Output("SentenceScores")) {
auto& sentence_scores = block->FindRecursiveOrCreateVar(o);
sentence_scores.SetType(framework::proto::VarType::LOD_TENSOR);
for (auto& o : ctx->Output("SentenceScores")) {
ctx->SetType(o, framework::proto::VarType::LOD_TENSOR);
}
}
};

@ -65,7 +65,7 @@ class BeamSearchOpMaker : public framework::OpProtoAndCheckerMaker {
.SetDefault(true);
AddComment(R"DOC(
This operator does the search in beams for one time step.
This operator does the search in beams for one time step.
Specifically, it selects the top-K candidate word ids of current step from
Input(ids) according to their Input(scores) for all source sentences,
where K is Attr(beam_size) and Input(ids), Input(scores) are predicted results
@ -120,15 +120,12 @@ class BeamSearchOp : public framework::OperatorWithKernel {
class BeamSearchInferVarType : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {
for (auto &o : op_desc.Output("selected_ids")) {
auto &selected_ids = block->FindRecursiveOrCreateVar(o);
selected_ids.SetType(framework::proto::VarType::LOD_TENSOR);
void operator()(framework::InferVarTypeContext *ctx) const override {
for (auto &o : ctx->Output("selected_ids")) {
ctx->SetType(o, framework::proto::VarType::LOD_TENSOR);
}
for (auto &o : op_desc.Output("selected_scores")) {
auto &selected_scores = block->FindRecursiveOrCreateVar(o);
selected_scores.SetType(framework::proto::VarType::LOD_TENSOR);
for (auto &o : ctx->Output("selected_scores")) {
ctx->SetType(o, framework::proto::VarType::LOD_TENSOR);
}
}
};

@ -93,11 +93,9 @@ execution.
class GetPlacesInferVarType : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {
for (auto &o_name : op_desc.Output("Out")) {
block->FindRecursiveOrCreateVar(o_name).SetType(
framework::proto::VarType::PLACE_LIST);
void operator()(framework::InferVarTypeContext *ctx) const override {
for (auto &o_name : ctx->Output("Out")) {
ctx->SetType(o_name, framework::proto::VarType::PLACE_LIST);
}
}
};

@ -100,16 +100,13 @@ class WriteToArrayInferShape : public framework::InferShapeBase {
class WriteToArrayInferVarType : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {
auto x_name = op_desc.Input("X")[0];
auto out_name = op_desc.Output("Out")[0];
void operator()(framework::InferVarTypeContext *ctx) const override {
auto x_name = ctx->Input("X")[0];
auto out_name = ctx->Output("Out")[0];
VLOG(10) << "Set Variable " << out_name << " as LOD_TENSOR_ARRAY";
auto &out = block->FindRecursiveOrCreateVar(out_name);
out.SetType(framework::proto::VarType::LOD_TENSOR_ARRAY);
auto *x = block->FindVarRecursive(x_name);
if (x != nullptr) {
out.SetDataType(x->GetDataType());
ctx->SetType(out_name, framework::proto::VarType::LOD_TENSOR_ARRAY);
if (ctx->HasVar(x_name)) {
ctx->SetDataType(out_name, ctx->GetDataType(x_name));
}
}
};

@ -365,19 +365,16 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
class WhileGradOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {
auto p_names = op_desc.Input(kX);
auto pg_ig_names = op_desc.Output(framework::GradVarName(kX));
void operator()(framework::InferVarTypeContext *ctx) const override {
auto p_names = ctx->Input(kX);
auto pg_ig_names = ctx->Output(framework::GradVarName(kX));
for (size_t i = 0; i < p_names.size(); ++i) {
auto &p_var = detail::Ref(block->FindVarRecursive(p_names[i]));
auto *g_var = block->FindVarRecursive(pg_ig_names[i]);
if (g_var != nullptr) { // Gradient could be @EMPTY@
if (ctx->HasVar(pg_ig_names[i])) {
VLOG(5) << "Setting " << pg_ig_names[i] << " following " << p_names[i]
<< " type: " << p_var.GetType();
g_var->SetType(p_var.GetType());
g_var->SetDataType(p_var.GetDataType());
<< " type: " << ctx->GetType(p_names[i]);
ctx->SetType(pg_ig_names[i], ctx->GetType(p_names[i]));
ctx->SetDataType(pg_ig_names[i], ctx->GetDataType(p_names[i]));
}
}
}

@ -56,8 +56,7 @@ class FakeInitOp : public framework::OperatorBase {
class FakeInitOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {}
void operator()(framework::InferVarTypeContext *ctx) const override {}
};
class FakeInitOpMaker : public framework::OpProtoAndCheckerMaker {

@ -114,11 +114,10 @@ class MergeIdsOp : public framework::OperatorWithKernel {
class MergeIdsOpInferVarType : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {
auto *input_var = block->Var(op_desc.Input("Ids")[0]);
for (auto &out_var : op_desc.Output("Out")) {
block->Var(out_var)->SetType(input_var->GetType());
void operator()(framework::InferVarTypeContext *ctx) const override {
auto input_type = ctx->GetType(ctx->Input("Ids")[0]);
for (auto &out_var : ctx->Output("Out")) {
ctx->SetType(out_var, input_type);
}
}
};

@ -14,6 +14,8 @@ limitations under the License. */
#include "paddle/fluid/operators/distributed_ops/split_ids_op.h"
#include <memory>
namespace paddle {
namespace operators {
@ -71,11 +73,10 @@ class SplitIdsOp : public framework::OperatorWithKernel {
class SplitIdsOpInferVarType : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {
auto *input_var = block->Var(op_desc.Input("Ids")[0]);
for (auto &out_var : op_desc.Output("Out")) {
block->Var(out_var)->SetType(input_var->GetType());
void operator()(framework::InferVarTypeContext *ctx) const override {
auto input_type = ctx->GetType(ctx->Input("Ids")[0]);
for (auto &out_var : ctx->Output("Out")) {
ctx->SetType(out_var, input_type);
}
}
};

@ -39,12 +39,11 @@ class FillConstantOp : public framework::OperatorWithKernel {
class FillConstantOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
void operator()(framework::InferVarTypeContext* ctx) const override {
auto data_type = static_cast<framework::proto::VarType::Type>(
boost::get<int>(op_desc.GetAttr("dtype")));
auto& out_var_name = op_desc.Output("Out").front();
block->Var(out_var_name)->SetDataType(data_type);
boost::get<int>(ctx->GetAttr("dtype")));
auto& out_var_name = ctx->Output("Out").front();
ctx->SetDataType(out_var_name, data_type);
}
};

@ -138,22 +138,20 @@ class FusedEmbeddingSeqPoolOpGrad : public framework::OperatorWithKernel {
class FusedEmbeddingSeqPoolOpGradVarTypeInference
: public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
auto out_var_name = op_desc.Output(framework::GradVarName("W")).front();
auto attr = op_desc.GetAttr("is_sparse");
void operator()(framework::InferVarTypeContext* ctx) const override {
auto out_var_name = ctx->Output(framework::GradVarName("W")).front();
auto attr = ctx->GetAttr("is_sparse");
bool is_sparse = boost::get<bool>(attr);
if (is_sparse) {
VLOG(3) << "fused_embedding_seq_pool_grad op "
<< framework::GradVarName("W") << " is set to SelectedRows";
block->Var(out_var_name)
->SetType(framework::proto::VarType::SELECTED_ROWS);
ctx->SetType(out_var_name, framework::proto::VarType::SELECTED_ROWS);
} else {
VLOG(3) << "fused_embedding_seq_pool_grad op "
<< framework::GradVarName("W") << " is set to LoDTensor";
block->Var(out_var_name)->SetType(framework::proto::VarType::LOD_TENSOR);
ctx->SetType(out_var_name, framework::proto::VarType::LOD_TENSOR);
}
block->Var(out_var_name)->SetDataType(block->Var("W")->GetDataType());
ctx->SetDataType(out_var_name, ctx->GetDataType(ctx->Input("W")[0]));
}
};

@ -81,15 +81,12 @@ GetTensorFromSelectedRows is used to get the tensor from SelectedRows.
class GetTensorFromSelectedRowsOpVarTypeInference
: public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const final {
auto out_var_name = op_desc.Output("Out").front();
auto in_var_name = op_desc.Input("X").front();
auto out_var = block->FindRecursiveOrCreateVar(out_var_name);
auto in_var = block->FindRecursiveOrCreateVar(in_var_name);
out_var.SetType(framework::proto::VarType::LOD_TENSOR);
out_var.SetDataType(in_var.GetDataType());
void operator()(framework::InferVarTypeContext *ctx) const { // NOLINT
auto out_var_name = ctx->Output("Out").front();
auto in_var_name = ctx->Input("X").front();
ctx->SetType(out_var_name, framework::proto::VarType::LOD_TENSOR);
ctx->SetDataType(out_var_name, ctx->GetDataType(in_var_name));
}
};

@ -197,38 +197,32 @@ class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel {
class HierarchicalSigmoidGradOpGradVarTypeInference
: public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
auto w_grad_var_name = op_desc.Output(framework::GradVarName("W")).front();
auto bias_grad_var_name_vec =
op_desc.Output(framework::GradVarName("Bias"));
void operator()(framework::InferVarTypeContext* ctx) const override {
auto w_grad_var_name = ctx->Output(framework::GradVarName("W")).front();
auto bias_grad_var_name_vec = ctx->Output(framework::GradVarName("Bias"));
std::string bias_grad_var_name;
bool hasBias = false;
if (bias_grad_var_name_vec.size()) {
hasBias = true;
bias_grad_var_name =
op_desc.Output(framework::GradVarName("Bias")).front();
bias_grad_var_name = ctx->Output(framework::GradVarName("Bias")).front();
}
auto attr = op_desc.GetAttr("is_sparse");
auto attr = ctx->GetAttr("is_sparse");
bool is_sparse = boost::get<bool>(attr);
if (is_sparse) {
VLOG(30) << "hierarchical_sigmoid_grad op " << framework::GradVarName("W")
<< " is set to SelectedRows";
block->Var(w_grad_var_name)
->SetType(framework::proto::VarType::SELECTED_ROWS);
ctx->SetType(w_grad_var_name, framework::proto::VarType::SELECTED_ROWS);
} else {
VLOG(30) << "hierarchical_sigmoid_grad op " << framework::GradVarName("W")
<< " is set to LoDTensor";
block->Var(w_grad_var_name)
->SetType(framework::proto::VarType::LOD_TENSOR);
ctx->SetType(w_grad_var_name, framework::proto::VarType::LOD_TENSOR);
}
if (hasBias) {
VLOG(30) << "hierarchical_sigmoid_grad op "
<< framework::GradVarName("Bias") << " is set to LoDTensor";
block->Var(bias_grad_var_name)
->SetType(framework::proto::VarType::LOD_TENSOR);
ctx->SetType(bias_grad_var_name, framework::proto::VarType::LOD_TENSOR);
}
block->Var(w_grad_var_name)->SetDataType(block->Var("W")->GetDataType());
ctx->SetDataType(w_grad_var_name, ctx->GetDataType(ctx->Input("W")[0]));
}
};

@ -64,11 +64,9 @@ class LoDRankTableInferShape : public framework::InferShapeBase {
class LoDRankTableInferVarType : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {
for (auto &o : op_desc.Output("Out")) {
block->FindRecursiveOrCreateVar(o).SetType(
framework::proto::VarType::LOD_RANK_TABLE);
void operator()(framework::InferVarTypeContext *ctx) const override {
for (auto &o : ctx->Output("Out")) {
ctx->SetType(o, framework::proto::VarType::LOD_RANK_TABLE);
}
}
};

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save