Add dygraph execution context (#20157)

* add_dygraph_execution_context

* add dygraph infershape context and execution context; test=develop

* fix imperative bug; test=develop

* remove inputs outputs interface from execution context,
because it have same function with inputNames;
test=develop

* remove tracer_test ctest; test=develop

* fix split op bug; test=develop

* fix unitests bug; test=develop

* fix distribute test bug; test=develop

* fix ngraph compile bug; test=develop

* fix grad maker bug; test=develop

* fix load op bugs; test=develop

* fix operator.cc construct bug; test=develop

* remove useless name find in operator; test=develop

* add tracer_test; test=develop

* fix concat, split bug; test=develop

* remove tracer_test unitest; test=develop

* fix attribute check bug; test=develop

* add test code to fix converage; test=develop

* remove useless code, change check backward input in engin; test=develop

* unlock var type infer shape;test=develop

* add ShareAllLoD api; test=develop

* add dygraph infershape context unitest; test=develop

* remove increase and decrease lod in dygraph; test=develop

* addd override; test=develop

* fix increase descrease lod; test=develop

* fix paddle_enforce; test=develop

* disable lod op dygraph check; test=develop

* fix paddle enforce error; test=develop

* add comment for op_registry and OperatorBase; test=develop

* optimize the comment of op_registry; test=develop

* fix format of comment; test=develop

* fix format of comment; test=develop

* optimize the format of comment; test=develop

* optimize the format of the comment; test=develop

* optimize comment of op_registry; test=develop
revert-21172-masked_select_api
hong 6 years ago committed by GitHub
parent a6b089c614
commit ac8546701d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -220,7 +220,7 @@ class DefaultValueSetter {
public: public:
explicit DefaultValueSetter(T default_value) explicit DefaultValueSetter(T default_value)
: default_value_(default_value) {} : default_value_(default_value) {}
void operator()(T* value) const { *value = default_value_; } const T& operator()() const { return default_value_; }
private: private:
T default_value_; T default_value_;
@ -259,7 +259,7 @@ class EnumInContainer {
// an attribute can have more than one limits // an attribute can have more than one limits
template <typename T> template <typename T>
class TypedAttrChecker { class TypedAttrChecker {
typedef std::function<void(T*)> DefaultValueChecker; typedef std::function<const T&()> DefaultValueChecker;
typedef std::function<void(const T&)> ValueChecker; typedef std::function<void(const T&)> ValueChecker;
public: public:
@ -297,18 +297,17 @@ class TypedAttrChecker {
} }
void operator()(AttributeMap* attr_map) const { void operator()(AttributeMap* attr_map) const {
if (!attr_map->count(attr_name_)) { auto it = attr_map->find(attr_name_);
if (it == attr_map->end()) {
// user do not set this attr // user do not set this attr
PADDLE_ENFORCE(!default_value_setter_.empty(), PADDLE_ENFORCE(!default_value_setter_.empty(),
"Attribute '%s' is required!", attr_name_); "Attribute '%s' is required!", attr_name_);
// default_value_setter_ has no more than one element // default_value_setter_ has no more than one element
T val; attr_map->emplace(attr_name_, default_value_setter_[0]());
(default_value_setter_[0])(&val);
(*attr_map)[attr_name_] = val;
} }
Attribute& attr = attr_map->at(attr_name_); it = attr_map->find(attr_name_);
ExtractAttribute<T> extract_attr(attr_name_); ExtractAttribute<T> extract_attr(attr_name_);
T* attr_value = extract_attr(attr); T* attr_value = extract_attr(it->second);
for (const auto& checker : value_checkers_) { for (const auto& checker : value_checkers_) {
checker(*attr_value); checker(*attr_value);
} }

@ -64,7 +64,7 @@ template <typename DeviceContext, typename T>
class TestKernel : public OpKernel<float> { class TestKernel : public OpKernel<float> {
public: public:
void Compute(const ExecutionContext& ctx) const { void Compute(const ExecutionContext& ctx) const {
std::cout << ctx.op().DebugString() << std::endl; std::cout << ctx.DebugString() << std::endl;
const Tensor* input = ctx.Input<Tensor>("input"); const Tensor* input = ctx.Input<Tensor>("input");

@ -47,18 +47,16 @@ class CompileTimeInferShapeContext : public InferShapeContext {
AttrReader Attrs() const override; AttrReader Attrs() const override;
const std::vector<std::string> &Inputs( std::vector<std::string> Inputs(const std::string &name) const override;
const std::string &name) const override;
const std::vector<std::string> &Outputs( std::vector<std::string> Outputs(const std::string &name) const override;
const std::string &name) const override;
void ShareDim(const std::string &in, const std::string &out, size_t i = 0, void ShareDim(const std::string &in, const std::string &out, size_t i = 0,
size_t j = 0) override { size_t j = 0) override {
PADDLE_ENFORCE_LT(i, Inputs(in).size()); PADDLE_ENFORCE_LT(i, Inputs(in).size());
PADDLE_ENFORCE_LT(j, Outputs(out).size()); PADDLE_ENFORCE_LT(j, Outputs(out).size());
const std::string &input_n = Inputs(in)[i]; std::string input_n = Inputs(in)[i];
const std::string &output_n = Outputs(out)[j]; std::string output_n = Outputs(out)[j];
PADDLE_ENFORCE(input_n != framework::kEmptyVarName, "The %s[%d] is @EMPTY@", PADDLE_ENFORCE(input_n != framework::kEmptyVarName, "The %s[%d] is @EMPTY@",
in, i); in, i);
@ -74,6 +72,33 @@ class CompileTimeInferShapeContext : public InferShapeContext {
SetDim(output_n, GetDim(input_n)); SetDim(output_n, GetDim(input_n));
} }
void ShareAllLoD(const std::string &in,
const std::string &out) const override {
auto &in_var_names = op_.Input(in);
auto &out_var_names = op_.Output(out);
PADDLE_ENFORCE_EQ(
in_var_names.size(), out_var_names.size(),
platform::errors::PreconditionNotMet(
"Op [%s]: Input var number shoule be equal with output var number",
op_.Type()));
for (size_t i = 0; i < in_var_names.size(); ++i) {
if (out_var_names[i] == framework::kEmptyVarName) {
continue;
}
auto *in_var = block_.FindVarRecursive(in_var_names[i]);
auto *out_var = block_.FindVarRecursive(out_var_names[i]);
if (in_var->GetType() != proto::VarType::LOD_TENSOR &&
in_var->GetType() != proto::VarType::LOD_TENSOR_ARRAY) {
VLOG(3) << "input " << in << " is not LoDTensor or LoDTensorArray.";
return;
}
out_var->SetLoDLevel(in_var->GetLoDLevel());
}
}
void ShareLoD(const std::string &in, const std::string &out, size_t i = 0, void ShareLoD(const std::string &in, const std::string &out, size_t i = 0,
size_t j = 0) const override { size_t j = 0) const override {
PADDLE_ENFORCE_LT(i, Inputs(in).size()); PADDLE_ENFORCE_LT(i, Inputs(in).size());
@ -173,7 +198,7 @@ class CompileTimeInferShapeContext : public InferShapeContext {
} }
void SetOutputDim(const std::string &name, const DDim &dim) override { void SetOutputDim(const std::string &name, const DDim &dim) override {
auto &arg_names = Outputs(name); auto arg_names = Outputs(name);
PADDLE_ENFORCE_EQ(arg_names.size(), 1UL, PADDLE_ENFORCE_EQ(arg_names.size(), 1UL,
"Output(%s) should hold one element, but now it holds %d", "Output(%s) should hold one element, but now it holds %d",
name, arg_names.size()); name, arg_names.size());
@ -182,7 +207,7 @@ class CompileTimeInferShapeContext : public InferShapeContext {
void SetOutputsDim(const std::string &name, void SetOutputsDim(const std::string &name,
const std::vector<DDim> &dims) override { const std::vector<DDim> &dims) override {
auto &names = Outputs(name); auto names = Outputs(name);
SetDims(names, dims); SetDims(names, dims);
} }
@ -789,12 +814,12 @@ AttrReader CompileTimeInferShapeContext::Attrs() const {
return AttrReader(op_.GetAttrMap()); return AttrReader(op_.GetAttrMap());
} }
const std::vector<std::string> &CompileTimeInferShapeContext::Inputs( std::vector<std::string> CompileTimeInferShapeContext::Inputs(
const std::string &name) const { const std::string &name) const {
return op_.Input(name); return op_.Input(name);
} }
const std::vector<std::string> &CompileTimeInferShapeContext::Outputs( std::vector<std::string> CompileTimeInferShapeContext::Outputs(
const std::string &name) const { const std::string &name) const {
return op_.Output(name); return op_.Output(name);
} }

@ -21,9 +21,9 @@ namespace framework {
std::unique_ptr<OperatorBase> OpRegistry::CreateOp( std::unique_ptr<OperatorBase> OpRegistry::CreateOp(
const std::string& type, const VariableNameMap& inputs, const std::string& type, const VariableNameMap& inputs,
const VariableNameMap& outputs, AttributeMap attrs) { const VariableNameMap& outputs, AttributeMap attrs, bool attr_check) {
auto& info = OpInfoMap::Instance().Get(type); auto& info = OpInfoMap::Instance().Get(type);
if (info.Checker() != nullptr) { if (attr_check && info.Checker() != nullptr) {
info.Checker()->Check(&attrs); info.Checker()->Check(&attrs);
} }
auto op = info.Creator()(type, inputs, outputs, attrs); auto op = info.Creator()(type, inputs, outputs, attrs);

@ -67,10 +67,34 @@ struct OperatorRegistrar : public Registrar {
class OpRegistry { class OpRegistry {
public: public:
/**
* @brief Return an OperatorBase constructed by type, inputs, outputs, attrs.
* In dygraph mode, inputs, output, attrs will be set to empty map to
* improve the execution efficiency of dygraph.
* Dygraph mode will use:
* framework::OpRegistry::CreateOp(type, {}, {}, {}, false).
*
* @param[str] type The operator type.
* @param[map] inputs Inputs map of the operator.
* @param[map] outputs Outputs map of the operator.
* @param[unordered_map] attrs Attributes map of the operator.
* @param[bool] attr_check
* Whether do the attribute check before OperatorBase construction.
* Default is true.
* Attr_check is used to control the check of attribute map.
* The check of attribute map have two purposes:
* 1. check whether the attribute item is valid or not.
* 2. add attribute item which has default value
* if it is not in attrs.
* In dygraph mode, attrs is an empty unordered_map,
* attr_check is set to false, otherwise it will be failed
* when check function called.
*/
static std::unique_ptr<OperatorBase> CreateOp(const std::string& type, static std::unique_ptr<OperatorBase> CreateOp(const std::string& type,
const VariableNameMap& inputs, const VariableNameMap& inputs,
const VariableNameMap& outputs, const VariableNameMap& outputs,
AttributeMap attrs); AttributeMap attrs,
bool attr_check = true);
static std::unique_ptr<OperatorBase> CreateOp(const proto::OpDesc& op_desc); static std::unique_ptr<OperatorBase> CreateOp(const proto::OpDesc& op_desc);

@ -321,8 +321,14 @@ OperatorBase::OperatorBase(const std::string& type,
attrs_(attrs), attrs_(attrs),
// NOTE(zjl): why op_info may be nullptr? // NOTE(zjl): why op_info may be nullptr?
info_(OpInfoMap::Instance().GetNullable(type)) { info_(OpInfoMap::Instance().GetNullable(type)) {
GenerateTemporaryNames(); // In dygraph mode, all the OperatorBase will be constructed by function:
CheckAllInputOutputSet(); // framework::OpRegistry::CreateOp(type, {}, {}, {}, false).
// Inputs, outputs and attrs will be set to empty map
// to improve the execution efficiency of dygraph.
if (inputs_.size() > 0 || outputs_.size() > 0) {
GenerateTemporaryNames();
CheckAllInputOutputSet();
}
} }
std::vector<std::string> OperatorBase::InputVars() const { std::vector<std::string> OperatorBase::InputVars() const {
@ -457,15 +463,14 @@ const Tensor* ExecutionContext::Input<Tensor>(const std::string& name) const {
template <> template <>
const std::vector<const Tensor*> ExecutionContext::MultiInput<Tensor>( const std::vector<const Tensor*> ExecutionContext::MultiInput<Tensor>(
const std::string& name) const { const std::string& name) const {
auto it = ctx_.inputs.find(name); auto vars = MultiInputVar(name);
if (it == ctx_.inputs.end()) { if (vars.size() == 0) {
return {}; return {};
} }
const std::vector<Variable*>& vars = it->second;
std::vector<const Tensor*> res; std::vector<const Tensor*> res;
res.reserve(vars.size()); res.reserve(vars.size());
std::transform(vars.begin(), vars.end(), std::back_inserter(res), std::transform(vars.begin(), vars.end(), std::back_inserter(res),
[&](Variable* var) -> const Tensor* { [&](const Variable* var) -> const Tensor* {
if (var == nullptr) return nullptr; if (var == nullptr) return nullptr;
PADDLE_ENFORCE( PADDLE_ENFORCE(
var->IsType<LoDTensor>(), var->IsType<LoDTensor>(),
@ -484,11 +489,11 @@ Tensor* ExecutionContext::Output<Tensor>(const std::string& name) const {
template <> template <>
std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>( std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>(
const std::string& name) const { const std::string& name) const {
auto it = ctx_.outputs.find(name); auto vars = MultiOutputVar(name);
if (it == ctx_.outputs.end()) {
if (vars.size() == 0) {
return {}; return {};
} }
const std::vector<Variable*>& vars = it->second;
std::vector<Tensor*> res; std::vector<Tensor*> res;
res.reserve(vars.size()); res.reserve(vars.size());
std::transform(vars.begin(), vars.end(), std::back_inserter(res), std::transform(vars.begin(), vars.end(), std::back_inserter(res),
@ -580,13 +585,11 @@ class RuntimeInferShapeContext : public InferShapeContext {
AttrReader Attrs() const override { return AttrReader(op_.Attrs()); } AttrReader Attrs() const override { return AttrReader(op_.Attrs()); }
const std::vector<std::string>& Inputs( std::vector<std::string> Inputs(const std::string& name) const override {
const std::string& name) const override {
return op_.Inputs(name); return op_.Inputs(name);
} }
const std::vector<std::string>& Outputs( std::vector<std::string> Outputs(const std::string& name) const override {
const std::string& name) const override {
return op_.Outputs(name); return op_.Outputs(name);
} }
@ -622,6 +625,51 @@ class RuntimeInferShapeContext : public InferShapeContext {
} }
} }
void ShareAllLoD(const std::string& in,
const std::string& out) const override {
auto in_it = ctx_.inputs.find(in);
auto out_it = ctx_.outputs.find(out);
PADDLE_ENFORCE_NE(in_it, ctx_.inputs.end(),
platform::errors::NotFound(
"Input [%s] found error in Op [%s]", in, op_.Type()));
PADDLE_ENFORCE_NE(
out_it, ctx_.outputs.end(),
platform::errors::NotFound("Output [%s] found error in Op [%s]", out,
op_.Type()));
auto& in_var_list = in_it->second;
auto& out_var_list = out_it->second;
PADDLE_ENFORCE_EQ(
in_var_list.size(), out_var_list.size(),
platform::errors::PreconditionNotMet(
"Op [%s]: Input var size should be equal with ouput var size",
op_.Type()));
auto& out_var_names = op_.Outputs(out);
for (size_t i = 0; i < in_var_list.size(); ++i) {
if (out_var_names[i] == framework::kEmptyVarName) {
continue;
}
Variable* in_var = in_var_list[i];
if (!in_var->IsType<LoDTensor>()) return;
Variable* out_var = out_var_list[i];
PADDLE_ENFORCE_EQ(out_var->IsType<LoDTensor>(), true,
platform::errors::PreconditionNotMet(
"The %d-th output of Output(%s) must be LoDTensor.",
i, out_var_names[i]));
auto& in_tensor = in_var->Get<LoDTensor>();
auto* out_tensor = out_var->GetMutable<LoDTensor>();
out_tensor->set_lod(in_tensor.lod());
#ifdef PADDLE_WITH_MKLDNN
if (in_tensor.layout() != DataLayout::kMKLDNN)
#endif
out_tensor->set_layout(in_tensor.layout());
}
}
void ShareLoD(const std::string& in, const std::string& out, size_t i = 0, void ShareLoD(const std::string& in, const std::string& out, size_t i = 0,
size_t j = 0) const override { size_t j = 0) const override {
auto in_it = ctx_.inputs.find(in); auto in_it = ctx_.inputs.find(in);
@ -1138,7 +1186,7 @@ void OperatorWithKernel::ParseInputDataType(
proto::VarType::Type* data_type) const { proto::VarType::Type* data_type) const {
proto::VarType::Type dafault_data_type = proto::VarType::Type dafault_data_type =
static_cast<proto::VarType::Type>(-1); static_cast<proto::VarType::Type>(-1);
const std::vector<const Variable*> vars = ctx.MultiInputVar(name); const std::vector<Variable*> vars = ctx.MultiInputVar(name);
for (size_t i = 0; i < vars.size(); ++i) { for (size_t i = 0; i < vars.size(); ++i) {
const Variable* var = vars[i]; const Variable* var = vars[i];
if (var != nullptr) { if (var != nullptr) {
@ -1156,7 +1204,7 @@ void OperatorWithKernel::ParseInputDataType(
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The Tensor in the %s Op's Input Variable %s(%s) is " "The Tensor in the %s Op's Input Variable %s(%s) is "
"not initialized.", "not initialized.",
Type(), name, ctx.Inputs(name).at(i))); Type(), name, ctx.InputNames(name).at(i)));
proto::VarType::Type tmp = t->type(); proto::VarType::Type tmp = t->type();
PADDLE_ENFORCE( PADDLE_ENFORCE(
tmp == *data_type || *data_type == dafault_data_type, tmp == *data_type || *data_type == dafault_data_type,
@ -1177,8 +1225,8 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType(
proto::VarType::Type dafault_data_type = proto::VarType::Type dafault_data_type =
static_cast<proto::VarType::Type>(-1); static_cast<proto::VarType::Type>(-1);
proto::VarType::Type data_type = dafault_data_type; proto::VarType::Type data_type = dafault_data_type;
for (auto& input : ctx.Context().inputs) { for (auto& input : ctx.InNameList()) {
ParseInputDataType(ctx, input.first, &data_type); ParseInputDataType(ctx, input, &data_type);
} }
PADDLE_ENFORCE_NE(data_type, dafault_data_type, PADDLE_ENFORCE_NE(data_type, dafault_data_type,
"DataType should be indicated by input Variable."); "DataType should be indicated by input Variable.");

@ -238,35 +238,57 @@ class ExecutionContext {
device_context_(device_context), device_context_(device_context),
ctx_(ctx), ctx_(ctx),
kernel_configs_(configs) {} kernel_configs_(configs) {}
virtual ~ExecutionContext() {}
const OperatorBase& op() const { return op_; } virtual std::string InputName(const std::string& name) const {
return op_.Input(name);
}
virtual std::vector<std::string> InputNames(const std::string& name) const {
return op_.Inputs(name);
}
virtual std::string OutputName(const std::string& name) const {
return op_.Output(name);
}
virtual std::vector<std::string> OutputNames(const std::string& name) const {
return op_.Outputs(name);
}
virtual bool HasAttr(const std::string& name) const {
return op_.HasAttr(name);
}
virtual const AttributeMap& Attrs() const { return op_.Attrs(); }
const std::string& Type() const { return op_.Type(); }
const Scope& scope() const { return scope_; } const Scope& scope() const { return scope_; }
template <typename T> template <typename T>
inline const T& Attr(const std::string& name) const { inline const T& Attr(const std::string& name) const {
return op_.Attr<T>(name); return boost::get<T>(GetAttr(name));
} }
bool HasAttr(const std::string& name) const { return op_.HasAttr(name); } virtual const Attribute& GetAttr(const std::string& name) const {
return op_.Attrs().at(name);
}
bool HasInput(const std::string& name) const; virtual bool HasInput(const std::string& name) const;
bool HasOutput(const std::string& name) const; virtual bool HasOutput(const std::string& name) const;
size_t InputSize(const std::string& name) const { virtual size_t InputSize(const std::string& name) const {
return op_.Inputs(name).size(); return op_.Inputs(name).size();
} }
size_t OutputSize(const std::string& name) const { virtual size_t OutputSize(const std::string& name) const {
return op_.Outputs(name).size(); return op_.Outputs(name).size();
} }
const Variable* InputVar(const std::string& name) const; virtual const Variable* InputVar(const std::string& name) const;
Variable* OutputVar(const std::string& name) const; virtual Variable* OutputVar(const std::string& name) const;
const std::vector<const Variable*> MultiInputVar( virtual const std::vector<Variable*> MultiInputVar(
const std::string& name) const { const std::string& name) const {
auto it = ctx_.inputs.find(name); auto it = ctx_.inputs.find(name);
if (it == ctx_.inputs.end()) { if (it == ctx_.inputs.end()) {
@ -275,8 +297,7 @@ class ExecutionContext {
return {it->second.begin(), it->second.end()}; return {it->second.begin(), it->second.end()};
} }
std::vector<Variable*> MultiOutputVar(const std::string& name) const { virtual std::vector<Variable*> MultiOutputVar(const std::string& name) const {
auto names = op_.Outputs(name);
auto it = ctx_.outputs.find(name); auto it = ctx_.outputs.find(name);
if (it == ctx_.outputs.end()) { if (it == ctx_.outputs.end()) {
return {}; return {};
@ -284,6 +305,17 @@ class ExecutionContext {
return it->second; return it->second;
} }
virtual std::vector<std::string> InNameList() const {
std::vector<std::string> vec_temp;
vec_temp.reserve(ctx_.inputs.size());
for (auto& input : ctx_.inputs) {
vec_temp.push_back(input.first);
}
return vec_temp;
}
template <typename T> template <typename T>
const T* Input(const std::string& name) const { const T* Input(const std::string& name) const {
auto* var = InputVar(name); auto* var = InputVar(name);
@ -298,15 +330,14 @@ class ExecutionContext {
template <typename T> template <typename T>
const std::vector<const T*> MultiInput(const std::string& name) const { const std::vector<const T*> MultiInput(const std::string& name) const {
auto it = ctx_.inputs.find(name); auto vars = MultiInputVar(name);
if (it == ctx_.inputs.end()) { if (vars.size() == 0) {
return {}; return {};
} }
const std::vector<Variable*>& vars = it->second;
std::vector<const T*> res; std::vector<const T*> res;
res.reserve(vars.size()); res.reserve(vars.size());
std::transform(vars.begin(), vars.end(), std::back_inserter(res), std::transform(vars.begin(), vars.end(), std::back_inserter(res),
[&](Variable* var) -> const T* { [&](const Variable* var) -> const T* {
return var == nullptr ? nullptr : &var->Get<T>(); return var == nullptr ? nullptr : &var->Get<T>();
}); });
return res; return res;
@ -314,17 +345,19 @@ class ExecutionContext {
template <typename T> template <typename T>
std::vector<T*> MultiOutput(const std::string& name) const { std::vector<T*> MultiOutput(const std::string& name) const {
auto it = ctx_.outputs.find(name); auto vars = MultiOutputVar(name);
if (it == ctx_.outputs.end()) {
if (vars.size() == 0) {
return {}; return {};
} }
const std::vector<Variable*>& vars = it->second;
std::vector<T*> res; std::vector<T*> res;
res.reserve(vars.size()); res.reserve(vars.size());
std::transform(vars.begin(), vars.end(), std::back_inserter(res), std::transform(vars.begin(), vars.end(), std::back_inserter(res),
[&](Variable* var) -> T* { [&](Variable* var) -> T* {
return var == nullptr ? nullptr : var->GetMutable<T>(); return var == nullptr ? nullptr : var->GetMutable<T>();
}); });
return res; return res;
} }
@ -347,16 +380,6 @@ class ExecutionContext {
} }
#endif #endif
//! Get actual name vector for this input.
const std::vector<std::string>& Inputs(const std::string& name) const {
return op_.Inputs(name);
}
//! Get actual name vector for this output.
const std::vector<std::string>& Outputs(const std::string& name) const {
return op_.Outputs(name);
}
template <typename T, typename DevContext> template <typename T, typename DevContext>
Tensor AllocateTmpTensor(const framework::DDim& dim, Tensor AllocateTmpTensor(const framework::DDim& dim,
const DevContext& dev_ctx) const { const DevContext& dev_ctx) const {
@ -385,7 +408,9 @@ class ExecutionContext {
return *boost::get<std::shared_ptr<T>>((*kernel_configs_)[idx]); return *boost::get<std::shared_ptr<T>>((*kernel_configs_)[idx]);
} }
const RuntimeContext& Context() const { return ctx_; } const RuntimeContext Context() const { return ctx_; }
std::string DebugString() const { return op_.DebugString(); }
private: private:
const OperatorBase& op_; const OperatorBase& op_;

@ -135,10 +135,10 @@ template <typename T1, typename T2>
class CPUKernelTest : public OpKernel<float> { class CPUKernelTest : public OpKernel<float> {
public: public:
void Compute(const ExecutionContext& ctx) const { void Compute(const ExecutionContext& ctx) const {
std::cout << ctx.op().DebugString() << std::endl; std::cout << ctx.DebugString() << std::endl;
cpu_kernel_run_num++; cpu_kernel_run_num++;
ASSERT_EQ(ctx.op().Input("x"), "IN1"); ASSERT_EQ(ctx.InputName("x"), "IN1");
ASSERT_EQ(ctx.op().Output("y"), "OUT1"); ASSERT_EQ(ctx.OutputName("y"), "OUT1");
} }
}; };
@ -146,10 +146,10 @@ template <typename T1, typename T2>
class CPUKernel2Test : public OpKernel<float> { class CPUKernel2Test : public OpKernel<float> {
public: public:
void Compute(const ExecutionContext& ctx) const { void Compute(const ExecutionContext& ctx) const {
std::cout << ctx.op().DebugString() << std::endl; std::cout << ctx.DebugString() << std::endl;
cpu_kernel2_run_num++; cpu_kernel2_run_num++;
ASSERT_EQ(ctx.op().Input("x"), "IN1"); ASSERT_EQ(ctx.InputName("x"), "IN1");
ASSERT_EQ(ctx.op().Output("y"), "OUT1"); ASSERT_EQ(ctx.OutputName("y"), "OUT1");
} }
}; };
@ -172,7 +172,7 @@ class OpKernelTestMultiInputsProtoAndCheckerMaker
class CPUKernalMultiInputsTest : public OpKernel<float> { class CPUKernalMultiInputsTest : public OpKernel<float> {
public: public:
void Compute(const ExecutionContext& ctx) const { void Compute(const ExecutionContext& ctx) const {
auto xs = ctx.op().Inputs("xs"); auto xs = ctx.InputNames("xs");
ASSERT_EQ(xs.size(), 3UL); ASSERT_EQ(xs.size(), 3UL);
ASSERT_EQ(xs[0], "x0"); ASSERT_EQ(xs[0], "x0");
ASSERT_EQ(xs[1], "x1"); ASSERT_EQ(xs[1], "x1");
@ -196,10 +196,10 @@ class CPUKernalMultiInputsTest : public OpKernel<float> {
auto outTensor0 = ctx.MultiOutput<Tensor>("ys"); auto outTensor0 = ctx.MultiOutput<Tensor>("ys");
ASSERT_EQ(outTensor0.size(), 2U); ASSERT_EQ(outTensor0.size(), 2U);
auto k = ctx.op().Input("k"); auto k = ctx.InputName("k");
ASSERT_EQ(k, "k0"); ASSERT_EQ(k, "k0");
auto ys = ctx.op().Outputs("ys"); auto ys = ctx.OutputNames("ys");
ASSERT_EQ(ys.size(), 2UL); ASSERT_EQ(ys.size(), 2UL);
ASSERT_EQ(ys[0], "y0"); ASSERT_EQ(ys[0], "y0");
ASSERT_EQ(ys[1], "y1"); ASSERT_EQ(ys[1], "y1");
@ -496,6 +496,41 @@ TEST(IndicateVarDataTypeTest, other) {
ASSERT_TRUE(caught); ASSERT_TRUE(caught);
} }
TEST(ExecutionContextAttrAndInOut, new_api) {
paddle::framework::InitDevices(true);
paddle::framework::proto::OpDesc op_desc;
op_desc.set_type("test_operator");
BuildVar("input", {"IN1"}, op_desc.add_inputs());
BuildVar("output", {"OUT1"}, op_desc.add_outputs());
auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("scale");
attr->set_type(paddle::framework::proto::AttrType::FLOAT);
attr->set_f(3.14);
paddle::platform::CPUPlace cpu_place;
paddle::framework::Scope scope;
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
auto* var = scope.Var("OUT1");
var->GetMutable<paddle::framework::LoDTensorArray>();
paddle::platform::DeviceContextPool& pool =
paddle::platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(cpu_place);
paddle::framework::RuntimeContext ctx({}, {});
paddle::framework::ExecutionContext exe_context(*(op.get()), scope, *dev_ctx,
ctx, nullptr);
ASSERT_EQ(exe_context.InputSize("input"), 1u);
ASSERT_EQ(exe_context.OutputSize("output"), 1u);
auto attr_map = exe_context.Attrs();
ASSERT_EQ(boost::get<float>(attr_map["scale"]), 3.14f);
ASSERT_EQ(exe_context.Type(), "test_operator");
}
namespace paddle { namespace paddle {
namespace framework { namespace framework {

@ -54,16 +54,18 @@ class InferShapeContext {
const std::vector<DDim> &dims); const std::vector<DDim> &dims);
virtual AttrReader Attrs() const = 0; virtual AttrReader Attrs() const = 0;
virtual const std::vector<std::string> &Inputs( virtual std::vector<std::string> Inputs(const std::string &name) const = 0;
const std::string &name) const = 0; virtual std::vector<std::string> Outputs(const std::string &name) const = 0;
virtual const std::vector<std::string> &Outputs(
const std::string &name) const = 0;
virtual void ShareDim(const std::string &in, const std::string &out, virtual void ShareDim(const std::string &in, const std::string &out,
size_t i = 0, size_t j = 0) = 0; size_t i = 0, size_t j = 0) = 0;
virtual void ShareLoD(const std::string &in, const std::string &out, virtual void ShareLoD(const std::string &in, const std::string &out,
size_t i = 0, size_t j = 0) const = 0; size_t i = 0, size_t j = 0) const = 0;
// share the lod information of all the tensor from in to out.
// out_vars[i].lod = in_vars[i].lod
virtual void ShareAllLoD(const std::string &in,
const std::string &out) const = 0;
virtual int32_t GetLoDLevel(const std::string &in, size_t i = 0) const = 0; virtual int32_t GetLoDLevel(const std::string &in, size_t i = 0) const = 0;

@ -152,8 +152,6 @@ void BasicEngine::PrepareDeps() {
q.pop(); q.pop();
VLOG(3) << "Checking grads of op " << cur_op->Type(); VLOG(3) << "Checking grads of op " << cur_op->Type();
CheckBackwardInputs(cur_op);
SetBackwardOutputs(cur_op); SetBackwardOutputs(cur_op);
PrepareGradAccumulators(cur_op); PrepareGradAccumulators(cur_op);
@ -189,6 +187,9 @@ void BasicEngine::Execute() {
OpBase* cur_op = q.front(); OpBase* cur_op = q.front();
q.pop(); q.pop();
// CheckBackWardInput
CheckBackwardInputs(cur_op);
// Step 1: Run Backward // Step 1: Run Backward
auto& bwd_ins = cur_op->GetInsMap(); auto& bwd_ins = cur_op->GetInsMap();
auto& bwd_outs = cur_op->GetOutsMap(); auto& bwd_outs = cur_op->GetOutsMap();
@ -210,7 +211,6 @@ void BasicEngine::Execute() {
} }
} }
} }
VLOG(3) << "Start to execute grad op " << cur_op->Type(); VLOG(3) << "Start to execute grad op " << cur_op->Type();
RunOp(cur_op, bwd_ins, tmp_outs, cur_op->place()); RunOp(cur_op, bwd_ins, tmp_outs, cur_op->place());
// Step 2: Sum Gradient // Step 2: Sum Gradient

@ -190,6 +190,7 @@ void VarBase::AddGradOps(const std::weak_ptr<OpBase>& op) {
void VarBase::ClearGradient() { void VarBase::ClearGradient() {
if (grad_var_) { if (grad_var_) {
auto* grad_t = grad_var_->var_.GetMutable<framework::LoDTensor>(); auto* grad_t = grad_var_->var_.GetMutable<framework::LoDTensor>();
if (grad_t->IsInitialized()) { if (grad_t->IsInitialized()) {
auto* dev_ctx = auto* dev_ctx =
platform::DeviceContextPool::Instance().Get(grad_t->place()); platform::DeviceContextPool::Instance().Get(grad_t->place());
@ -241,18 +242,9 @@ OpBase::OpBase(size_t id, const std::string& type, const NameVarBaseMap& ins,
info.Checker()->Check(&attrs_); info.Checker()->Check(&attrs_);
} }
auto input_name_map = CreateVarNameMap(info, type, ins, true); op_ = framework::OpRegistry::CreateOp(type, {}, {}, {}, false);
auto output_name_map = CreateVarNameMap(info, type, outs, false);
op_ = framework::OpRegistry::CreateOp(type, std::move(input_name_map),
std::move(output_name_map), attrs);
VLOG(3) << "Construct Op: " << type << std::endl;
}
// create OpBase from opdesc VLOG(3) << "Construct Op: " << type << std::endl;
OpBase::OpBase(size_t id, const framework::OpDesc& op_desc,
const platform::Place& place)
: id_(id), op_(framework::OpRegistry::CreateOp(op_desc)), place_(place) {
VLOG(3) << "Construct Op: " << op_desc.Type() << std::endl;
} }
void OpBase::CreateOperatorBase() { void OpBase::CreateOperatorBase() {
@ -260,11 +252,7 @@ void OpBase::CreateOperatorBase() {
if (info.Checker() != nullptr) { if (info.Checker() != nullptr) {
info.Checker()->Check(&attrs_); info.Checker()->Check(&attrs_);
} }
op_ = framework::OpRegistry::CreateOp(type_, {}, {}, {}, false);
auto input_name_map = CreateVarNameMap(info, type_, ins_, true);
auto output_name_map = CreateVarNameMap(info, type_, outs_, false);
op_ = framework::OpRegistry::CreateOp(type_, std::move(input_name_map),
std::move(output_name_map), attrs_);
} }
void OpBase::Run(const NameVarBaseMap& ins, const NameVarBaseMap& outs) { void OpBase::Run(const NameVarBaseMap& ins, const NameVarBaseMap& outs) {
@ -272,10 +260,9 @@ void OpBase::Run(const NameVarBaseMap& ins, const NameVarBaseMap& outs) {
PADDLE_ENFORCE_NOT_NULL(op_kernel, "only support op with kernel"); PADDLE_ENFORCE_NOT_NULL(op_kernel, "only support op with kernel");
auto& info = op_->Info(); auto& info = op_->Info();
if (info.infer_var_type_) { if (info.infer_var_type_) {
RuntimeInferVarTypeContext infer_var_type_ctx(ins, &outs, op_->Attrs()); RuntimeInferVarTypeContext infer_var_type_ctx(ins, &outs, attrs_);
info.infer_var_type_(&infer_var_type_ctx); info.infer_var_type_(&infer_var_type_ctx);
} }
// Initialize output var type // Initialize output var type
for (auto& var_pair : outs) { for (auto& var_pair : outs) {
for (auto& var : var_pair.second) { for (auto& var : var_pair.second) {
@ -285,13 +272,11 @@ void OpBase::Run(const NameVarBaseMap& ins, const NameVarBaseMap& outs) {
VLOG(3) << "Running Op " << Type(); VLOG(3) << "Running Op " << Type();
VLOG(5) << LayerDebugString(Type(), ins, outs); VLOG(5) << LayerDebugString(Type(), ins, outs);
auto runtime_ctx = PrepareRuntimeContext(ins, outs); framework::RuntimeContext runtime_ctx({}, {});
auto prepared_op =
VLOG(6) << "start preparing op: " << Type(); PreparedOp::Prepare(ins, outs, *op_kernel, place(), &attrs_);
auto prepared_op = PreparedOp::Prepare(runtime_ctx, *op_kernel, place(), ins);
VLOG(6) << "finish preparing op: " << Type(); prepared_op.Run(&ins, &outs, &attrs_);
prepared_op.Run();
VLOG(4) << LayerDebugString(Type(), ins, outs); VLOG(4) << LayerDebugString(Type(), ins, outs);
} }

File diff suppressed because it is too large Load Diff

@ -70,10 +70,11 @@ PreparedOp::PreparedOp(const framework::OperatorBase& op,
dev_ctx_(dev_ctx), dev_ctx_(dev_ctx),
kernel_configs_(kernel_configs) {} kernel_configs_(kernel_configs) {}
PreparedOp PreparedOp::Prepare(const framework::RuntimeContext& ctx, PreparedOp PreparedOp::Prepare(const NameVarBaseMap& ins,
const NameVarBaseMap& outs,
const framework::OperatorWithKernel& op, const framework::OperatorWithKernel& op,
platform::Place place, platform::Place place,
const NameVarBaseMap& ins) { const framework::AttributeMap* attrs) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place); auto* dev_ctx = pool.Get(place);
@ -88,9 +89,9 @@ PreparedOp PreparedOp::Prepare(const framework::RuntimeContext& ctx,
auto& kernels = kernels_iter->second; auto& kernels = kernels_iter->second;
auto expected_kernel_key = framework::RuntimeContext ctx({}, {});
op.GetExpectedKernelType(framework::ExecutionContext( auto expected_kernel_key = op.GetExpectedKernelType(DygraphExecutionContext(
op, framework::Scope(), *dev_ctx, ctx, nullptr)); op, framework::Scope(), *dev_ctx, ctx, nullptr, ins, outs, attrs));
VLOG(3) << "expected_kernel_key:" << expected_kernel_key; VLOG(3) << "expected_kernel_key:" << expected_kernel_key;
auto kernel_iter = kernels.find(expected_kernel_key); auto kernel_iter = kernels.find(expected_kernel_key);
@ -111,13 +112,20 @@ PreparedOp PreparedOp::Prepare(const framework::RuntimeContext& ctx,
return PreparedOp(op, ctx, kernel_iter->second, dev_ctx, kernel_configs); return PreparedOp(op, ctx, kernel_iter->second, dev_ctx, kernel_configs);
} }
void PreparedOp::Run() { void PreparedOp::Run(const NameVarBaseMap* in, const NameVarBaseMap* out,
const framework::AttributeMap* attrs) {
// TODO(zjl): remove scope in dygraph // TODO(zjl): remove scope in dygraph
framework::Scope scope; framework::Scope scope;
op_.RuntimeInferShape(scope, dev_ctx_->GetPlace(), ctx_);
VLOG(6) << "Finish Runtime infer shape"; DygraphInferShapeContext infer_shape_ctx(in, out, attrs);
func_(framework::ExecutionContext(op_, scope, *dev_ctx_, ctx_,
kernel_configs_)); framework::OperatorWithKernel* op_ker =
(framework::OperatorWithKernel*)(&op_);
op_ker->InferShape(&infer_shape_ctx);
func_(DygraphExecutionContext(op_, scope, *dev_ctx_, ctx_, kernel_configs_,
*in, *out, attrs));
} }
} // namespace imperative } // namespace imperative

@ -30,13 +30,16 @@ const framework::Tensor* GetTensorFromVar(const framework::Variable& var);
class PreparedOp { class PreparedOp {
public: public:
static PreparedOp Prepare(const framework::RuntimeContext& ctx, static PreparedOp Prepare(const NameVarBaseMap& ins,
const NameVarBaseMap& outs,
const framework::OperatorWithKernel& op, const framework::OperatorWithKernel& op,
platform::Place place, const NameVarBaseMap& ins); platform::Place place,
const framework::AttributeMap* attrs);
inline platform::DeviceContext* GetDeviceContext() const { return dev_ctx_; } inline platform::DeviceContext* GetDeviceContext() const { return dev_ctx_; }
void Run(); void Run(const NameVarBaseMap* in, const NameVarBaseMap* out,
const framework::AttributeMap* attrs);
static void PrepareData(const platform::Place& place, static void PrepareData(const platform::Place& place,
const NameVarBaseMap& ins, const NameVarBaseMap& ins,

@ -148,6 +148,67 @@ TEST(test_layer, test_varbase_basic) {
} }
// TODO(jiabin): Add more ut here for layer // TODO(jiabin): Add more ut here for layer
TEST(test_layer, test_dygraph_execution_context) {
std::shared_ptr<imperative::VarBase> vin(
new imperative::VarBase(false, "vin"));
std::shared_ptr<imperative::VarBase> vout(
new imperative::VarBase(false, "vout"));
framework::OpDesc desc;
platform::CPUPlace place;
var_pair x_pair = var_pair("X", vb_vector(1, vin));
var_pair y_pair = var_pair("Y", vb_vector(1, vin));
var_pair out_pair = var_pair("Out", vb_vector(1, vout));
imperative::NameVarBaseMap ins = {x_pair, y_pair};
imperative::NameVarBaseMap outs = {out_pair};
framework::AttributeMap concat_att_map;
concat_att_map["axis"] = 1;
auto op = framework::OpRegistry::CreateOp("mul", {}, {}, {}, false);
paddle::platform::CPUPlace cpu_place;
paddle::platform::DeviceContextPool& pool =
paddle::platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(cpu_place);
paddle::framework::RuntimeContext ctx({}, {});
framework::Scope scope;
DygraphExecutionContext dy_exe_context(*(op.get()), scope, *dev_ctx, ctx,
nullptr, ins, outs, &concat_att_map);
ASSERT_EQ(dy_exe_context.InputSize("X"), 1u);
ASSERT_EQ(dy_exe_context.InputName("X"), "vin");
ASSERT_EQ(dy_exe_context.HasAttr("axis"), true);
auto attr_map = dy_exe_context.Attrs();
ASSERT_EQ(boost::get<int>(attr_map["axis"]), 1);
ASSERT_EQ(dy_exe_context.OutputSize("Out"), 1u);
ASSERT_EQ(dy_exe_context.HasOutput("Out"), true);
}
TEST(test_layer, test_dygraph_infershape_context) {
std::shared_ptr<imperative::VarBase> vin(
new imperative::VarBase(false, "vin"));
std::shared_ptr<imperative::VarBase> vout(
new imperative::VarBase(false, "vout"));
framework::OpDesc desc;
platform::CPUPlace place;
var_pair x_pair = var_pair("X", vb_vector(1, vin));
var_pair y_pair = var_pair("Y", vb_vector(1, vin));
var_pair out_pair = var_pair("Out", vb_vector(1, vout));
imperative::NameVarBaseMap ins = {x_pair, y_pair};
imperative::NameVarBaseMap outs = {out_pair};
framework::AttributeMap concat_att_map;
concat_att_map["axis"] = 1;
DygraphInferShapeContext infer_shape_ctx(&ins, &outs, &concat_att_map);
bool have_x = infer_shape_ctx.HasOutputs("Out");
ASSERT_EQ(have_x, true);
bool have_z = infer_shape_ctx.HasOutputs("Z");
ASSERT_EQ(have_z, false);
}
} // namespace imperative } // namespace imperative
} // namespace paddle } // namespace paddle

@ -110,8 +110,8 @@ TEST(test_prepare_op, test_prepare_op) {
framework::OperatorWithKernel op("split", var_in_map, var_out_map, framework::OperatorWithKernel op("split", var_in_map, var_out_map,
split_attr_map); split_attr_map);
framework::RuntimeContext ctx = PrepareRuntimeContext(ins, outs); framework::RuntimeContext ctx = PrepareRuntimeContext(ins, outs);
ASSERT_NO_FATAL_FAILURE(PreparedOp preparedOp = ASSERT_NO_FATAL_FAILURE(PreparedOp preparedOp = PreparedOp::Prepare(
PreparedOp::Prepare(ctx, op, place, ins)); ins, outs, op, place, &split_attr_map));
} }
const framework::Tensor* GetTensorFromVar(const framework::Variable& var); const framework::Tensor* GetTensorFromVar(const framework::Variable& var);
@ -158,7 +158,8 @@ TEST(test_prepare_op, test_prepare_data) {
framework::RuntimeContext ctx = PrepareRuntimeContext(ins, outs); framework::RuntimeContext ctx = PrepareRuntimeContext(ins, outs);
// test if it can be transformed to GPU place // test if it can be transformed to GPU place
PreparedOp prepared_op = PreparedOp::Prepare(ctx, assign_op, gpu_place, ins); PreparedOp prepared_op =
PreparedOp::Prepare(ins, outs, assign_op, gpu_place, &assign_attr_map);
for (const auto& name_pair : ins) { for (const auto& name_pair : ins) {
for (const auto& vb : name_pair.second) { for (const auto& vb : name_pair.second) {
ASSERT_TRUE(platform::is_same_place( ASSERT_TRUE(platform::is_same_place(
@ -201,7 +202,8 @@ TEST(test_prepare_op, test_prepare_data_same_place) {
framework::RuntimeContext ctx = PrepareRuntimeContext(ins, outs); framework::RuntimeContext ctx = PrepareRuntimeContext(ins, outs);
// test if it never transfered on GPU place // test if it never transfered on GPU place
PreparedOp prepared_op = PreparedOp::Prepare(ctx, assign_op, cpu_place, ins); PreparedOp prepared_op =
PreparedOp::Prepare(ins, outs, assign_op, cpu_place, &assign_attr_map);
for (const auto& name_pair : ins) { for (const auto& name_pair : ins) {
for (const auto& vb : name_pair.second) { for (const auto& vb : name_pair.second) {
ASSERT_TRUE(platform::is_same_place( ASSERT_TRUE(platform::is_same_place(

@ -82,10 +82,9 @@ static void PassStopGradient(const NameVarBaseMap& outs, bool generate_grad) {
void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins, void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins,
const NameVarBaseMap& outs, framework::AttributeMap attrs, const NameVarBaseMap& outs, framework::AttributeMap attrs,
const platform::Place& place, bool trace_backward) { const platform::Place& place, bool trace_backward) {
platform::RecordEvent event(type);
VLOG(1) << "Trace Op: " << type; VLOG(1) << "Trace Op: " << type;
size_t op_id = GenerateUniqueId(); size_t op_id = GenerateUniqueId();
auto op = OpBase::Create(op_id, type, ins, outs, std::move(attrs), place); auto op = OpBase::Create(op_id, type, ins, outs, attrs, place);
op->Run(ins, outs); op->Run(ins, outs);
if (enable_program_desc_tracing_) { if (enable_program_desc_tracing_) {

@ -62,11 +62,11 @@ inline void ExtractActivationTensor(const framework::ExecutionContext& context,
auto out_var = context.OutputVar("Out"); auto out_var = context.OutputVar("Out");
PADDLE_ENFORCE(x_var != nullptr, PADDLE_ENFORCE(x_var != nullptr,
"Cannot get input Variable X, variable name = %s", "Cannot get input Variable X, variable name = %s",
context.op().Input("X")); context.InputName("X"));
PADDLE_ENFORCE(out_var != nullptr, PADDLE_ENFORCE(out_var != nullptr,
"Cannot get output Variable Out, variable name = %s", "Cannot get output Variable Out, variable name = %s",
context.op().Output("Out")); context.OutputName("Out"));
if (CanBeUsedBySelectedRows.count(context.op().Type())) { if (CanBeUsedBySelectedRows.count(context.Type())) {
*X = paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(*x_var); *X = paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(*x_var);
*Out = paddle::framework::GetMutableLoDTensorOrSelectedRowsValueFromVar( *Out = paddle::framework::GetMutableLoDTensorOrSelectedRowsValueFromVar(
out_var); out_var);
@ -77,7 +77,7 @@ inline void ExtractActivationTensor(const framework::ExecutionContext& context,
PADDLE_ENFORCE(*Out != nullptr, PADDLE_ENFORCE(*Out != nullptr,
"Cannot get output tensor Out, variable name = %s", "Cannot get output tensor Out, variable name = %s",
context.op().Output("Out")); context.OutputName("Out"));
} }
template <ActBwdOpFwdDeps kDepValue> template <ActBwdOpFwdDeps kDepValue>
@ -93,18 +93,18 @@ inline void ExtractActivationGradTensor(
out_var = context.InputVar("Out"); out_var = context.InputVar("Out");
PADDLE_ENFORCE(out_var != nullptr, PADDLE_ENFORCE(out_var != nullptr,
"Cannot get input Variable Out, variable name = %s", "Cannot get input Variable Out, variable name = %s",
context.op().Input("Out")); context.InputName("Out"));
} }
PADDLE_ENFORCE(out_grad_var != nullptr, PADDLE_ENFORCE(out_grad_var != nullptr,
"Cannot get input Variable %s, variable name = %s", "Cannot get input Variable %s, variable name = %s",
framework::GradVarName("Out"), framework::GradVarName("Out"),
context.op().Input(framework::GradVarName("Out"))); context.InputName(framework::GradVarName("Out")));
PADDLE_ENFORCE(x_grad_var != nullptr, PADDLE_ENFORCE(x_grad_var != nullptr,
"Cannot get output Variable %s, variable name = %s", "Cannot get output Variable %s, variable name = %s",
framework::GradVarName("X"), framework::GradVarName("X"),
context.op().Output(framework::GradVarName("X"))); context.OutputName(framework::GradVarName("X")));
if (CanBeUsedBySelectedRows.count(context.op().Type())) { if (CanBeUsedBySelectedRows.count(context.Type())) {
*dOut = paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar( *dOut = paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(
*out_grad_var); *out_grad_var);
*dX = paddle::framework::GetMutableLoDTensorOrSelectedRowsValueFromVar( *dX = paddle::framework::GetMutableLoDTensorOrSelectedRowsValueFromVar(
@ -132,20 +132,20 @@ inline void ExtractActivationGradTensor(
PADDLE_ENFORCE(*dX != nullptr, PADDLE_ENFORCE(*dX != nullptr,
"Cannot get output tensor %s, variable name = %s", "Cannot get output tensor %s, variable name = %s",
framework::GradVarName("X"), framework::GradVarName("X"),
context.op().Output(framework::GradVarName("X"))); context.OutputName(framework::GradVarName("X")));
if (static_cast<int>(kDepValue) & static_cast<int>(kDepX)) { if (static_cast<int>(kDepValue) & static_cast<int>(kDepX)) {
auto x_var = context.InputVar("X"); auto x_var = context.InputVar("X");
PADDLE_ENFORCE(x_var != nullptr, PADDLE_ENFORCE(x_var != nullptr,
"Cannot get input tensor X, variable name = %s", "Cannot get input tensor X, variable name = %s",
context.op().Input("X")); context.InputName("X"));
if (CanBeUsedBySelectedRows.count(context.op().Type())) { if (CanBeUsedBySelectedRows.count(context.Type())) {
*X = paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(*x_var); *X = paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(*x_var);
} else { } else {
*X = context.Input<framework::Tensor>("X"); *X = context.Input<framework::Tensor>("X");
} }
} else { } else {
VLOG(10) << " Inplace activation of Op : " << context.op().Type(); VLOG(10) << " Inplace activation of Op : " << context.Type();
*X = *dX; *X = *dX;
} }
} }
@ -1273,8 +1273,8 @@ inline void ExtractActivationDoubleGradTensor(
auto ddo_var = ctx.OutputVar("DDOut"); auto ddo_var = ctx.OutputVar("DDOut");
PADDLE_ENFORCE(ddx_var != nullptr, PADDLE_ENFORCE(ddx_var != nullptr,
"Cannot get input Variable Out, variable name = %s", "Cannot get input Variable Out, variable name = %s",
ctx.op().Input("DDX")); ctx.InputName("DDX"));
if (CanBeUsedBySelectedRows.count(ctx.op().Type())) { if (CanBeUsedBySelectedRows.count(ctx.Type())) {
*ddX = paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(*ddx_var); *ddX = paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(*ddx_var);
if (ddo_var) { if (ddo_var) {
*ddOut = paddle::framework::GetMutableLoDTensorOrSelectedRowsValueFromVar( *ddOut = paddle::framework::GetMutableLoDTensorOrSelectedRowsValueFromVar(
@ -1288,15 +1288,15 @@ inline void ExtractActivationDoubleGradTensor(
} }
PADDLE_ENFORCE(*ddX != nullptr, PADDLE_ENFORCE(*ddX != nullptr,
"Cannot get output tensor DDX, variable name = %s", "Cannot get output tensor DDX, variable name = %s",
ctx.op().Output("DDX")); ctx.OutputName("DDX"));
if (static_cast<int>(kDepValue) & static_cast<int>(kDepX)) { if (static_cast<int>(kDepValue) & static_cast<int>(kDepX)) {
auto x_var = ctx.InputVar("X"); auto x_var = ctx.InputVar("X");
PADDLE_ENFORCE(x_var != nullptr, PADDLE_ENFORCE(x_var != nullptr,
"Cannot get input Variable Out, variable name = %s", "Cannot get input Variable Out, variable name = %s",
ctx.op().Input("X")); ctx.InputName("X"));
auto dx_var = ctx.OutputVar("DX"); auto dx_var = ctx.OutputVar("DX");
if (CanBeUsedBySelectedRows.count(ctx.op().Type())) { if (CanBeUsedBySelectedRows.count(ctx.Type())) {
*X = paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(*x_var); *X = paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(*x_var);
if (dx_var) { if (dx_var) {
*dX = paddle::framework::GetMutableLoDTensorOrSelectedRowsValueFromVar( *dX = paddle::framework::GetMutableLoDTensorOrSelectedRowsValueFromVar(
@ -1309,16 +1309,16 @@ inline void ExtractActivationDoubleGradTensor(
} }
} }
} else { } else {
VLOG(10) << "Inplace activation of Op: " << ctx.op().Type(); VLOG(10) << "Inplace activation of Op: " << ctx.Type();
*X = *ddX; *X = *ddX;
} }
if (static_cast<int>(kDepValue) & static_cast<int>(kDepOut)) { if (static_cast<int>(kDepValue) & static_cast<int>(kDepOut)) {
auto out_var = ctx.InputVar("Out"); auto out_var = ctx.InputVar("Out");
PADDLE_ENFORCE(out_var != nullptr, PADDLE_ENFORCE(out_var != nullptr,
"Cannot get input tensor Out, variable name = %s", "Cannot get input tensor Out, variable name = %s",
ctx.op().Input("Out")); ctx.InputName("Out"));
auto dout_var = ctx.OutputVar("DOut"); auto dout_var = ctx.OutputVar("DOut");
if (CanBeUsedBySelectedRows.count(ctx.op().Type())) { if (CanBeUsedBySelectedRows.count(ctx.Type())) {
*Out = *Out =
paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(*out_var); paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(*out_var);
if (dout_var) { if (dout_var) {
@ -1333,7 +1333,7 @@ inline void ExtractActivationDoubleGradTensor(
} }
} }
} else { } else {
VLOG(10) << "Inplace activation of Op: " << ctx.op().Type(); VLOG(10) << "Inplace activation of Op: " << ctx.Type();
*Out = *ddX; *Out = *ddX;
} }
} }
@ -1471,20 +1471,20 @@ inline void ExtractDoubleGradTensorWithInputDOut(
auto ddo_var = ctx.OutputVar("DDOut"); auto ddo_var = ctx.OutputVar("DDOut");
PADDLE_ENFORCE(ddx_var != nullptr, PADDLE_ENFORCE(ddx_var != nullptr,
"Cannot get input Variable Out, variable name = %s", "Cannot get input Variable Out, variable name = %s",
ctx.op().Input("DDX")); ctx.InputName("DDX"));
*ddX = ctx.Input<framework::Tensor>("DDX"); *ddX = ctx.Input<framework::Tensor>("DDX");
if (ddo_var) { if (ddo_var) {
*ddOut = ctx.Output<framework::Tensor>("DDOut"); *ddOut = ctx.Output<framework::Tensor>("DDOut");
} }
PADDLE_ENFORCE(*ddX != nullptr, PADDLE_ENFORCE(*ddX != nullptr,
"Cannot get output tensor DDX, variable name = %s", "Cannot get output tensor DDX, variable name = %s",
ctx.op().Output("DDX")); ctx.OutputName("DDX"));
// extract x(input), dx(output) // extract x(input), dx(output)
auto x_var = ctx.InputVar("X"); auto x_var = ctx.InputVar("X");
PADDLE_ENFORCE(x_var != nullptr, PADDLE_ENFORCE(x_var != nullptr,
"Cannot get input Variable Out, variable name = %s", "Cannot get input Variable Out, variable name = %s",
ctx.op().Input("X")); ctx.InputName("X"));
auto dx_var = ctx.OutputVar("DX"); auto dx_var = ctx.OutputVar("DX");
*X = ctx.Input<framework::Tensor>("X"); *X = ctx.Input<framework::Tensor>("X");
if (dx_var) { if (dx_var) {
@ -1537,20 +1537,20 @@ class SqrtDoubleGradKernel
auto ddo_var = ctx.OutputVar("DDOut"); auto ddo_var = ctx.OutputVar("DDOut");
PADDLE_ENFORCE(ddx_var != nullptr, PADDLE_ENFORCE(ddx_var != nullptr,
"Cannot get input Variable DDX, variable name = %s", "Cannot get input Variable DDX, variable name = %s",
ctx.op().Input("DDX")); ctx.InputName("DDX"));
ddX = ctx.Input<framework::Tensor>("DDX"); ddX = ctx.Input<framework::Tensor>("DDX");
if (ddo_var) { if (ddo_var) {
ddOut = ctx.Output<framework::Tensor>("DDOut"); ddOut = ctx.Output<framework::Tensor>("DDOut");
} }
PADDLE_ENFORCE(ddX != nullptr, PADDLE_ENFORCE(ddX != nullptr,
"Cannot get input Variable DDX, variable name = %s", "Cannot get input Variable DDX, variable name = %s",
ctx.op().Input("DDX")); ctx.InputName("DDX"));
// extract out(input), dout(output) // extract out(input), dout(output)
auto out_var = ctx.InputVar("Out"); auto out_var = ctx.InputVar("Out");
PADDLE_ENFORCE(out_var != nullptr, PADDLE_ENFORCE(out_var != nullptr,
"Cannot get input Variable Out, variable name = %s", "Cannot get input Variable Out, variable name = %s",
ctx.op().Input("Out")); ctx.InputName("Out"));
auto dout_var = ctx.OutputVar("DOut"); auto dout_var = ctx.OutputVar("DOut");
Out = ctx.Input<framework::Tensor>("Out"); Out = ctx.Input<framework::Tensor>("Out");
if (dout_var) { if (dout_var) {
@ -1561,7 +1561,7 @@ class SqrtDoubleGradKernel
auto dx_var = ctx.InputVar("DX"); auto dx_var = ctx.InputVar("DX");
PADDLE_ENFORCE(dx_var != nullptr, PADDLE_ENFORCE(dx_var != nullptr,
"Cannot get input Variable DX, variable name = %s", "Cannot get input Variable DX, variable name = %s",
ctx.op().Input("DX")); ctx.InputName("DX"));
if (dx_var) { if (dx_var) {
dX = ctx.Input<framework::Tensor>("DX"); dX = ctx.Input<framework::Tensor>("DX");
} }

@ -27,8 +27,8 @@ template <typename DeviceContext, typename T>
class CoalesceTensorOpKernel : public framework::OpKernel<T> { class CoalesceTensorOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &context) const override { void Compute(const framework::ExecutionContext &context) const override {
auto &in_var_names = context.Inputs("Input"); auto in_var_names = context.InputNames("Input");
auto &out_var_names = context.Outputs("Output"); auto out_var_names = context.OutputNames("Output");
auto &in_vars = context.MultiInputVar("Input"); auto &in_vars = context.MultiInputVar("Input");
auto out_vars = context.MultiOutputVar("Output"); auto out_vars = context.MultiOutputVar("Output");

@ -32,6 +32,7 @@ class ConcatOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_GE(ctx->Inputs("X").size(), 1UL, PADDLE_ENFORCE_GE(ctx->Inputs("X").size(), 1UL,
"Inputs(X) of ConcatOp should not be empty."); "Inputs(X) of ConcatOp should not be empty.");
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
"Output(Out) of ConcatOp should not be null."); "Output(Out) of ConcatOp should not be null.");
@ -152,17 +153,8 @@ class ConcatOpGrad : public framework::OperatorWithKernel {
auto in_x = "X"; auto in_x = "X";
auto out_x_g_n = framework::GradVarName(in_x); auto out_x_g_n = framework::GradVarName(in_x);
ctx->SetOutputsDim(out_x_g_n, ctx->GetInputsDim(in_x)); ctx->SetOutputsDim(out_x_g_n, ctx->GetInputsDim(in_x));
auto &in_names = ctx->Inputs(in_x);
auto &out_names = ctx->Outputs(out_x_g_n); ctx->ShareAllLoD(in_x, out_x_g_n);
PADDLE_ENFORCE_EQ(
in_names.size(), out_names.size(),
"The number of arguments in %s[%d] and %s[%d] is not equal.", in_x,
in_names.size(), out_x_g_n, out_names.size());
for (size_t i = 0; i < in_names.size(); ++i) {
if (out_names[i] != framework::kEmptyVarName) {
ctx->ShareLoD(in_x, out_x_g_n, i, i);
}
}
} }
protected: protected:
@ -197,7 +189,9 @@ class ConcatGradOpMaker : public framework::SingleGradOpMaker<T> {
std::unique_ptr<T> op(new T()); std::unique_ptr<T> op(new T());
op->SetType("concat_grad"); op->SetType("concat_grad");
op->SetInput("X", this->Input("X")); op->SetInput("X", this->Input("X"));
op->SetInput("AxisTensor", this->Input("AxisTensor")); if (this->HasInput("AxisTensor")) {
op->SetInput("AxisTensor", this->Input("AxisTensor"));
}
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X", false)); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X", false));
op->SetAttrMap(this->Attrs()); op->SetAttrMap(this->Attrs());

@ -139,7 +139,7 @@ class ConcatGradKernel : public framework::OpKernel<T> {
auto* out_grad = auto* out_grad =
ctx.Input<framework::Tensor>(framework::GradVarName("Out")); ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto ins = ctx.MultiInput<framework::LoDTensor>("X"); auto ins = ctx.MultiInput<framework::LoDTensor>("X");
auto out_var_names = ctx.Outputs(framework::GradVarName("X")); auto out_var_names = ctx.OutputNames(framework::GradVarName("X"));
auto outs = auto outs =
ctx.MultiOutput<framework::LoDTensor>(framework::GradVarName("X")); ctx.MultiOutput<framework::LoDTensor>(framework::GradVarName("X"));

@ -665,7 +665,7 @@ class GemmConvDoubleGradKernel : public framework::OpKernel<T> {
Tensor* dX = ctx.Output<Tensor>("DInput"); Tensor* dX = ctx.Output<Tensor>("DInput");
Tensor W = detail::Ref(ctx.Input<Tensor>("Filter"), Tensor W = detail::Ref(ctx.Input<Tensor>("Filter"),
"Cannot find input Filter(%s) in scope)", "Cannot find input Filter(%s) in scope)",
ctx.Inputs("Filter")[0]); ctx.InputNames("Filter")[0]);
if (!ddY && !dW && !dX) return; if (!ddY && !dW && !dX) return;
const int groups = ctx.Attr<int>("groups"); const int groups = ctx.Attr<int>("groups");

@ -62,7 +62,7 @@ class CudnnLSTMGPUKernel : public framework::OpKernel<T> {
// multi-devices before the first running. // multi-devices before the first running.
// use parent scope to make cache persistable // use parent scope to make cache persistable
auto *scope = const_cast<framework::Scope *>(ctx.scope().parent()); auto *scope = const_cast<framework::Scope *>(ctx.scope().parent());
auto cache_var_name = ctx.Inputs("Cache")[0]; auto cache_var_name = ctx.InputNames("Cache")[0];
cache_var = scope->Var(cache_var_name); cache_var = scope->Var(cache_var_name);
} }
CudnnRNNCache *cudnn_rnn_cache = nullptr; CudnnRNNCache *cudnn_rnn_cache = nullptr;

@ -31,11 +31,11 @@ class CumKernel : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto& X = detail::Ref(context.Input<framework::Tensor>("X"), auto& X = detail::Ref(context.Input<framework::Tensor>("X"),
"Cannot get input tensor X, variable name = %s", "Cannot get input tensor X, variable name = %s",
context.op().Input("X")); context.InputName("X"));
auto& Out = detail::Ref(context.Output<framework::Tensor>("Out"), auto& Out = detail::Ref(context.Output<framework::Tensor>("Out"),
"Cannot get output tensor Out, variable name = %s", "Cannot get output tensor Out, variable name = %s",
context.op().Output("Out")); context.OutputName("Out"));
int axis = context.Attr<int>("axis"); int axis = context.Attr<int>("axis");
bool exclusive = context.Attr<bool>("exclusive"); bool exclusive = context.Attr<bool>("exclusive");
bool reverse = context.Attr<bool>("reverse"); bool reverse = context.Attr<bool>("reverse");

@ -295,10 +295,10 @@ class GenerateProposalsKernel : public framework::OpKernel<T> {
auto *im_info = context.Input<Tensor>("ImInfo"); auto *im_info = context.Input<Tensor>("ImInfo");
auto anchors = detail::Ref(context.Input<Tensor>("Anchors"), auto anchors = detail::Ref(context.Input<Tensor>("Anchors"),
"Cannot find input Anchors(%s) in scope", "Cannot find input Anchors(%s) in scope",
context.Inputs("Anchors")[0]); context.InputNames("Anchors")[0]);
auto variances = detail::Ref(context.Input<Tensor>("Variances"), auto variances = detail::Ref(context.Input<Tensor>("Variances"),
"Cannot find input Variances(%s) in scope", "Cannot find input Variances(%s) in scope",
context.Inputs("Variances")[0]); context.InputNames("Variances")[0]);
auto *rpn_rois = context.Output<LoDTensor>("RpnRois"); auto *rpn_rois = context.Output<LoDTensor>("RpnRois");
auto *rpn_roi_probs = context.Output<LoDTensor>("RpnRoiProbs"); auto *rpn_roi_probs = context.Output<LoDTensor>("RpnRoiProbs");

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save