Implement infer var type context

revert-16190-refine_parallel_executor
minqiyang 6 years ago
parent 0b49e43d3a
commit ca392c7e97

@ -68,11 +68,11 @@ class SplitOpMaker : public OpProtoAndCheckerMaker {
class DummyVarTypeInference : public VarTypeInference { class DummyVarTypeInference : public VarTypeInference {
public: public:
void operator()(const OpDesc& op_desc, BlockDesc* block) const override { void operator()(framework::InferVarTypeContext& ctx) const override {
auto& inputs = op_desc.Input("X"); auto& inputs = ctx.Input("X");
auto type = block->Var(inputs.front())->GetType(); auto type = ctx.GetType(inputs.front());
auto out_var_name = op_desc.Output("Out").front(); auto out_var_name = ctx.Output("Out").front();
block->Var(out_var_name)->SetType(type); ctx.SetType(out_var_name, type);
} }
}; };

@ -127,9 +127,9 @@ struct OpInfoFiller<T, kGradOpDescMaker> {
template <typename T> template <typename T>
struct OpInfoFiller<T, kVarTypeInference> { struct OpInfoFiller<T, kVarTypeInference> {
void operator()(const char* op_type, OpInfo* info) const { void operator()(const char* op_type, OpInfo* info) const {
info->infer_var_type_ = [](const OpDesc& fwd_op, BlockDesc* block) { info->infer_var_type_ = [](InferVarTypeContext& context) {
T inference; T inference;
inference(fwd_op, block); inference(context);
}; };
} }
}; };

@ -43,20 +43,20 @@ class SumOpMaker : public OpProtoAndCheckerMaker {
class SumOpVarTypeInference : public VarTypeInference { class SumOpVarTypeInference : public VarTypeInference {
public: public:
void operator()(const OpDesc &op_desc, BlockDesc *block) const override { void operator()(InferVarTypeContext &ctx) const override {
auto &inputs = op_desc.Input("X"); auto &inputs = ctx.Input("X");
auto default_var_type = proto::VarType::SELECTED_ROWS; auto default_var_type = proto::VarType::SELECTED_ROWS;
bool any_input_is_lod_tensor = std::any_of( bool any_input_is_lod_tensor = std::any_of(
inputs.begin(), inputs.end(), [block](const std::string &name) { inputs.begin(), inputs.end(), [ctx](const std::string &name) {
return block->Var(name)->GetType() == proto::VarType::LOD_TENSOR; return ctx.GetType(name) == proto::VarType::LOD_TENSOR;
}); });
if (any_input_is_lod_tensor) { if (any_input_is_lod_tensor) {
default_var_type = proto::VarType::LOD_TENSOR; default_var_type = proto::VarType::LOD_TENSOR;
} }
auto out_var_name = op_desc.Output("Out").front(); auto out_var_name = ctx.Output("Out").front();
block->Var(out_var_name)->SetType(default_var_type); ctx.SetType(out_var_name, default_var_type);
} }
}; };
@ -71,7 +71,7 @@ class DummyOpMaker : public OpProtoAndCheckerMaker {
class DummyOpVarTypeInference : public VarTypeInference { class DummyOpVarTypeInference : public VarTypeInference {
public: public:
void operator()(const OpDesc &op_desc, BlockDesc *block) const override {} void operator()(framework::InferVarTypeContext &ctx) const override {}
}; };
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle

@ -24,6 +24,7 @@ limitations under the License. */
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/shape_inference.h" #include "paddle/fluid/framework/shape_inference.h"
#include "paddle/fluid/framework/var_type_inference.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
@ -677,7 +678,8 @@ void OpDesc::InferVarType(BlockDesc *block) const {
// var type inference. Hence, we don't do any "default" setting here. // var type inference. Hence, we don't do any "default" setting here.
auto &info = OpInfoMap::Instance().Get(this->Type()); auto &info = OpInfoMap::Instance().Get(this->Type());
if (info.infer_var_type_) { if (info.infer_var_type_) {
info.infer_var_type_(*this, block); InferVarTypeContext context(this, block);
info.infer_var_type_(context);
} }
} }

@ -27,6 +27,7 @@ namespace framework {
class OperatorBase; class OperatorBase;
class OpDesc; class OpDesc;
class InferShapeContext; class InferShapeContext;
class InferVarTypeContext;
class BlockDesc; class BlockDesc;
class Variable; class Variable;
@ -53,7 +54,7 @@ using GradOpMakerFN = std::function<std::vector<std::unique_ptr<OpDesc>>(
const std::vector<BlockDesc*>& grad_block)>; const std::vector<BlockDesc*>& grad_block)>;
using InferVarTypeFN = using InferVarTypeFN =
std::function<void(const OpDesc& /*op_desc*/, BlockDesc* /*block*/)>; std::function<void(framework::InferVarTypeContext& /*context*/)>;
using InferShapeFN = std::function<void(InferShapeContext*)>; using InferShapeFN = std::function<void(InferShapeContext*)>;

@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include <string> #include <string>
#include <vector>
#include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/type_defs.h" #include "paddle/fluid/framework/type_defs.h"
@ -21,26 +22,113 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class OpDesc;
class BlockDesc;
// default infer var type context
class InferVarTypeContext {
public:
InferVarTypeContext(const OpDesc* op, BlockDesc* block)
: op_(op), block_(block) {}
Attribute GetAttr(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(op_);
return op_->GetAttr(name);
}
inline bool HasVar(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(block_);
return block_->FindVarRecursive(name) != nullptr;
}
inline bool HasInput(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(op_);
return op_->Inputs().count(name) > 0;
}
inline bool HasOutput(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(op_);
return op_->Outputs().count(name) > 0;
}
inline const std::vector<std::string>& Input(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(op_);
return op_->Input(name);
}
inline const std::vector<std::string>& Output(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(op_);
return op_->Output(name);
}
inline proto::VarType::Type GetType(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(block_);
return block_->FindRecursiveOrCreateVar(name).GetType();
}
inline void SetType(const std::string& name, proto::VarType::Type type) {
PADDLE_ENFORCE_NOT_NULL(block_);
block_->FindRecursiveOrCreateVar(name).SetType(type);
}
inline proto::VarType::Type GetDataType(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(block_);
return block_->FindRecursiveOrCreateVar(name).GetDataType();
}
inline void SetDataType(const std::string& name, proto::VarType::Type type) {
PADDLE_ENFORCE_NOT_NULL(block_);
block_->FindRecursiveOrCreateVar(name).SetDataType(type);
}
inline std::vector<int64_t> GetShape(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(block_);
return block_->FindRecursiveOrCreateVar(name).GetShape();
}
inline void SetShape(const std::string& name,
const std::vector<int64_t>& dims) {
PADDLE_ENFORCE_NOT_NULL(block_);
block_->FindRecursiveOrCreateVar(name).SetShape(dims);
}
inline int32_t GetLoDLevel(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(block_);
return block_->FindRecursiveOrCreateVar(name).GetLoDLevel();
}
inline void SetLoDLevel(const std::string& name, int32_t lod_level) {
PADDLE_ENFORCE_NOT_NULL(block_);
block_->FindRecursiveOrCreateVar(name).SetLoDLevel(lod_level);
}
private:
const OpDesc* op_;
BlockDesc* block_;
};
// infer var type context for imperative mode
class RuntimeInferVarTypeContext : public InferVarTypeContext {
public:
RuntimeInferVarTypeContext() : InferVarTypeContext(nullptr, nullptr) {}
};
class VarTypeInference { class VarTypeInference {
public: public:
virtual ~VarTypeInference() {} virtual ~VarTypeInference() {}
virtual void operator()(const OpDesc& op_desc, BlockDesc* block) const = 0; virtual void operator()(InferVarTypeContext& context) const = 0; // NOLINT
}; };
class PassInDtypeAndVarTypeToOutput : public framework::VarTypeInference { class PassInDtypeAndVarTypeToOutput : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc& op_desc, void operator()(framework::InferVarTypeContext& ctx) const final { // NOLINT
framework::BlockDesc* block) const final {
auto in_out_var_names = this->GetInputOutputWithSameType(); auto in_out_var_names = this->GetInputOutputWithSameType();
for (auto& i_o_n : in_out_var_names) { for (auto& i_o_n : in_out_var_names) {
auto& x_name = op_desc.Input(i_o_n.first).at(0); auto& x_name = ctx.Input(i_o_n.first).at(0);
auto& out_name = op_desc.Output(i_o_n.second).at(0); auto& out_name = ctx.Output(i_o_n.second).at(0);
auto& x = block->FindRecursiveOrCreateVar(x_name); ctx.SetType(out_name, ctx.GetType(x_name));
auto& out = block->FindRecursiveOrCreateVar(out_name); ctx.SetDataType(out_name, ctx.GetDataType(x_name));
out.SetType(x.GetType());
out.SetDataType(x.GetDataType());
} }
} }

@ -44,20 +44,20 @@ class SumOpMaker : public OpProtoAndCheckerMaker {
class SumOpVarTypeInference : public VarTypeInference { class SumOpVarTypeInference : public VarTypeInference {
public: public:
void operator()(const OpDesc &op_desc, BlockDesc *block) const override { void operator()(framework::InferVarTypeContext &ctx) const override {
auto &inputs = op_desc.Input("X"); auto &inputs = ctx.Input("X");
auto default_var_type = proto::VarType::SELECTED_ROWS; auto default_var_type = proto::VarType::SELECTED_ROWS;
bool any_input_is_lod_tensor = std::any_of( bool any_input_is_lod_tensor = std::any_of(
inputs.begin(), inputs.end(), [block](const std::string &name) { inputs.begin(), inputs.end(), [ctx](const std::string &name) {
return block->Var(name)->GetType() == proto::VarType::LOD_TENSOR; return ctx.GetType(name) == proto::VarType::LOD_TENSOR;
}); });
if (any_input_is_lod_tensor) { if (any_input_is_lod_tensor) {
default_var_type = proto::VarType::LOD_TENSOR; default_var_type = proto::VarType::LOD_TENSOR;
} }
auto out_var_name = op_desc.Output("Out").front(); auto out_var_name = ctx.Output("Out").front();
block->Var(out_var_name)->SetType(default_var_type); ctx.SetType(out_var_name, default_var_type);
} }
}; };
} // namespace framework } // namespace framework

@ -178,10 +178,10 @@ Beam Search Decode Operator. This Operator constructs the full hypotheses for
each source sentence by walking back along the LoDTensorArray Input(ids) each source sentence by walking back along the LoDTensorArray Input(ids)
whose lods can be used to restore the path in the beam search tree. whose lods can be used to restore the path in the beam search tree.
The Output(SentenceIds) and Output(SentenceScores) separately contain the The Output(SentenceIds) and Output(SentenceScores) separately contain the
generated id sequences and the corresponding scores. The shapes and lods of the generated id sequences and the corresponding scores. The shapes and lods of the
two LodTensor are same. The lod level is 2 and the two levels separately two LodTensor are same. The lod level is 2 and the two levels separately
indicate how many hypotheses each source sentence has and how many ids each indicate how many hypotheses each source sentence has and how many ids each
hypothesis has. hypothesis has.
)DOC"); )DOC");
} }
@ -203,15 +203,12 @@ class BeamSearchDecodeInferShape : public framework::InferShapeBase {
class BeamSearchDecodeInferVarType : public framework::VarTypeInference { class BeamSearchDecodeInferVarType : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc& op_desc, void operator()(framework::InferVarTypeContext& ctx) const override {
framework::BlockDesc* block) const override { for (auto& o : ctx.Output("SentenceIds")) {
for (auto& o : op_desc.Output("SentenceIds")) { ctx.SetType(o, framework::proto::VarType::LOD_TENSOR);
auto& sentence_ids = block->FindRecursiveOrCreateVar(o);
sentence_ids.SetType(framework::proto::VarType::LOD_TENSOR);
} }
for (auto& o : op_desc.Output("SentenceScores")) { for (auto& o : ctx.Output("SentenceScores")) {
auto& sentence_scores = block->FindRecursiveOrCreateVar(o); ctx.SetType(o, framework::proto::VarType::LOD_TENSOR);
sentence_scores.SetType(framework::proto::VarType::LOD_TENSOR);
} }
} }
}; };

@ -65,7 +65,7 @@ class BeamSearchOpMaker : public framework::OpProtoAndCheckerMaker {
.SetDefault(true); .SetDefault(true);
AddComment(R"DOC( AddComment(R"DOC(
This operator does the search in beams for one time step. This operator does the search in beams for one time step.
Specifically, it selects the top-K candidate word ids of current step from Specifically, it selects the top-K candidate word ids of current step from
Input(ids) according to their Input(scores) for all source sentences, Input(ids) according to their Input(scores) for all source sentences,
where K is Attr(beam_size) and Input(ids), Input(scores) are predicted results where K is Attr(beam_size) and Input(ids), Input(scores) are predicted results
@ -120,15 +120,12 @@ class BeamSearchOp : public framework::OperatorWithKernel {
class BeamSearchInferVarType : public framework::VarTypeInference { class BeamSearchInferVarType : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc &op_desc, void operator()(framework::InferVarTypeContext &ctx) const override {
framework::BlockDesc *block) const override { for (auto &o : ctx.Output("selected_ids")) {
for (auto &o : op_desc.Output("selected_ids")) { ctx.SetType(o, framework::proto::VarType::LOD_TENSOR);
auto &selected_ids = block->FindRecursiveOrCreateVar(o);
selected_ids.SetType(framework::proto::VarType::LOD_TENSOR);
} }
for (auto &o : op_desc.Output("selected_scores")) { for (auto &o : ctx.Output("selected_scores")) {
auto &selected_scores = block->FindRecursiveOrCreateVar(o); ctx.SetType(o, framework::proto::VarType::LOD_TENSOR);
selected_scores.SetType(framework::proto::VarType::LOD_TENSOR);
} }
} }
}; };

@ -93,11 +93,9 @@ execution.
class GetPlacesInferVarType : public framework::VarTypeInference { class GetPlacesInferVarType : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc &op_desc, void operator()(framework::InferVarTypeContext &ctx) const override {
framework::BlockDesc *block) const override { for (auto &o_name : ctx.Output("Out")) {
for (auto &o_name : op_desc.Output("Out")) { ctx.SetType(o_name, framework::proto::VarType::PLACE_LIST);
block->FindRecursiveOrCreateVar(o_name).SetType(
framework::proto::VarType::PLACE_LIST);
} }
} }
}; };

@ -100,16 +100,13 @@ class WriteToArrayInferShape : public framework::InferShapeBase {
class WriteToArrayInferVarType : public framework::VarTypeInference { class WriteToArrayInferVarType : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc &op_desc, void operator()(framework::InferVarTypeContext &ctx) const override {
framework::BlockDesc *block) const override { auto x_name = ctx.Input("X")[0];
auto x_name = op_desc.Input("X")[0]; auto out_name = ctx.Output("Out")[0];
auto out_name = op_desc.Output("Out")[0];
VLOG(10) << "Set Variable " << out_name << " as LOD_TENSOR_ARRAY"; VLOG(10) << "Set Variable " << out_name << " as LOD_TENSOR_ARRAY";
auto &out = block->FindRecursiveOrCreateVar(out_name); ctx.SetType(out_name, framework::proto::VarType::LOD_TENSOR_ARRAY);
out.SetType(framework::proto::VarType::LOD_TENSOR_ARRAY); if (ctx.HasVar(x_name)) {
auto *x = block->FindVarRecursive(x_name); ctx.SetDataType(out_name, ctx.GetDataType(x_name));
if (x != nullptr) {
out.SetDataType(x->GetDataType());
} }
} }
}; };

@ -114,11 +114,10 @@ class MergeIdsOp : public framework::OperatorWithKernel {
class MergeIdsOpInferVarType : public framework::VarTypeInference { class MergeIdsOpInferVarType : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc &op_desc, void operator()(framework::InferVarTypeContext &ctx) const override {
framework::BlockDesc *block) const override { auto input_type = ctx.GetType(ctx.Input("Ids")[0]);
auto *input_var = block->Var(op_desc.Input("Ids")[0]); for (auto &out_var : ctx.Output("Out")) {
for (auto &out_var : op_desc.Output("Out")) { ctx.SetType(out_var, input_type);
block->Var(out_var)->SetType(input_var->GetType());
} }
} }
}; };

@ -71,11 +71,10 @@ class SplitIdsOp : public framework::OperatorWithKernel {
class SplitIdsOpInferVarType : public framework::VarTypeInference { class SplitIdsOpInferVarType : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc &op_desc, void operator()(framework::InferVarTypeContext &ctx) const override {
framework::BlockDesc *block) const override { auto input_type = ctx.GetType(ctx.Input("Ids")[0]);
auto *input_var = block->Var(op_desc.Input("Ids")[0]); for (auto &out_var : ctx.Output("Out")) {
for (auto &out_var : op_desc.Output("Out")) { ctx.SetType(out_var, input_type);
block->Var(out_var)->SetType(input_var->GetType());
} }
} }
}; };

@ -39,12 +39,11 @@ class FillConstantOp : public framework::OperatorWithKernel {
class FillConstantOpVarTypeInference : public framework::VarTypeInference { class FillConstantOpVarTypeInference : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc& op_desc, void operator()(framework::InferVarTypeContext& ctx) const override {
framework::BlockDesc* block) const override {
auto data_type = static_cast<framework::proto::VarType::Type>( auto data_type = static_cast<framework::proto::VarType::Type>(
boost::get<int>(op_desc.GetAttr("dtype"))); boost::get<int>(ctx.GetAttr("dtype")));
auto& out_var_name = op_desc.Output("Out").front(); auto& out_var_name = ctx.Output("Out").front();
block->Var(out_var_name)->SetDataType(data_type); ctx.SetDataType(out_var_name, data_type);
} }
}; };

@ -137,22 +137,20 @@ class FusedEmbeddingSeqPoolOpGrad : public framework::OperatorWithKernel {
class FusedEmbeddingSeqPoolOpGradVarTypeInference class FusedEmbeddingSeqPoolOpGradVarTypeInference
: public framework::VarTypeInference { : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc& op_desc, void operator()(framework::InferVarTypeContext& ctx) const override {
framework::BlockDesc* block) const override { auto out_var_name = ctx.Output(framework::GradVarName("W")).front();
auto out_var_name = op_desc.Output(framework::GradVarName("W")).front(); auto attr = ctx.GetAttr("is_sparse");
auto attr = op_desc.GetAttr("is_sparse");
bool is_sparse = boost::get<bool>(attr); bool is_sparse = boost::get<bool>(attr);
if (is_sparse) { if (is_sparse) {
VLOG(3) << "fused_embedding_seq_pool_grad op " VLOG(3) << "fused_embedding_seq_pool_grad op "
<< framework::GradVarName("W") << " is set to SelectedRows"; << framework::GradVarName("W") << " is set to SelectedRows";
block->Var(out_var_name) ctx.SetType(out_var_name, framework::proto::VarType::SELECTED_ROWS);
->SetType(framework::proto::VarType::SELECTED_ROWS);
} else { } else {
VLOG(3) << "fused_embedding_seq_pool_grad op " VLOG(3) << "fused_embedding_seq_pool_grad op "
<< framework::GradVarName("W") << " is set to LoDTensor"; << framework::GradVarName("W") << " is set to LoDTensor";
block->Var(out_var_name)->SetType(framework::proto::VarType::LOD_TENSOR); ctx.SetType(out_var_name, framework::proto::VarType::LOD_TENSOR);
} }
block->Var(out_var_name)->SetDataType(block->Var("W")->GetDataType()); ctx.SetDataType(out_var_name, ctx.GetDataType(ctx.Input("W")[0]));
} }
}; };

@ -81,15 +81,12 @@ GetTensorFromSelectedRows is used to get the tensor from SelectedRows.
class GetTensorFromSelectedRowsOpVarTypeInference class GetTensorFromSelectedRowsOpVarTypeInference
: public framework::VarTypeInference { : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc &op_desc, void operator()(framework::InferVarTypeContext &ctx) const { // NOLINT
framework::BlockDesc *block) const final { auto out_var_name = ctx.Output("Out").front();
auto out_var_name = op_desc.Output("Out").front(); auto in_var_name = ctx.Input("X").front();
auto in_var_name = op_desc.Input("X").front();
ctx.SetType(out_var_name, framework::proto::VarType::LOD_TENSOR);
auto out_var = block->FindRecursiveOrCreateVar(out_var_name); ctx.SetDataType(out_var_name, ctx.GetDataType(in_var_name));
auto in_var = block->FindRecursiveOrCreateVar(in_var_name);
out_var.SetType(framework::proto::VarType::LOD_TENSOR);
out_var.SetDataType(in_var.GetDataType());
} }
}; };

@ -197,38 +197,32 @@ class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel {
class HierarchicalSigmoidGradOpGradVarTypeInference class HierarchicalSigmoidGradOpGradVarTypeInference
: public framework::VarTypeInference { : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc& op_desc, void operator()(framework::InferVarTypeContext& ctx) const override {
framework::BlockDesc* block) const override { auto w_grad_var_name = ctx.Output(framework::GradVarName("W")).front();
auto w_grad_var_name = op_desc.Output(framework::GradVarName("W")).front(); auto bias_grad_var_name_vec = ctx.Output(framework::GradVarName("Bias"));
auto bias_grad_var_name_vec =
op_desc.Output(framework::GradVarName("Bias"));
std::string bias_grad_var_name; std::string bias_grad_var_name;
bool hasBias = false; bool hasBias = false;
if (bias_grad_var_name_vec.size()) { if (bias_grad_var_name_vec.size()) {
hasBias = true; hasBias = true;
bias_grad_var_name = bias_grad_var_name = ctx.Output(framework::GradVarName("Bias")).front();
op_desc.Output(framework::GradVarName("Bias")).front();
} }
auto attr = op_desc.GetAttr("is_sparse"); auto attr = ctx.GetAttr("is_sparse");
bool is_sparse = boost::get<bool>(attr); bool is_sparse = boost::get<bool>(attr);
if (is_sparse) { if (is_sparse) {
VLOG(30) << "hierarchical_sigmoid_grad op " << framework::GradVarName("W") VLOG(30) << "hierarchical_sigmoid_grad op " << framework::GradVarName("W")
<< " is set to SelectedRows"; << " is set to SelectedRows";
block->Var(w_grad_var_name) ctx.SetType(w_grad_var_name, framework::proto::VarType::SELECTED_ROWS);
->SetType(framework::proto::VarType::SELECTED_ROWS);
} else { } else {
VLOG(30) << "hierarchical_sigmoid_grad op " << framework::GradVarName("W") VLOG(30) << "hierarchical_sigmoid_grad op " << framework::GradVarName("W")
<< " is set to LoDTensor"; << " is set to LoDTensor";
block->Var(w_grad_var_name) ctx.SetType(w_grad_var_name, framework::proto::VarType::LOD_TENSOR);
->SetType(framework::proto::VarType::LOD_TENSOR);
} }
if (hasBias) { if (hasBias) {
VLOG(30) << "hierarchical_sigmoid_grad op " VLOG(30) << "hierarchical_sigmoid_grad op "
<< framework::GradVarName("Bias") << " is set to LoDTensor"; << framework::GradVarName("Bias") << " is set to LoDTensor";
block->Var(bias_grad_var_name) ctx.SetType(bias_grad_var_name, framework::proto::VarType::LOD_TENSOR);
->SetType(framework::proto::VarType::LOD_TENSOR);
} }
block->Var(w_grad_var_name)->SetDataType(block->Var("W")->GetDataType()); ctx.SetDataType(w_grad_var_name, ctx.GetDataType(ctx.Input("W")[0]));
} }
}; };

@ -64,11 +64,9 @@ class LoDRankTableInferShape : public framework::InferShapeBase {
class LoDRankTableInferVarType : public framework::VarTypeInference { class LoDRankTableInferVarType : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc &op_desc, void operator()(framework::InferVarTypeContext &ctx) const override {
framework::BlockDesc *block) const override { for (auto &o : ctx.Output("Out")) {
for (auto &o : op_desc.Output("Out")) { ctx.SetType(o, framework::proto::VarType::LOD_RANK_TABLE);
block->FindRecursiveOrCreateVar(o).SetType(
framework::proto::VarType::LOD_RANK_TABLE);
} }
} }
}; };

@ -201,10 +201,9 @@ class LoDTensorToArrayInferShape : public framework::InferShapeBase {
class LoDTensorToArrayInferVarType : public framework::VarTypeInference { class LoDTensorToArrayInferVarType : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc &op_desc, void operator()(framework::InferVarTypeContext &ctx) const override {
framework::BlockDesc *block) const override { for (auto &out_var : ctx.Output("Out")) {
for (auto &out_var : op_desc.Output("Out")) { ctx.SetType(out_var, framework::proto::VarType::LOD_TENSOR_ARRAY);
block->Var(out_var)->SetType(framework::proto::VarType::LOD_TENSOR_ARRAY);
} }
} }
}; };

@ -147,22 +147,20 @@ class LookupTableOpGrad : public framework::OperatorWithKernel {
class LookupTableOpGradVarTypeInference : public framework::VarTypeInference { class LookupTableOpGradVarTypeInference : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc& op_desc, void operator()(framework::InferVarTypeContext& ctx) const override {
framework::BlockDesc* block) const override { auto out_var_name = ctx.Output(framework::GradVarName("W")).front();
auto out_var_name = op_desc.Output(framework::GradVarName("W")).front(); auto attr = ctx.GetAttr("is_sparse");
auto attr = op_desc.GetAttr("is_sparse");
bool is_sparse = boost::get<bool>(attr); bool is_sparse = boost::get<bool>(attr);
if (is_sparse) { if (is_sparse) {
VLOG(3) << "lookup_table_grad op " << framework::GradVarName("W") VLOG(3) << "lookup_table_grad op " << framework::GradVarName("W")
<< " is set to SelectedRows"; << " is set to SelectedRows";
block->Var(out_var_name) ctx.SetType(out_var_name, framework::proto::VarType::SELECTED_ROWS);
->SetType(framework::proto::VarType::SELECTED_ROWS);
} else { } else {
VLOG(3) << "lookup_table_grad op " << framework::GradVarName("W") VLOG(3) << "lookup_table_grad op " << framework::GradVarName("W")
<< " is set to LoDTensor"; << " is set to LoDTensor";
block->Var(out_var_name)->SetType(framework::proto::VarType::LOD_TENSOR); ctx.SetType(out_var_name, framework::proto::VarType::LOD_TENSOR);
} }
block->Var(out_var_name)->SetDataType(block->Var("W")->GetDataType()); ctx.SetDataType(out_var_name, ctx.GetDataType(ctx.Input("W")[0]));
} }
}; };

@ -237,23 +237,21 @@ class NCEOpGrad : public framework::OperatorWithKernel {
class NCEOpGradVarTypeInference : public framework::VarTypeInference { class NCEOpGradVarTypeInference : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc &op_desc, void operator()(framework::InferVarTypeContext &ctx) const override {
framework::BlockDesc *block) const override { auto weight_grad = ctx.Output(framework::GradVarName("Weight")).front();
auto weight_grad = op_desc.Output(framework::GradVarName("Weight")).front();
auto attr = op_desc.GetAttr("is_sparse"); auto attr = ctx.GetAttr("is_sparse");
bool is_sparse = boost::get<bool>(attr); bool is_sparse = boost::get<bool>(attr);
if (is_sparse) { if (is_sparse) {
VLOG(3) << "nce_op_grad op " << weight_grad << " and " VLOG(3) << "nce_op_grad op " << weight_grad << " and "
<< " is set to SelectedRows"; << " is set to SelectedRows";
block->Var(weight_grad) ctx.SetType(weight_grad, framework::proto::VarType::SELECTED_ROWS);
->SetType(framework::proto::VarType::SELECTED_ROWS);
} else { } else {
VLOG(3) << "nce_op_grad op " << weight_grad << " and " VLOG(3) << "nce_op_grad op " << weight_grad << " and "
<< " is set to LoDTensor"; << " is set to LoDTensor";
block->Var(weight_grad)->SetType(framework::proto::VarType::LOD_TENSOR); ctx.SetType(weight_grad, framework::proto::VarType::LOD_TENSOR);
} }
block->Var(weight_grad)->SetDataType(block->Var("Input")->GetDataType()); ctx.SetDataType(weight_grad, ctx.GetDataType(ctx.Input("Input")[0]));
} }
}; };

@ -37,8 +37,7 @@ class NgraphEngineOpMaker : public framework::OpProtoAndCheckerMaker {
class NgraphEngineInferVarType : public framework::VarTypeInference { class NgraphEngineInferVarType : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc &op_desc, void operator()(framework::InferVarTypeContext &ctx) const override {}
framework::BlockDesc *block) const override {}
}; };
} // namespace operators } // namespace operators

@ -56,9 +56,9 @@ This optimizer use LARS (https://arxiv.org/abs/1708.03888) to optimize each
weight using a local learning rate: weight using a local learning rate:
$$ $$
local\_lr = \eta * local\_lr = \eta *
\frac{\left \| param \right \|}{\left \| grad \right \| + \beta *\left \| param \right \|} \\ \frac{\left \| param \right \|}{\left \| grad \right \| + \beta *\left \| param \right \|} \\
velocity = mu * velocity + velocity = mu * velocity +
local\_lr * (grad + \beta * param) \\ local\_lr * (grad + \beta * param) \\
param = param - velocity. \\ param = param - velocity. \\
$$ $$
@ -72,8 +72,7 @@ use L2 regularizers in case of using LARS.
class LarsMomentumOpVarTypeInference : public framework::VarTypeInference { class LarsMomentumOpVarTypeInference : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc &op_desc, void operator()(framework::InferVarTypeContext &ctx) const override {}
framework::BlockDesc *block) const override {}
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle

@ -21,18 +21,14 @@ using Tensor = framework::Tensor;
class MomentumOpInferVarType : public framework::VarTypeInference { class MomentumOpInferVarType : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc& op_desc, void operator()(framework::InferVarTypeContext& ctx) const override {
framework::BlockDesc* block) const override { auto& input_var = ctx.Input("Param")[0];
auto input_var = op_desc.Input("Param")[0]; for (auto& out_var : ctx.Output("ParamOut")) {
for (auto& out_var : op_desc.Output("ParamOut")) { if (ctx.GetType(input_var) == framework::proto::VarType::SELECTED_ROWS) {
if (block->FindRecursiveOrCreateVar(input_var).GetType() == ctx.SetType(out_var, framework::proto::VarType::SELECTED_ROWS);
framework::proto::VarType::SELECTED_ROWS) { } else if (ctx.GetType(input_var) ==
block->FindRecursiveOrCreateVar(out_var).SetType(
framework::proto::VarType::SELECTED_ROWS);
} else if (block->FindRecursiveOrCreateVar(input_var).GetType() ==
framework::proto::VarType::LOD_TENSOR) { framework::proto::VarType::LOD_TENSOR) {
block->FindRecursiveOrCreateVar(out_var).SetType( ctx.SetType(out_var, framework::proto::VarType::LOD_TENSOR);
framework::proto::VarType::LOD_TENSOR);
} else { } else {
PADDLE_THROW( PADDLE_THROW(
"Only support LodTensor and SelectedRows, Unexpected Input Type."); "Only support LodTensor and SelectedRows, Unexpected Input Type.");

@ -50,20 +50,18 @@ class SGDOp : public framework::OperatorWithKernel {
class SGDOpInferVarType : public framework::VarTypeInference { class SGDOpInferVarType : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc &op_desc, void operator()(framework::InferVarTypeContext &ctx) const override {
framework::BlockDesc *block) const override { auto &input_var_n = ctx.Input("Param")[0];
auto input_var_n = op_desc.Input("Param")[0]; auto in_var_type = ctx.GetType(input_var_n);
auto in_var_type = block->FindRecursiveOrCreateVar(input_var_n).GetType();
PADDLE_ENFORCE(in_var_type == framework::proto::VarType::SELECTED_ROWS || PADDLE_ENFORCE(in_var_type == framework::proto::VarType::SELECTED_ROWS ||
in_var_type == framework::proto::VarType::LOD_TENSOR, in_var_type == framework::proto::VarType::LOD_TENSOR,
"The input Var's type should be LoDtensor or SelectedRows," "The input Var's type should be LoDtensor or SelectedRows,"
" but the received var(%s)'s type is %s", " but the received var(%s)'s type is %s",
input_var_n, in_var_type); input_var_n, in_var_type);
for (auto &out_var_n : op_desc.Output("ParamOut")) { for (auto &out_var_n : ctx.Output("ParamOut")) {
auto &out_var = block->FindRecursiveOrCreateVar(out_var_n); if (ctx.GetType(out_var_n) != in_var_type) {
if (out_var.GetType() != in_var_type) { ctx.SetType(out_var_n, in_var_type);
out_var.SetType(in_var_type);
} }
} }
} }

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save