Refactoring InferShape (#3946)

* init Infershape

* add static InferShape interface

* refactor add-op infershape

* add AttrReader

* add all maker's infershape

* add all InferShape

* add python infer api

* add VarDesc interface

* add python VarDesc and OpDesc interface

* update python code

* use infershape function to do shape inference

* clean code

* do not use pointer

* refine code of op_proto_maker

* add get_dims to VarDesc

* refine the code

* remove the dependency from operator to op registry

* remove OpProtoAndCheckerMaker from operator

* restore complete_add_op

* add shape_infer_impl.h

* code optimization

* remove const return value

* add fake BlockDesc class

* optimize code

* remove infer function in op_info

* move InferShapeContextImpl to operator.h

* optimize the interface of InferShapeContextBase

* add temperary interface of new infershape

* change add_op, clip_op, conv2d_op and activation_op

* change all operators InferShape

* fix SetDim

* update cos_sim_op

* update crop_op

* update lookup_table_op

* allocate tensor when call GetDim in InferShapeContext

* update modified_huber_loss_op

* update rowwise_add_op

* update mean_op

* update sequence_avg_pool_op

* typo

* remove old InferShape interface

* can compile

* fix or unit test

* clean code

* clean code

* remove const before InferShapeContext

* change InferenceContextBase to pointer

* rename RunTime to Runtime, code clean
update-doc-pybind
Qiao Longfei 8 years ago committed by GitHub
parent 86351037c9
commit 9a9d50a6ee

@ -45,6 +45,21 @@ inline AttrType AttrTypeID() {
Attribute GetAttrValue(const OpDesc::Attr& attr_desc);
class AttrReader {
public:
explicit AttrReader(const AttributeMap& attrs) : attrs_(attrs) {}
template <typename T>
inline const T& Get(const std::string& name) const {
PADDLE_ENFORCE(attrs_.count(name) != 0, "%s should be in AttributeMap",
name);
return boost::get<T>(attrs_.at(name));
}
private:
const AttributeMap& attrs_;
};
// check whether a value(attribute) fit a certain limit
template <typename T>
class GreaterThanChecker {

@ -174,4 +174,4 @@ TEST(OpRegistry, CustomChecker) {
op->Run(scope, dev_ctx);
int test_attr = op->Attr<int>("test_attr");
ASSERT_EQ(test_attr, 4);
}
}

@ -14,7 +14,6 @@ limitations under the License. */
#include "paddle/framework/operator.h"
#include <algorithm>
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace framework {
@ -33,6 +32,24 @@ ExecutionContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const {
}
#endif
const Tensor* GetTensorFromVar(const Variable* var) {
if (var->IsType<LoDTensor>()) {
return &var->Get<LoDTensor>();
}
PADDLE_ENFORCE(var->IsType<Tensor>(),
"The Input must be LoDTensor or Tensor.");
return &var->Get<Tensor>();
}
Tensor* GetTensorFromVar(Variable* var) {
if (var->IsType<LoDTensor>()) {
return var->GetMutable<LoDTensor>();
}
PADDLE_ENFORCE(var->IsType<Tensor>(),
"The Input must be LoDTensor or Tensor.");
return var->GetMutable<Tensor>();
}
std::string OperatorBase::Input(const std::string& name) const {
auto& ins = Inputs(name);
PADDLE_ENFORCE_LE(ins.size(), 1UL,

@ -24,6 +24,7 @@ limitations under the License. */
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/scope.h"
#include "paddle/framework/shape_inference.h"
#include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/place.h"
@ -56,6 +57,9 @@ class OperatorBase;
class InferShapeContext;
class ExecutionContext;
extern const Tensor* GetTensorFromVar(const Variable* var);
extern Tensor* GetTensorFromVar(Variable* var);
/**
* OperatorBase has the basic element that Net will call to do computation.
* Only CreateOperator from OpRegistry will new Operator directly. User
@ -262,15 +266,6 @@ class InferShapeContext {
return res;
}
const Tensor* GetTensorFromVar(const Variable* var) const {
if (var->IsType<LoDTensor>()) {
return &var->Get<LoDTensor>();
}
PADDLE_ENFORCE(var->IsType<Tensor>(),
"The Input(%s) must be LoDTensor or Tensor.");
return &var->Get<Tensor>();
}
void ShareLoD(const std::string& in, const std::string& out, size_t i = 0,
size_t j = 0) const {
PADDLE_ENFORCE_LT(i, InputSize(in));
@ -340,6 +335,78 @@ class ExecutionContext : public InferShapeContext {
const platform::DeviceContext& device_context_;
};
class RuntimeInferShapeContext : public InferShapeContextBase {
public:
RuntimeInferShapeContext(const OperatorBase& op, const Scope& scope)
: op_(op), scope_(scope) {}
bool HasInput(const std::string& name) const {
auto ipt = op_.Input(name);
auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
return var != nullptr;
}
bool HasOutput(const std::string& name) const {
auto ipt = op_.Output(name);
auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
return var != nullptr;
}
DDim GetInputDim(const std::string& name) const {
return GetDim(op_.Input(name));
}
void SetInputDim(const std::string& name, const DDim& dim) {
SetDim(op_.Input(name), dim);
}
DDim GetOutputDim(const std::string& name) const {
return GetDim(op_.Output(name));
}
void SetOutputDim(const std::string& name, const DDim& dim) {
SetDim(op_.Output(name), dim);
}
AttrReader Attrs() const { return AttrReader(op_.Attrs()); }
const std::vector<std::string>& Inputs(const std::string& name) const {
return op_.Inputs(name);
}
const std::vector<std::string>& Outputs(const std::string& name) const {
return op_.Outputs(name);
}
private:
template <bool Allocate>
Tensor* GetTensor(const std::string& name) const {
Tensor* t = nullptr;
auto* var = scope_.FindVar(name);
if (!var->IsType<LoDTensor>() && !var->IsType<Tensor>()) {
if (Allocate) {
t = var->GetMutable<LoDTensor>();
} else {
PADDLE_THROW("Variable(%s) should be tensor", name);
}
} else {
t = GetTensorFromVar(scope_.FindVar(name));
}
return t;
}
DDim GetDim(const std::string& name) const {
return GetTensor<false>(name)->dims();
}
void SetDim(const std::string& name, const DDim& dim) {
GetTensor<true>(name)->Resize(dim);
}
const OperatorBase& op_;
const Scope& scope_;
};
class OpKernel {
public:
/**
@ -383,8 +450,10 @@ class OperatorWithKernel : public OperatorBase {
const VariableNameMap& outputs, const AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
// runtime infershape
void InferShape(const Scope& scope) const override {
InferShape(InferShapeContext(*this, scope));
auto c = RuntimeInferShapeContext(*this, scope);
InferShape(&c);
}
void Run(const Scope& scope,
@ -406,7 +475,7 @@ class OperatorWithKernel : public OperatorBase {
}
protected:
virtual void InferShape(const InferShapeContext& ctx) const = 0;
virtual void InferShape(InferShapeContextBase* ctx) const = 0;
};
} // namespace framework

@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/framework/operator.h"
#include "gtest/gtest.h"
#include "paddle/framework/op_info.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
@ -114,7 +115,7 @@ class OpWithKernelTest : public OperatorWithKernel {
using OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext& ctx) const override {}
void InferShape(framework::InferShapeContextBase* ctx) const override {}
};
template <typename T1, typename T2>

@ -0,0 +1,82 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/ddim.h"
namespace paddle {
namespace framework {
class InferShapeContextBase {
public:
virtual ~InferShapeContextBase() {}
virtual bool HasInput(const std::string &name) const = 0;
virtual bool HasOutput(const std::string &name) const = 0;
virtual framework::DDim GetInputDim(const std::string &name) const = 0;
std::vector<framework::DDim> GetInputsDim(const std::string &name) const {
const std::vector<std::string> &names = Inputs(name);
return GetDims(names);
}
virtual void SetInputDim(const std::string &name,
const framework::DDim &dim) = 0;
void SetInputsDim(const std::string &name,
const std::vector<framework::DDim> &dims) {
auto &names = Inputs(name);
SetDims(names, dims);
}
virtual framework::DDim GetOutputDim(const std::string &name) const = 0;
std::vector<framework::DDim> GetOutputsDim(const std::string &name) const {
const std::vector<std::string> &names = Outputs(name);
return GetDims(names);
}
virtual void SetOutputDim(const std::string &name, const DDim &dim) = 0;
void SetOutputsDim(const std::string &name,
const std::vector<framework::DDim> &dims) {
auto &names = Outputs(name);
SetDims(names, dims);
}
virtual AttrReader Attrs() const = 0;
virtual const std::vector<std::string> &Inputs(
const std::string &name) const = 0;
virtual const std::vector<std::string> &Outputs(
const std::string &name) const = 0;
// TODO(qiao) implement this function
void ShareLoD(const std::string &in, const std::string &out, size_t i = 0,
size_t j = 0) const {}
protected:
virtual framework::DDim GetDim(const std::string &name) const = 0;
virtual void SetDim(const std::string &name, const framework::DDim &dim) = 0;
std::vector<framework::DDim> GetDims(
const std::vector<std::string> &names) const {
std::vector<framework::DDim> ret;
ret.reserve(names.size());
std::transform(
names.begin(), names.end(), std::back_inserter(ret),
[this](const std::string &name) { return this->GetDim(name); });
return ret;
}
void SetDims(const std::vector<std::string> &names,
const std::vector<framework::DDim> &dims) {
size_t length = names.size();
PADDLE_ENFORCE_EQ(length, dims.size());
for (size_t i = 0; i < length; ++i) {
SetDim(names[i], dims[i]);
}
}
};
} // namespace framework
} // namespace paddle

@ -22,25 +22,23 @@ class AccuracyOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(
ctx.InputVar("Inference"),
"Input(Inference) of AccuracyOp should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"),
"Input(Label) of AccuracyOp should not be null.");
PADDLE_ENFORCE_NOT_NULL(
ctx.OutputVar("Accuracy"),
"Output(Accuracy) of AccuracyOp should not be null.");
void InferShape(framework::InferShapeContextBase *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Inference"),
"Input(Inference) of AccuracyOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Label"),
"Input(Label) of AccuracyOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Accuracy"),
"Output(Accuracy) of AccuracyOp should not be null.");
auto *inference = ctx.Input<framework::Tensor>("Inference");
auto *label = ctx.Input<framework::Tensor>("Label");
auto inference_dim = ctx->GetInputDim("Inference");
auto label_dim = ctx->GetInputDim("Label");
PADDLE_ENFORCE_EQ(label->dims().size(), 1, "label must be a vector");
PADDLE_ENFORCE_EQ(inference->dims()[0], label->dims()[0],
PADDLE_ENFORCE_EQ(label_dim.size(), 1, "label must be a vector");
PADDLE_ENFORCE_EQ(inference_dim[0], label_dim[0],
"inference size must be the same as label size");
ctx.Output<framework::Tensor>("Accuracy")->Resize({1});
ctx.ShareLoD("Inference", /*->*/ "Accuracy");
ctx->SetOutputDim("Accuracy", {1});
ctx->ShareLoD("Inference", /*->*/ "Accuracy");
}
};

@ -22,10 +22,9 @@ class ActivationOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
ctx.Output<framework::Tensor>("Y")->Resize(
ctx.Input<framework::Tensor>("X")->dims());
ctx.ShareLoD("X", /*->*/ "Y");
void InferShape(framework::InferShapeContextBase *ctx) const override {
ctx->SetOutputDim("Y", ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*->*/ "Y");
}
};
@ -34,9 +33,8 @@ class ActivationOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
ctx.Output<framework::Tensor>(framework::GradVarName("X"))
->Resize(ctx.Input<framework::Tensor>("Y")->dims());
void InferShape(framework::InferShapeContextBase *ctx) const override {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("Y"));
}
};

@ -22,25 +22,23 @@ class AddOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"),
"Input(X) of AddOp should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"),
"Input(Y) of AddOp should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"),
"Output(Out) of AddOp should not be null.");
void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of AddOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) of AddOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of AddOp should not be null.");
PADDLE_ENFORCE_EQ(ctx.Input<Tensor>("X")->dims(),
ctx.Input<Tensor>("Y")->dims(),
auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
PADDLE_ENFORCE_EQ(x_dims, y_dims,
"Two input of Add Op's dimension must be same.");
ctx.Output<framework::Tensor>("Out")->Resize(
ctx.Input<Tensor>("X")->dims());
ctx->SetOutputDim("Out", x_dims);
}
};
class AddOpMaker : public framework::OpProtoAndCheckerMaker {
public:
AddOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
AddOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The first input of add op");
AddInput("Y", "The second input of add op");
@ -58,7 +56,7 @@ class AddOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {}
void InferShape(framework::InferShapeContextBase* ctx) const override {}
};
} // namespace operators

@ -22,24 +22,24 @@ class ClipOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"),
"Input(X) of ClipOp should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"),
"Output(Out) of ClipOp should not be null.");
auto x_dims = ctx.Input<Tensor>("X")->dims();
auto max = Attr<float>("max");
auto min = Attr<float>("min");
void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of ClipOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of ClipOp should not be null.");
auto x_dims = ctx->GetInputDim("X");
auto max = ctx->Attrs().Get<float>("max");
auto min = ctx->Attrs().Get<float>("min");
PADDLE_ENFORCE_LT(min, max, "max should be greater than min.");
ctx.Output<Tensor>("Out")->Resize(x_dims);
ctx.ShareLoD("X", /*->*/ "Out");
ctx->SetOutputDim("Out", x_dims);
ctx->ShareLoD("X", /*->*/ "Out");
}
};
template <typename AttrType>
class ClipOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ClipOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
ClipOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
"(Tensor)The input of clip op."
@ -61,14 +61,13 @@ class ClipOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null");
auto x_dims = ctx.Input<Tensor>("X")->dims();
auto *x_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
if (x_grad != nullptr) {
x_grad->Resize(x_dims);
void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null");
auto x_dims = ctx->GetInputDim("X");
if (ctx->HasOutput(framework::GradVarName("X"))) {
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
}
}
};

@ -24,31 +24,30 @@ class ConcatOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"),
"Output(Out) of ConcatOp should not be null.");
void InferShape(framework::InferShapeContextBase *ctx) const override {
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of ConcatOp should not be null.");
auto ins = ctx.MultiInput<framework::Tensor>("X");
auto *out = ctx.Output<framework::Tensor>("Out");
size_t axis = static_cast<size_t>(ctx.Attr<int>("axis"));
auto ins = ctx->GetInputsDim("X");
size_t axis = static_cast<size_t>(ctx->Attrs().Get<int>("axis"));
size_t n = ins.size();
PADDLE_ENFORCE_GT(n, 1, "Input tensors count should > 1.");
auto out_dims = ins[0]->dims();
auto out_dims = ins[0];
size_t in_zero_dims_size = out_dims.size();
for (size_t i = 1; i < n; i++) {
for (size_t j = 0; j < in_zero_dims_size; j++) {
if (j == axis) {
out_dims[axis] += ins[i]->dims()[j];
out_dims[axis] += ins[i][j];
continue;
}
PADDLE_ENFORCE_EQ(out_dims[j], ins[i]->dims()[j],
PADDLE_ENFORCE_EQ(out_dims[j], ins[i][j],
"Input tensors should have the same "
"elements except the specify axis.")
}
}
out->Resize(out_dims);
ctx->SetOutputDim("Out", out_dims);
}
};

@ -215,7 +215,7 @@ class CondOpProtoAndCheckerMaker : public framework::OpProtoAndCheckerMaker {
AddComment(R"DOC(
Sample dependent Cond Operator:
Given Cond[i] as a 1/0 vector to indicate true/false
The equation is:
The equation is:
Out[i] = subnet_t[i], if Cond[i] == true
Out[i] = subnet_t[i], if Cond[i] == false
)DOC");

@ -27,27 +27,25 @@ class Conv2DOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Input"),
"Input(Input) of Conv2DOp should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Filter"),
"Input(Filter) of Conv2DOp should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Output"),
"Output(Output) of Conv2DOp should not be null.");
auto in = ctx.Input<Tensor>("Input");
auto filter = ctx.Input<Tensor>("Filter");
auto out = ctx.Output<framework::Tensor>("Output");
std::vector<int> strides = Attr<std::vector<int>>("strides");
std::vector<int> paddings = Attr<std::vector<int>>("paddings");
int groups = Attr<int>("groups");
int input_channels = in->dims()[1];
int output_channels = filter->dims()[0];
PADDLE_ENFORCE_EQ(in->dims().size(), 4, "Conv2DOp input should be 4-D.");
PADDLE_ENFORCE_EQ(filter->dims().size(), 4,
"Conv2DOp filter should be 4-D.");
PADDLE_ENFORCE_EQ(input_channels, filter->dims()[1] * groups,
void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input"),
"Input(Input) of Conv2DOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Filter"),
"Input(Filter) of Conv2DOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Output"),
"Output(Output) of Conv2DOp should not be null.");
auto in_dims = ctx->GetInputDim("Input");
auto filter_dims = ctx->GetInputDim("Filter");
std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides");
std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
int groups = ctx->Attrs().Get<int>("groups");
int input_channels = in_dims[1];
int output_channels = filter_dims[0];
PADDLE_ENFORCE_EQ(in_dims.size(), 4, "Conv2DOp input should be 4-D.");
PADDLE_ENFORCE_EQ(filter_dims.size(), 4, "Conv2DOp filter should be 4-D.");
PADDLE_ENFORCE_EQ(input_channels, filter_dims[1] * groups,
"The number of input channels should be equal to filter "
"channels * groups.");
PADDLE_ENFORCE_EQ(
@ -55,17 +53,17 @@ class Conv2DOp : public framework::OperatorWithKernel {
"The number of output channels should be divided by groups.");
auto output_height =
outputSize(in->dims()[2], filter->dims()[2], paddings[0], strides[0]);
outputSize(in_dims[2], filter_dims[2], paddings[0], strides[0]);
auto output_width =
outputSize(in->dims()[3], filter->dims()[3], paddings[1], strides[1]);
out->Resize(
{in->dims()[0], filter->dims()[0], output_height, output_width});
outputSize(in_dims[3], filter_dims[3], paddings[1], strides[1]);
ctx->SetOutputDim(
"Output", {in_dims[0], filter_dims[0], output_height, output_width});
}
};
class Conv2DOpMaker : public framework::OpProtoAndCheckerMaker {
public:
Conv2DOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
Conv2DOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(
"Input",
@ -108,14 +106,15 @@ class Conv2DOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
auto in = ctx.Input<Tensor>("Input");
auto filter = ctx.Input<Tensor>("Filter");
auto d_in = ctx.Output<framework::Tensor>(framework::GradVarName("Input"));
auto d_filter =
ctx.Output<framework::Tensor>(framework::GradVarName("Filter"));
if (d_in) d_in->Resize(in->dims());
if (d_filter) d_filter->Resize(filter->dims());
void InferShape(framework::InferShapeContextBase* ctx) const override {
auto in_dims = ctx->GetInputDim("Input");
auto filter_dims = ctx->GetInputDim("Filter");
if (ctx->HasOutput(framework::GradVarName("Input"))) {
ctx->SetOutputDim(framework::GradVarName("Input"), in_dims);
}
if (ctx->HasOutput(framework::GradVarName("Filter"))) {
ctx->SetOutputDim(framework::GradVarName("Filter"), filter_dims);
}
}
};

@ -24,22 +24,22 @@ class CosSimOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
void InferShape(framework::InferShapeContextBase* ctx) const override {
// notnull check
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"),
"Input(X) of CosSimOp should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"),
"Input(Y) of CosSimOp should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"),
"Output(Out) of CosSimOp should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("XNorm"),
"Output(XNorm) of CosSimOp should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("YNorm"),
"Output(YNorm) of CosSimOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of CosSimOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Y"),
"Input(Y) of CosSimOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of CosSimOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("XNorm"),
"Output(XNorm) of CosSimOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("YNorm"),
"Output(YNorm) of CosSimOp should not be null.");
// shape check
auto x_dims = ctx.Input<Tensor>("X")->dims();
auto y_dims = ctx.Input<Tensor>("Y")->dims();
auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
PADDLE_ENFORCE_EQ(x_dims.size(), y_dims.size(),
"Ranks of Input(X) and Input(Y) must be equal.");
@ -54,16 +54,16 @@ class CosSimOp : public framework::OperatorWithKernel {
" just 1 (which will be broadcasted to match Input(X)).");
// resize tensor
ctx.Output<framework::Tensor>("Out")->Resize({x_dims[0], 1});
ctx.Output<framework::Tensor>("XNorm")->Resize({x_dims[0], 1});
ctx.Output<framework::Tensor>("YNorm")->Resize({y_dims[0], 1});
ctx.ShareLoD("X", /*->*/ "Out");
ctx->SetOutputDim("Out", {x_dims[0], 1});
ctx->SetOutputDim("XNorm", {x_dims[0], 1});
ctx->SetOutputDim("YNorm", {y_dims[0], 1});
ctx->ShareLoD("X", /*->*/ "Out");
}
};
class CosSimOpMaker : public framework::OpProtoAndCheckerMaker {
public:
CosSimOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
CosSimOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The 1st input of cos_sim op.");
AddInput("Y", "The 2nd input of cos_sim op.");
@ -98,27 +98,23 @@ class CosSimOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
void InferShape(framework::InferShapeContextBase* ctx) const override {
// notnull check
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Input(Y) must not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("XNorm"),
"Input(XNorm) must not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("YNorm"),
"Input(YNorm) must not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Out"),
"Input(Out) must not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
"Input(Out@GRAD) must not be null.");
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null.");
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) must not be null.");
PADDLE_ENFORCE(ctx->HasInput("XNorm"), "Input(XNorm) must not be null.");
PADDLE_ENFORCE(ctx->HasInput("YNorm"), "Input(YNorm) must not be null.");
PADDLE_ENFORCE(ctx->HasInput("Out"), "Input(Out) must not be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) must not be null.");
// shape check
auto x_dims = ctx.Input<Tensor>("X")->dims();
auto y_dims = ctx.Input<Tensor>("Y")->dims();
auto xnorm_dims = ctx.Input<Tensor>("XNorm")->dims();
auto ynorm_dims = ctx.Input<Tensor>("YNorm")->dims();
auto out_dims = ctx.Input<Tensor>("Out")->dims();
auto out_grad_dims =
ctx.Input<Tensor>(framework::GradVarName("Out"))->dims();
auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
auto xnorm_dims = ctx->GetInputDim("XNorm");
auto ynorm_dims = ctx->GetInputDim("YNorm");
auto out_dims = ctx->GetInputDim("Out");
auto out_grad_dims = ctx->GetInputDim(framework::GradVarName("Out"));
PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(),
"Ranks of Input(X) and Input(Y) must be equal.");
@ -143,10 +139,14 @@ class CosSimOpGrad : public framework::OperatorWithKernel {
"Shape of Input(Out@Grad) must be [X.Dim(0), 1].");
// resize tensor
auto *x_grad = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto *y_grad = ctx.Output<framework::Tensor>(framework::GradVarName("Y"));
if (x_grad) x_grad->Resize(x_dims);
if (y_grad) y_grad->Resize(y_dims);
auto x_grad_name = framework::GradVarName("X");
auto y_grad_name = framework::GradVarName("Y");
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, x_dims);
}
if (ctx->HasOutput(y_grad_name)) {
ctx->SetOutputDim(y_grad_name, y_dims);
}
}
};

@ -25,16 +25,14 @@ class CropOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"),
"Input(X) of CropOp should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"),
"Output(Out) of CropOp should not be null.");
auto x_dim = ctx.Input<Tensor>("X")->dims();
auto *y = ctx.Input<Tensor>("Y");
auto *out = ctx.Output<Tensor>("Out");
if (y == nullptr) {
auto shape = Attr<std::vector<int>>("shape");
void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of CropOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of CropOp should not be null.");
auto x_dim = ctx->GetInputDim("X");
if (!ctx->HasInput("Y")) {
auto shape = ctx->Attrs().Get<std::vector<int>>("shape");
PADDLE_ENFORCE_EQ(
int64_t(shape.size()), x_dim.size(),
"Shape size should be equal to dimention size of input tensor.");
@ -42,19 +40,20 @@ class CropOp : public framework::OperatorWithKernel {
for (size_t i = 0; i < shape.size(); ++i) {
tensor_shape[i] = static_cast<int64_t>(shape[i]);
}
out->Resize(framework::make_ddim(tensor_shape));
ctx->SetOutputDim("Out", framework::make_ddim(tensor_shape));
} else {
PADDLE_ENFORCE_EQ(framework::arity(x_dim), framework::arity(y->dims()),
auto y_dim = ctx->GetInputDim("Y");
PADDLE_ENFORCE_EQ(framework::arity(x_dim), framework::arity(y_dim),
"Tensor rank of both CropOp's "
"inputs must be same.");
out->Resize(y->dims());
ctx->SetOutputDim("Out", y_dim);
}
}
};
class CropOpMaker : public framework::OpProtoAndCheckerMaker {
public:
CropOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
CropOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
"The input of pad op. "
@ -78,12 +77,12 @@ class CropOpMaker : public framework::OpProtoAndCheckerMaker {
Crop Operator.
Crop input into output, as specified by offsets and shape.
There are two ways to set shape:
There are two ways to set shape:
1. referenc input: crop input X as shape as reference input.
The dimension of reference input should
The dimension of reference input should
be as same as input X.
2. shape list: crop input X by shape described by a list<int>.
The size of shape list should be as same as
The size of shape list should be as same as
dimension size of input X.
The input should be a k-D tensor(k > 0 and k < 7). As an example:
@ -94,15 +93,15 @@ Given:
[0, 3, 4, 0, 0]
[0, 0, 0, 0, 0]]
and
and
offsets = [0, 1]
and
shape = [2, 2]
then we get
then we get
Out = [[1, 2],
[3, 4]]
@ -116,14 +115,14 @@ class CropOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null");
auto x_dims = ctx.Input<Tensor>("X")->dims();
auto *x_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
if (x_grad != nullptr) {
x_grad->Resize(x_dims);
void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null");
auto x_dims = ctx->GetInputDim("X");
auto x_grad_name = framework::GradVarName("X");
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, x_dims);
}
}
};

@ -22,33 +22,30 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should be not null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"),
"Input(Label) should be not null.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Y"),
"Output(Y) should be not null.");
auto x = ctx.Input<Tensor>("X");
auto label = ctx.Input<Tensor>("Label");
PADDLE_ENFORCE_EQ(x->dims().size(), 2, "Input(X)'s rank should be 2.");
PADDLE_ENFORCE_EQ(label->dims().size(), 2,
"Input(Label)'s rank should be 2.");
PADDLE_ENFORCE_EQ(x->dims()[0], label->dims()[0],
void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null.");
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null.");
PADDLE_ENFORCE(ctx->HasOutput("Y"), "Output(Y) should be not null.");
auto x_dims = ctx->GetInputDim("X");
auto label_dims = ctx->GetInputDim("Label");
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank should be 2.");
PADDLE_ENFORCE_EQ(label_dims.size(), 2, "Input(Label)'s rank should be 2.");
PADDLE_ENFORCE_EQ(x_dims[0], label_dims[0],
"The 1st dimension of Input(X) and Input(Label) should "
"be equal.");
if (ctx.Attr<bool>("softLabel")) {
PADDLE_ENFORCE_EQ(x->dims()[1], label->dims()[1],
if (ctx->Attrs().Get<bool>("softLabel")) {
PADDLE_ENFORCE_EQ(x_dims[1], label_dims[1],
"If Attr(softLabel) == true, the 2nd dimension of "
"Input(X) and Input(Label) should be equal.");
} else {
PADDLE_ENFORCE_EQ(label->dims()[1], 1,
PADDLE_ENFORCE_EQ(label_dims[1], 1,
"If Attr(softLabel) == false, the 2nd dimension of "
"Input(Label) should be 1.");
}
ctx.Output<Tensor>("Y")->Resize({x->dims()[0], 1});
ctx.ShareLoD("X", /*->*/ "Y");
ctx->SetOutputDim("Y", {x_dims[0], 1});
ctx->ShareLoD("X", /*->*/ "Y");
}
};
@ -57,50 +54,45 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should be not null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"),
"Input(Label) should be not null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Y")),
"Input(Y@GRAD) shoudl be not null.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar(framework::GradVarName("X")),
"Output(X@GRAD) should be not null.");
auto x = ctx.Input<Tensor>("X");
auto label = ctx.Input<Tensor>("Label");
auto dy = ctx.Input<Tensor>(framework::GradVarName("Y"));
PADDLE_ENFORCE_EQ(x->dims().size(), 2, "Input(X)'s rank should be 2.");
PADDLE_ENFORCE_EQ(dy->dims().size(), 2,
"Input(Y@Grad)'s rank should be 2.");
PADDLE_ENFORCE_EQ(label->dims().size(), 2,
"Input(Label)'s rank should be 2.");
PADDLE_ENFORCE_EQ(x->dims()[0], label->dims()[0],
void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null.");
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")),
"Input(Y@GRAD) shoudl be not null.");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
"Output(X@GRAD) should be not null.");
auto x_dims = ctx->GetInputDim("X");
auto label_dims = ctx->GetInputDim("Label");
auto dy_dims = ctx->GetInputDim(framework::GradVarName("Y"));
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank should be 2.");
PADDLE_ENFORCE_EQ(dy_dims.size(), 2, "Input(Y@Grad)'s rank should be 2.");
PADDLE_ENFORCE_EQ(label_dims.size(), 2, "Input(Label)'s rank should be 2.");
PADDLE_ENFORCE_EQ(x_dims[0], label_dims[0],
"The 1st dimension of Input(X) and Input(Label) should "
"be equal.");
PADDLE_ENFORCE_EQ(x->dims()[0], dy->dims()[0],
PADDLE_ENFORCE_EQ(x_dims[0], dy_dims[0],
"The 1st dimension of Input(X) and Input(Y@Grad) should "
"be equal.");
PADDLE_ENFORCE_EQ(dy->dims()[1], 1,
PADDLE_ENFORCE_EQ(dy_dims[1], 1,
"The 2nd dimension of Input(Y@Grad) should be 1.");
if (ctx.Attr<bool>("softLabel")) {
PADDLE_ENFORCE_EQ(x->dims()[1], label->dims()[1],
if (ctx->Attrs().Get<bool>("softLabel")) {
PADDLE_ENFORCE_EQ(x_dims[1], label_dims[1],
"When Attr(softLabel) == true, the 2nd dimension of "
"Input(X) and Input(Label) should be equal.");
} else {
PADDLE_ENFORCE_EQ(label->dims()[1], 1,
PADDLE_ENFORCE_EQ(label_dims[1], 1,
"When Attr(softLabel) == false, the 2nd dimension of "
"Input(Label) should be 1.");
}
auto dx = ctx.Output<Tensor>(framework::GradVarName("X"));
dx->Resize(x->dims());
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
}
};
class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker {
public:
CrossEntropyOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
CrossEntropyOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
"(Tensor, default Tensor<float>), a 2-D tensor with shape N x D, "

@ -24,25 +24,25 @@ class DropoutOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null.");
PADDLE_ENFORCE_GE(ctx.Attr<float>("dropout_prob"), 0);
PADDLE_ENFORCE_LE(ctx.Attr<float>("dropout_prob"), 1);
auto dims = ctx.Input<Tensor>("X")->dims();
ctx.Output<Tensor>("Out")->Resize(dims);
if (ctx.Attr<bool>("is_training")) {
ctx.Output<Tensor>("Mask")->Resize(dims);
void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null.");
PADDLE_ENFORCE_GE(ctx->Attrs().Get<float>("dropout_prob"), 0);
PADDLE_ENFORCE_LE(ctx->Attrs().Get<float>("dropout_prob"), 1);
auto x_dims = ctx->GetInputDim("X");
ctx->SetOutputDim("Out", x_dims);
if (ctx->Attrs().Get<bool>("is_training") == 1) {
ctx->SetOutputDim("Mask", x_dims);
}
ctx.ShareLoD("X", /*->*/ "Out");
ctx->ShareLoD("X", /*->*/ "Out");
}
};
template <typename AttrType>
class DropoutOpMaker : public framework::OpProtoAndCheckerMaker {
public:
DropoutOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
DropoutOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddAttr<AttrType>("dropout_prob", "Probability of setting units to zero.")
.SetDefault(.5f);
@ -70,27 +70,26 @@ class DropoutOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.Attr<bool>("is_training"),
"GradOp is only callable when is_training is true");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Mask"), "Mask must not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
"Input(Out@GRAD) must not be null.");
PADDLE_ENFORCE_GE(ctx.Attr<AttrType>("dropout_prob"), 0);
PADDLE_ENFORCE_LE(ctx.Attr<AttrType>("dropout_prob"), 1);
auto x_dims = ctx.Input<Tensor>("X")->dims();
auto out_dims = ctx.Input<Tensor>(framework::GradVarName("Out"))->dims();
void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->Attrs().Get<bool>("is_training"), 1,
"GradOp is only callable when is_training is true");
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null.");
PADDLE_ENFORCE(ctx->HasInput("Mask"), "Mask must not be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) must not be null.");
PADDLE_ENFORCE_GE(ctx->Attrs().Get<AttrType>("dropout_prob"), 0);
PADDLE_ENFORCE_LE(ctx->Attrs().Get<AttrType>("dropout_prob"), 1);
auto x_dims = ctx->GetInputDim("X");
auto out_dims = ctx->GetInputDim(framework::GradVarName("Out"));
PADDLE_ENFORCE_EQ(x_dims, out_dims,
"Dimensions of Input(X) and Out@Grad must be the same.");
auto mask_dims = ctx.Input<Tensor>("Mask")->dims();
auto mask_dims = ctx->GetInputDim("Mask");
PADDLE_ENFORCE_EQ(x_dims, mask_dims,
"Dimensions of Input(X) and Mask must be the same.");
auto *x_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
x_grad->Resize(x_dims);
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
}
};

@ -202,21 +202,20 @@ class ElementwiseOp : public framework::OperatorWithKernel {
protected:
using Tensor = framework::Tensor;
void InferShape(const framework::InferShapeContext& ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"),
"Input(X) of elementwise op should not be null");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"),
"Input(Y) of elementwise op should not be null");
PADDLE_ENFORCE_NOT_NULL(
ctx.OutputVar("Out"),
"Output(Out) of elementwise op should not be null.");
auto x_dim = ctx.Input<Tensor>("X")->dims();
auto y_dim = ctx.Input<Tensor>("Y")->dims();
void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of elementwise op should not be null");
PADDLE_ENFORCE(ctx->HasInput("Y"),
"Input(Y) of elementwise op should not be null");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of elementwise op should not be null.");
auto x_dim = ctx->GetInputDim("X");
auto y_dim = ctx->GetInputDim("Y");
PADDLE_ENFORCE_GE(x_dim.size(), y_dim.size(),
"Rank of first input must >= rank of second input.")
ctx.Output<framework::Tensor>("Out")->Resize(x_dim);
ctx.ShareLoD("X", /*->*/ "Out");
ctx->SetOutputDim("Out", x_dim);
ctx->ShareLoD("X", /*->*/ "Out");
}
};
@ -234,7 +233,7 @@ must be small or equal to X's dimensions.
)DOC");
AddAttr<int>("axis",
R"DOC(
When the shape(Y) does not equal the shape(X),Y will be broadcasted
When the shape(Y) does not equal the shape(X),Y will be broadcasted
to match the shape of X and axis should be dimension index Y in X
)DOC")
.SetDefault(-1)
@ -244,7 +243,7 @@ to match the shape of X and axis should be dimension index Y in X
comment_ = R"DOC(
Limited elementwise {name} operator.The equation is: Out = {equation}.
1. The shape of Y should be same with X or
2. Y's shape is a subset of X.
2. Y's shape is a subset of X.
Y will be broadcasted to match the shape of X and axis should be dimension index Y in X.
example:
@ -284,27 +283,26 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
using Tensor = framework::Tensor;
protected:
void InferShape(const framework::InferShapeContext& ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Input(Y) should not be null");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null");
auto x_dims = ctx.Input<Tensor>("X")->dims();
auto y_dims = ctx.Input<Tensor>("Y")->dims();
auto out_dims = ctx.Input<Tensor>(framework::GradVarName("Out"))->dims();
auto* x_grad = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto* y_grad = ctx.Output<framework::Tensor>(framework::GradVarName("Y"));
void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null");
auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
auto out_dims = ctx->GetInputDim(framework::GradVarName("Out"));
PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(),
"Rank of first input must >= rank of second input.")
if (x_grad) {
x_grad->Resize(x_dims);
auto x_grad_name = framework::GradVarName("X");
auto y_grad_name = framework::GradVarName("Y");
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, x_dims);
}
if (y_grad) {
y_grad->Resize(y_dims);
if (ctx->HasOutput(y_grad_name)) {
ctx->SetOutputDim(y_grad_name, y_dims);
}
}
};

@ -22,15 +22,13 @@ class FillZerosLikeOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"),
"Input(X) of FillZerosLikeOp should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Y"),
"Output(Y) of FillZerosLikeOp should not be null.");
ctx.Output<framework::Tensor>("Y")->Resize(
ctx.Input<framework::Tensor>("X")->dims());
ctx.ShareLoD("X", /*->*/ "Y");
void InferShape(framework::InferShapeContextBase *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of FillZerosLikeOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Y"),
"Output(Y) of FillZerosLikeOp should not be null.");
ctx->SetOutputDim("Y", ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*->*/ "Y");
}
};

@ -23,19 +23,19 @@ class GatherOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"),
"Input(X) of GatherOp should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Index"),
"Input(Index) of GatherOp should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"),
"Output(Out) of GatherOp should not be null.");
void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of GatherOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Index"),
"Input(Index) of GatherOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of GatherOp should not be null.");
int batch_size = ctx.Input<Tensor>("Index")->dims()[0];
int batch_size = ctx->GetInputDim("Index")[0];
PADDLE_ENFORCE_GE(batch_size, 0, "Batch size must be >0");
framework::DDim output_dims(ctx.Input<Tensor>("X")->dims());
framework::DDim output_dims(ctx->GetInputDim("X"));
output_dims[0] = batch_size;
ctx.Output<framework::Tensor>("Out")->Resize(output_dims);
ctx->SetOutputDim("Out", output_dims);
}
};
@ -44,23 +44,20 @@ class GatherGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
auto X_grad = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto X = ctx.Input<Tensor>("X");
X_grad->Resize(X->dims());
void InferShape(framework::InferShapeContextBase* ctx) const override {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
};
class GatherOpMaker : public framework::OpProtoAndCheckerMaker {
public:
GatherOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
GatherOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The source input of gather op");
AddInput("Index", "The index input of gather op");
AddOutput("Out", "The output of add op");
AddComment(R"DOC(
Gather Operator by selecting from the first axis,
Gather Operator by selecting from the first axis,
Out = X[Index]
)DOC");

@ -43,13 +43,10 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext& ctx) const override {
PADDLE_ENFORCE_NOT_NULL(
ctx.OutputVar("Out"),
"Output(Out) of GaussianRandomOp should not be null.");
auto* tensor = ctx.Output<framework::Tensor>("Out");
auto dims = Attr<std::vector<int>>("dims");
void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of GaussianRandomOp should not be null.");
auto dims = ctx->Attrs().Get<std::vector<int>>("dims");
std::vector<int64_t> temp;
temp.reserve(dims.size());
for (auto dim : dims) {
@ -57,7 +54,7 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
}
PADDLE_ENFORCE(dims.size() > 0UL,
"dims can be one int or array. dims must be set.");
tensor->Resize(framework::make_ddim(temp));
ctx->SetOutputDim("Out", framework::make_ddim(temp));
}
};

@ -22,27 +22,26 @@ class LookupTableOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("W"),
"Input(W) of LookupTableOp should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Ids"),
"Input(Ids) of LookupTableOp should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"),
"Output(Out) of LookupTableOp should not be null.");
auto table_t = ctx.Input<Tensor>("W");
auto ids_t = ctx.Input<Tensor>("Ids");
auto output_t = ctx.Output<framework::Tensor>("Out");
output_t->Resize({ids_t->dims()[0], table_t->dims()[1]});
ctx.ShareLoD("Ids", /*->*/ "Out");
void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("W"),
"Input(W) of LookupTableOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Ids"),
"Input(Ids) of LookupTableOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of LookupTableOp should not be null.");
auto table_dims = ctx->GetInputDim("W");
auto ids_dims = ctx->GetInputDim("Ids");
ctx->SetOutputDim("Out", {ids_dims[0], table_dims[1]});
ctx->ShareLoD("Ids", /*->*/ "Out");
}
};
class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
public:
LookupTableOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
LookupTableOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("W",
"An input represents embedding tensors,"
@ -66,11 +65,9 @@ class LookupTableOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &context) const override {
auto table = context.Input<Tensor>("W");
auto d_table =
context.Output<framework::Tensor>(framework::GradVarName("W"));
d_table->Resize(table->dims());
void InferShape(framework::InferShapeContextBase* ctx) const override {
auto table_dims = ctx->GetInputDim("W");
ctx->SetOutputDim(framework::GradVarName("W"), table_dims);
}
};

@ -22,37 +22,36 @@ class LstmUnitOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"),
"Input(X) of LSTM should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("C_prev"),
"Input(C_prev) of LSTM should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("C"),
"Output(C) of LSTM should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("H"),
"Output(H) of LSTM should not be null.");
auto *x = ctx.Input<framework::Tensor>("X");
auto *c_prev = ctx.Input<framework::Tensor>("C_prev");
PADDLE_ENFORCE_EQ(x->dims().size(), 2, "Input(X)'s rank must be 2.");
PADDLE_ENFORCE(x->dims()[0] == c_prev->dims()[0],
void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("C_prev"),
"Input(C_prev) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("C"),
"Output(C) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("H"),
"Output(H) of LSTM should not be null.");
auto x_dims = ctx->GetInputDim("X");
auto c_prev_dims = ctx->GetInputDim("C_prev");
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2.");
PADDLE_ENFORCE(x_dims[0] == c_prev_dims[0],
"Batch size of inputs and states must be equal");
PADDLE_ENFORCE(x->dims()[1] == c_prev->dims()[1] * 4,
PADDLE_ENFORCE(x_dims[1] == c_prev_dims[1] * 4,
"Dimension of FC should equal to prev state * 4");
int b_size = c_prev->dims()[0]; // batch size
int s_dim = c_prev->dims()[1]; // state dim
ctx.Output<framework::LoDTensor>("C")->Resize({b_size, s_dim});
ctx.Output<framework::LoDTensor>("H")->Resize({b_size, s_dim});
int b_size = c_prev_dims[0]; // batch size
int s_dim = c_prev_dims[1]; // state dim
ctx->SetOutputDim("C", {b_size, s_dim});
ctx->SetOutputDim("H", {b_size, s_dim});
}
};
template <typename AttrType>
class LstmUnitOpMaker : public framework::OpProtoAndCheckerMaker {
public:
LstmUnitOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
LstmUnitOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "FC input before the non-linear activation.");
AddInput(
@ -63,11 +62,11 @@ class LstmUnitOpMaker : public framework::OpProtoAndCheckerMaker {
AddComment(R"DOC(Lstm-Unit Operator
Equation:
Equation:
i, f, o, j = split(X)
C = C_prev * sigm(f + forget_bias) + sigm(i) * tanh(j)
H = C * sigm(o)
)DOC");
AddAttr<AttrType>("forget_bias", "The forget bias of Lstm Unit.")
.SetDefault(0.0);
@ -79,15 +78,14 @@ class LstmUnitGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("C")),
"Input(C@GRAD) should not be null");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("H")),
"Input(H@GRAD) should not be null");
ctx.Output<framework::LoDTensor>(framework::GradVarName("X"))
->Resize(ctx.Input<Tensor>("X")->dims());
ctx.Output<framework::LoDTensor>(framework::GradVarName("C_prev"))
->Resize(ctx.Input<Tensor>("C_prev")->dims());
void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("C")),
"Input(C@GRAD) should not be null");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("H")),
"Input(H@GRAD) should not be null");
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
ctx->SetOutputDim(framework::GradVarName("C_prev"),
ctx->GetInputDim("C_prev"));
}
};

@ -22,18 +22,18 @@ class MeanOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"),
"Input(X) of MeanOp should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"),
"Output(Out) of MeanOp should not be null.");
ctx.Output<framework::Tensor>("Out")->Resize({1});
void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of MeanOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of MeanOp should not be null.");
ctx->SetOutputDim("Out", {1});
}
};
class MeanOpMaker : public framework::OpProtoAndCheckerMaker {
public:
MeanOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
MeanOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input of mean op");
AddOutput("Out", "The output of mean op").NotInGradient();
@ -47,9 +47,8 @@ class MeanGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
ctx.Output<framework::Tensor>(framework::GradVarName("X"))
->Resize(ctx.Input<Tensor>("X")->dims());
void InferShape(framework::InferShapeContextBase* ctx) const override {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
};

@ -26,22 +26,22 @@ class MinusOp : public framework::OperatorWithKernel {
: OperatorWithKernel(type, inputs, outputs, attrs) {}
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"),
"Input(X) of MinusOp should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"),
"Input(Y) of MinusOp should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"),
"Output(Out) of MinusOp should not be null.");
void InferShape(framework::InferShapeContextBase *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of MinusOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Y"),
"Input(Y) of MinusOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of MinusOp should not be null.");
auto *left_tensor = ctx.Input<framework::Tensor>("X");
auto *right_tensor = ctx.Input<framework::Tensor>("Y");
auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
PADDLE_ENFORCE_EQ(
left_tensor->numel(), right_tensor->numel(),
x_dims, y_dims,
"Minus operator must take two tensor with same num of elements");
ctx.Output<framework::Tensor>("Out")->Resize(left_tensor->dims());
ctx.ShareLoD("X", /*->*/ "Out");
ctx->SetOutputDim("Out", x_dims);
ctx->ShareLoD("X", /*->*/ "Out");
}
};

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save