Merge pull request #4630 from jacquesqiao/merge-infershapecontext

Merge infershapecontext and ExecutionContext
revert-4814-Add_sequence_project_op
Qiao Longfei 8 years ago committed by GitHub
commit e12ec95ac1

@ -205,13 +205,13 @@ void OperatorBase::GenerateTemporaryNames() {
}
template <>
const Tensor* InferShapeContext::Input<Tensor>(const std::string& name) const {
const Tensor* ExecutionContext::Input<Tensor>(const std::string& name) const {
auto* var = InputVar(name);
return var == nullptr ? nullptr : GetTensorFromVar(var);
}
template <>
const std::vector<const Tensor*> InferShapeContext::MultiInput<Tensor>(
const std::vector<const Tensor*> ExecutionContext::MultiInput<Tensor>(
const std::string& name) const {
auto names = op().Inputs(name);
std::vector<const Tensor*> res;
@ -225,13 +225,13 @@ const std::vector<const Tensor*> InferShapeContext::MultiInput<Tensor>(
}
template <>
Tensor* InferShapeContext::Output<Tensor>(const std::string& name) const {
Tensor* ExecutionContext::Output<Tensor>(const std::string& name) const {
auto var = OutputVar(name);
return var == nullptr ? nullptr : var->GetMutable<LoDTensor>();
}
template <>
std::vector<Tensor*> InferShapeContext::MultiOutput<Tensor>(
std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>(
const std::string& name) const {
auto names = op().Outputs(name);
std::vector<Tensor*> res;

@ -57,7 +57,6 @@ inline std::string GradVarName(const std::string& var_name) {
}
class OperatorBase;
class InferShapeContext;
class ExecutionContext;
extern const Tensor* GetTensorFromVar(const Variable* var);
@ -169,10 +168,11 @@ class NOP : public OperatorBase {
}
};
class InferShapeContext {
class ExecutionContext {
public:
InferShapeContext(const OperatorBase& op, const Scope& scope)
: op_(op), scope_(scope) {}
ExecutionContext(const OperatorBase& op, const Scope& scope,
const platform::DeviceContext& device_context)
: op_(op), scope_(scope), device_context_(device_context) {}
const OperatorBase& op() const { return op_; }
@ -278,31 +278,6 @@ class InferShapeContext {
out_tensor->set_lod(in_tensor.lod());
}
private:
const OperatorBase& op_;
const Scope& scope_;
};
template <>
const Tensor* InferShapeContext::Input<Tensor>(const std::string& name) const;
template <>
const std::vector<const Tensor*> InferShapeContext::MultiInput<Tensor>(
const std::string& name) const;
template <>
Tensor* InferShapeContext::Output<Tensor>(const std::string& name) const;
template <>
std::vector<Tensor*> InferShapeContext::MultiOutput<Tensor>(
const std::string& name) const;
class ExecutionContext : public InferShapeContext {
public:
ExecutionContext(const OperatorBase& op, const Scope& scope,
const platform::DeviceContext& device_context)
: InferShapeContext(op, scope), device_context_(device_context) {}
template <typename PlaceType,
typename DeviceType = typename platform::EigenDeviceConverter<
PlaceType>::EigenDeviceType>
@ -315,10 +290,26 @@ class ExecutionContext : public InferShapeContext {
}
private:
const OperatorBase& op_;
const Scope& scope_;
const platform::DeviceContext& device_context_;
};
class CompileTimeInferShapeContext : public InferShapeContextBase {
template <>
const Tensor* ExecutionContext::Input<Tensor>(const std::string& name) const;
template <>
const std::vector<const Tensor*> ExecutionContext::MultiInput<Tensor>(
const std::string& name) const;
template <>
Tensor* ExecutionContext::Output<Tensor>(const std::string& name) const;
template <>
std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>(
const std::string& name) const;
class CompileTimeInferShapeContext : public InferShapeContext {
public:
CompileTimeInferShapeContext(const OpDescBind& op, const BlockDescBind& block)
: op_(op), block_(block) {}
@ -414,7 +405,7 @@ class CompileTimeInferShapeContext : public InferShapeContextBase {
const BlockDescBind& block_;
};
class RuntimeInferShapeContext : public InferShapeContextBase {
class RuntimeInferShapeContext : public InferShapeContext {
public:
RuntimeInferShapeContext(const OperatorBase& op, const Scope& scope)
: op_(op), scope_(scope) {}
@ -612,7 +603,7 @@ class OperatorWithKernel : public OperatorBase {
});
}
virtual void InferShape(InferShapeContextBase* ctx) const = 0;
virtual void InferShape(InferShapeContext* ctx) const = 0;
protected:
// indicate kernel DataType by input data. Defaultly all input data must be

@ -113,7 +113,7 @@ class OpWithKernelTest : public OperatorWithKernel {
using OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {}
void InferShape(framework::InferShapeContext* ctx) const override {}
DataType IndicateDataType(const ExecutionContext& ctx) const override {
return DataType::FP32;
}

@ -20,11 +20,11 @@ namespace paddle {
namespace framework {
// TODO(longfei): Once after both CompileTimeInferShapeContext and
// RuntimeInferShapeContext get merged, we can rename InferShapeContextBase into
// RuntimeInferShapeContext get merged, we can rename InferShapeContext into
// InferShapeContext so to replace the current InferShapeContext.
class InferShapeContextBase {
class InferShapeContext {
public:
virtual ~InferShapeContextBase() {}
virtual ~InferShapeContext() {}
virtual bool HasInput(const std::string &name) const = 0;
virtual bool HasOutput(const std::string &name) const = 0;

@ -22,7 +22,7 @@ class AccuracyOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase *ctx) const override {
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Inference"),
"Input(Inference) of AccuracyOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Label"),

@ -22,7 +22,7 @@ class ActivationOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase *ctx) const override {
void InferShape(framework::InferShapeContext *ctx) const override {
ctx->SetOutputDim("Y", ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*->*/ "Y");
}
@ -33,7 +33,7 @@ class ActivationOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase *ctx) const override {
void InferShape(framework::InferShapeContext *ctx) const override {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("Y"));
}
};

@ -22,7 +22,7 @@ class AdadeltaOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase *ctx) const override {
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Param"),
"Input(Param) of AdadeltaOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Grad"),

@ -22,7 +22,7 @@ class AdagradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase *ctx) const override {
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Param"),
"Input(Param) of AdagradOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Grad"),

@ -22,7 +22,7 @@ class ClipOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of ClipOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
@ -61,7 +61,7 @@ class ClipOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null");

@ -24,7 +24,7 @@ class ConcatOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase *ctx) const override {
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_GE(ctx->Inputs("X").size(), 1UL,
"Inputs(X) of ConcatOp should be empty.")
PADDLE_ENFORCE(ctx->HasOutput("Out"),
@ -83,7 +83,7 @@ class ConcatOpGrad : public framework::OperatorWithKernel {
: OperatorWithKernel(type, inputs, outputs, attrs) {}
protected:
void InferShape(framework::InferShapeContextBase *ctx) const override {
void InferShape(framework::InferShapeContext *ctx) const override {
ctx->SetOutputsDim(framework::GradVarName("X"), ctx->GetInputsDim("X"));
}
};

@ -27,7 +27,7 @@ class Conv2DOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input"),
"Input(Input) of Conv2DOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Filter"),
@ -106,7 +106,7 @@ class Conv2DOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
auto in_dims = ctx->GetInputDim("Input");
auto filter_dims = ctx->GetInputDim("Filter");
if (ctx->HasOutput(framework::GradVarName("Input"))) {

@ -24,7 +24,7 @@ class CosSimOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
// notnull check
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of CosSimOp should not be null.");
@ -98,7 +98,7 @@ class CosSimOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
// notnull check
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null.");
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) must not be null.");

@ -25,7 +25,7 @@ class CropOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of CropOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
@ -115,7 +115,7 @@ class CropOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null");

@ -22,7 +22,7 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null.");
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null.");
PADDLE_ENFORCE(ctx->HasOutput("Y"), "Output(Y) should be not null.");
@ -60,7 +60,7 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null.");
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")),

@ -24,7 +24,7 @@ class DropoutOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null.");
PADDLE_ENFORCE_GE(ctx->Attrs().Get<float>("dropout_prob"), 0);
PADDLE_ENFORCE_LE(ctx->Attrs().Get<float>("dropout_prob"), 1);
@ -70,7 +70,7 @@ class DropoutOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->Attrs().Get<bool>("is_training"), 1,
"GradOp is only callable when is_training is true");

@ -25,7 +25,7 @@ class ElementwiseOp : public framework::OperatorWithKernel {
protected:
using Tensor = framework::Tensor;
void InferShape(framework::InferShapeContextBase* ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of elementwise op should not be null");
PADDLE_ENFORCE(ctx->HasInput("Y"),
@ -106,7 +106,7 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
using Tensor = framework::Tensor;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),

@ -22,7 +22,7 @@ class FillZerosLikeOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase *ctx) const override {
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of FillZerosLikeOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Y"),

@ -23,7 +23,7 @@ class GatherOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of GatherOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Index"),
@ -51,7 +51,7 @@ class GatherGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}

@ -43,7 +43,7 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of GaussianRandomOp should not be null.");
auto dims = ctx->Attrs().Get<std::vector<int>>("dims");

@ -22,7 +22,7 @@ class LookupTableOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("W"),
"Input(W) of LookupTableOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Ids"),
@ -70,7 +70,7 @@ class LookupTableOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
auto table_dims = ctx->GetInputDim("W");
ctx->SetOutputDim(framework::GradVarName("W"), table_dims);
}

@ -22,7 +22,7 @@ class LstmUnitOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("C_prev"),
"Input(C_prev) of LSTM should not be null.");
@ -77,7 +77,7 @@ class LstmUnitGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("C")),
"Input(C@GRAD) should not be null");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("H")),

@ -22,7 +22,7 @@ class MeanOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of MeanOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
@ -47,7 +47,7 @@ class MeanGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
};

@ -26,7 +26,7 @@ class MinusOp : public framework::OperatorWithKernel {
: OperatorWithKernel(type, inputs, outputs, attrs) {}
protected:
void InferShape(framework::InferShapeContextBase *ctx) const override {
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of MinusOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Y"),

@ -22,7 +22,7 @@ class ModifiedHuberLossOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "X must be initialized.");
PADDLE_ENFORCE(ctx->HasInput("Y"), "Y must be initialized.");
@ -74,7 +74,7 @@ class ModifiedHuberLossGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "X must be initialized.");
PADDLE_ENFORCE(ctx->HasInput("Y"), "Y must be initialized.");
PADDLE_ENFORCE(ctx->HasInput("IntermediateVal"),

@ -24,7 +24,7 @@ class MulOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of MulOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) of MulOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
@ -97,7 +97,7 @@ class MulOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save