open dygraph op test, test=develop (#19787)

* open dygraph op test, test=develop

* modify to_variable, test=develop

* modify input and output for dygraph, test=develop

* modify input and output for dygraph(fix bug), test=develop

* fix input processing of dygraph op test, test=develop

* fix bug, test=develop

* fix op test, test=develop

* fix forward bug for dygraph, test=develop

* fix mkldnn op test for forward, test=develop

* update nn.py for dygraph, test=develop

* fix crop_tensor_op, test=develop

* fix elementwise_mul_op, test=develop

* fix fill_op, test=develop

* fix some mkldnn op, test=develop

* open backward op test for dygraph, test=develop

* delete log, test=develop

* close backward op test for dygraph, test=develop

* fix bug for edit_distance_op and test_lstm_cudnn_op, test=develop

* fix optest backward bug for dygraph, test=develop

* fix optest backward bug for dygraph, test=develop

* close backward op test for dygraph, test=develop

* close backward op test for dygraph, test=develop

* open dygraph op test, test=develop

* fix op test for dygraph, fix GradOpDescMaker, test=develop

* fix bug for linear_chain_crf_op.h, test=develop

* remove log, test=develop

* remove log, test=develop

* remove log for op_test.py, test=develop

* remove log for op_test.py, test=develop

* fix bug for var_conv_2d_op, change PADDLE_ENFORCE, test=develop

* fix PADDLE_ENFORCE_EQ for hierarchical_sigmoid_op.cc, test=develop

* fix bug for test_increment_ngraph_op.py, test=develop

* fix lod for op test in dygraph, test=develop

* refactor op_test.py to reduce redundant code, test=develop

* fix lod optest, modify InputVar/OutputVar to HasInput/HasOutput, test=develop

* remove debug log, test=develop

* remove redundant code in base.py, test=develop

* fix some error in optest, test=develop

* fix ClearNoNeedBufferInputs function's bug for LoDTensor, test=develop

* refactor op_test.py, test=develop

* remove redundant writing, test=develop

* fix error(get tensor of the grad variable), test=develop

* fix test_concat_mkldnn test_conv2d_mkldnn, test=develop

* fix optest.py for get tensor of LoDTensor, test=develop

* fix optest.py for get tensor of LoDTensor, test=develop

* fix optest.py for get tensor of LoDTensor, test=develop

* fix some redundant code, test=develop

* reslove conflict and rewrite paddle error message, test=develop
revert-21172-masked_select_api
zhongpu 5 years ago committed by hong
parent 3ab60f5bf9
commit c4ede95c74

@ -253,12 +253,14 @@ class RuntimeInferVarTypeContext : public framework::InferVarTypeContext {
}
bool HasInput(const std::string& name) const override {
return inputs_.count(name) > 0;
auto it = inputs_.find(name);
return (it != inputs_.end() && it->second.size() > 0);
}
bool HasOutput(const std::string& name) const override {
PADDLE_ENFORCE_NOT_NULL(outputs_);
return outputs_->count(name) > 0;
auto it = outputs_->find(name);
return (it != outputs_->end() && it->second.size() > 0);
}
const std::vector<std::string>& Input(

@ -53,6 +53,7 @@ static void ClearNoNeedBufferInputs(OpBase* op) {
new_var->MutableVar()->GetMutable<framework::LoDTensor>();
auto& old_tensor = var.Get<framework::LoDTensor>();
new_tensor->Resize(old_tensor.dims());
new_tensor->set_lod(old_tensor.lod());
each_var.reset(new_var);
}
}

@ -61,16 +61,30 @@ class HierarchicalSigmoidOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("W"), "Input(W) should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("PreOut"),
"Output(PreOut) should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
platform::errors::NotFound(
"Input(X) of HierarchicalSigmoidOp is not found."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("Label"), true,
platform::errors::NotFound(
"Input(Label) of HierarchicalSigmoidOp is not found."));
PADDLE_ENFORCE_EQ(ctx->HasInput("W"), true,
platform::errors::NotFound(
"Input(W) of HierarchicalSigmoidOp is not found."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("Out"), true,
platform::errors::NotFound(
"Output(Out) of HierarchicalSigmoidOp is not found."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("PreOut"), true,
platform::errors::NotFound(
"Output(PreOut) of HierarchicalSigmoidOp is not found."));
auto with_prefetch = ctx->Attrs().Get<bool>("remote_prefetch");
if (with_prefetch) {
PADDLE_ENFORCE(ctx->HasOutput("W_Out"),
"Output(W_Out) should not be null.");
PADDLE_ENFORCE_EQ(
ctx->HasOutput("W_Out"), true,
platform::errors::NotFound(
"Output(W_Out) of HierarchicalSigmoidOp is not found."));
}
const int64_t batch_size = ctx->GetInputDim("X")[0];
std::vector<int64_t> output_shape({batch_size, 1});
@ -202,16 +216,30 @@ class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("W"), "Input(W) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should not be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@Grad) should not be null");
PADDLE_ENFORCE(ctx->HasInput("PreOut"),
"Input(Preout) should not be null.");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("W")),
"Output(W@Grad should not be null.");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
"Output(X@Grad should not be null.");
PADDLE_ENFORCE_EQ(
ctx->HasInput("W"), true,
platform::errors::NotFound(
"Input(W) of HierarchicalSigmoidGradOp is not found."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("Label"), true,
platform::errors::NotFound(
"Input(Label) of HierarchicalSigmoidGradOp is not found."));
PADDLE_ENFORCE_EQ(
ctx->HasInput(framework::GradVarName("Out")), true,
platform::errors::NotFound(
"Input(Out@Grad) of HierarchicalSigmoidGradOp is not found."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("PreOut"), true,
platform::errors::NotFound(
"Input(Preout) of HierarchicalSigmoidGradOp is not found."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput(framework::GradVarName("W")), true,
platform::errors::NotFound(
"Output(W@Grad of HierarchicalSigmoidGradOp is not found."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput(framework::GradVarName("X")), true,
platform::errors::NotFound(
"Output(X@Grad of HierarchicalSigmoidGradOp is not found."));
if (ctx->HasOutput(framework::GradVarName("Bias"))) {
ctx->SetOutputDim(framework::GradVarName("Bias"),
@ -235,10 +263,10 @@ class HierarchicalSigmoidGradOpGradVarTypeInference
public:
void operator()(framework::InferVarTypeContext* ctx) const override {
auto w_grad_var_name = ctx->Output(framework::GradVarName("W")).front();
auto bias_grad_var_name_vec = ctx->Output(framework::GradVarName("Bias"));
auto has_bias_grad_var = ctx->HasOutput(framework::GradVarName("Bias"));
std::string bias_grad_var_name;
bool hasBias = false;
if (bias_grad_var_name_vec.size()) {
if (has_bias_grad_var) {
hasBias = true;
bias_grad_var_name = ctx->Output(framework::GradVarName("Bias")).front();
}

@ -29,12 +29,15 @@ class MinusOp : public framework::OperatorWithKernel {
: OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of MinusOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Y"),
"Input(Y) of MinusOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of MinusOp should not be null.");
PADDLE_ENFORCE_EQ(
ctx->HasInput("X"), true,
platform::errors::NotFound("Input(X) of MinusOp is not found."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("Y"), true,
platform::errors::NotFound("Input(Y) of MinusOp is not found."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("Out"), true,
platform::errors::NotFound("Output(Out) of MinusOp is not found."));
auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
@ -71,27 +74,57 @@ or not. But the output only shares the LoD information with input `X`.
}
};
class MinusGradMaker : public framework::GradOpDescMakerBase {
class MinusGradDescMaker : public framework::GradOpDescMakerBase {
public:
using framework::GradOpDescMakerBase::GradOpDescMakerBase;
std::vector<std::unique_ptr<framework::OpDesc>> operator()() const override {
std::vector<std::unique_ptr<framework::OpDesc>> ops;
auto x_g = InputGrad("X");
auto x_g = this->InputGrad("X");
if (!x_g.empty()) {
auto *x_g_op = new framework::OpDesc();
x_g_op->SetType("scale");
x_g_op->SetInput("X", OutputGrad("Out"));
x_g_op->SetInput("X", this->OutputGrad("Out"));
x_g_op->SetOutput("Out", x_g);
x_g_op->SetAttr("scale", 1.0f);
ops.emplace_back(x_g_op);
}
auto y_g = InputGrad("Y");
auto y_g = this->InputGrad("Y");
if (!y_g.empty()) {
auto *y_g_op = new framework::OpDesc();
y_g_op->SetType("scale");
y_g_op->SetInput("X", OutputGrad("Out"));
y_g_op->SetInput("X", this->OutputGrad("Out"));
y_g_op->SetOutput("Out", y_g);
y_g_op->SetAttr("scale", -1.0f);
ops.emplace_back(y_g_op);
}
return ops;
}
};
class MinusGradMaker : public imperative::GradOpBaseMakerBase {
public:
using imperative::GradOpBaseMakerBase::GradOpBaseMakerBase;
std::vector<std::unique_ptr<imperative::OpBase>> operator()() const override {
std::vector<std::unique_ptr<imperative::OpBase>> ops;
auto x_g = this->InputGrad("X");
if (!x_g.empty()) {
auto *x_g_op = new imperative::OpBase();
x_g_op->SetType("scale");
x_g_op->SetInput("X", this->OutputGrad("Out"));
x_g_op->SetOutput("Out", x_g);
x_g_op->SetAttr("scale", 1.0f);
ops.emplace_back(x_g_op);
}
auto y_g = this->InputGrad("Y");
if (!y_g.empty()) {
auto *y_g_op = new imperative::OpBase();
y_g_op->SetType("scale");
y_g_op->SetInput("X", this->OutputGrad("Out"));
y_g_op->SetOutput("Out", y_g);
y_g_op->SetAttr("scale", -1.0f);
ops.emplace_back(y_g_op);
@ -105,6 +138,7 @@ class MinusGradMaker : public framework::GradOpDescMakerBase {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(minus, ops::MinusOp, ops::MinusOpMaker, ops::MinusGradMaker);
REGISTER_OPERATOR(minus, ops::MinusOp, ops::MinusOpMaker,
ops::MinusGradDescMaker, ops::MinusGradMaker);
REGISTER_OP_CPU_KERNEL(
minus, ops::MinusKernel<paddle::platform::CPUDeviceContext, float>);

@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/nce_op.h"
#include <memory>
#include <string>
#include <vector>
@ -212,6 +213,33 @@ By default this operator uses a uniform distribution for sampling.
}
};
template <typename T>
class NCEGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
std::unique_ptr<T> Apply() const override {
auto *op = new T();
op->SetType(this->ForwardOpType() + "_grad");
op->SetInput("Input", this->Input("Input"));
op->SetInput("Label", this->Input("Label"));
op->SetInput("Bias", this->Input("Bias"));
op->SetInput("Weight", this->Input("Weight"));
op->SetInput("Cost", this->Output("Cost"));
op->SetInput("SampleLogits", this->Output("SampleLogits"));
op->SetInput("SampleLabels", this->Output("SampleLabels"));
op->SetInput("SampleWeight", this->Input("SampleWeight"));
op->SetInput("CustomDistProbs", this->Input("CustomDistProbs"));
op->SetInput("CustomDistAlias", this->Input("CustomDistAlias"));
op->SetInput("CustomDistAliasProbs", this->Input("CustomDistAliasProbs"));
op->SetInput(framework::GradVarName("Cost"), this->OutputGrad("Cost"));
op->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input"));
op->SetOutput(framework::GradVarName("Bias"), this->InputGrad("Bias"));
op->SetOutput(framework::GradVarName("Weight"), this->InputGrad("Weight"));
op->SetAttrMap(this->Attrs());
return std::unique_ptr<T>(op);
}
};
class NCEOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
@ -277,11 +305,9 @@ class NCEOpGradVarTypeInference : public framework::VarTypeInference {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(
nce, ops::NCEOp,
paddle::framework::DefaultGradOpMaker<paddle::framework::OpDesc, true>,
paddle::framework::DefaultGradOpMaker<paddle::imperative::OpBase, true>,
ops::NCEOpMaker);
REGISTER_OPERATOR(nce, ops::NCEOp, ops::NCEOpMaker,
ops::NCEGradOpMaker<paddle::framework::OpDesc>,
ops::NCEGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(nce_grad, ops::NCEOpGrad, ops::NCEOpGradVarTypeInference);
REGISTER_OP_CPU_KERNEL(nce, ops::NCEKernel<paddle::platform::CPUPlace, float>,
ops::NCEKernel<paddle::platform::CPUPlace, double>);

@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/unpool_op.h"
#include <memory>
#include <string>
#include <vector>
namespace paddle {
@ -82,14 +83,15 @@ class UnpoolOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of UnpoolOp"
"should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Indices"),
"Input(Indices) of UnpoolOp"
"should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of UnpoolOp should not be null.");
PADDLE_ENFORCE_EQ(
ctx->HasInput("X"), true,
platform::errors::NotFound("Input(X) of UnpoolOp is not found."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("Indices"), true,
platform::errors::NotFound("Input(Indices) of UnpoolOp is not found."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("Out"), true,
platform::errors::NotFound("Output(Out) of UnpoolOp is not found."));
auto in_x_dims = ctx->GetInputDim("X");
auto in_y_dims = ctx->GetInputDim("Indices");
std::string unpooling_type =
@ -97,8 +99,11 @@ class UnpoolOp : public framework::OperatorWithKernel {
std::vector<int> ksize = ctx->Attrs().Get<std::vector<int>>("ksize");
std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides");
std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
PADDLE_ENFORCE(in_x_dims.size() == 4,
"Unpooling intput must be of 4-dimensional.");
PADDLE_ENFORCE_EQ(in_x_dims.size() == 4, true,
platform::errors::InvalidArgument(
"Unpooling intput(X) must be of 4-dimensional, but "
"received X's dimension is %d.",
in_x_dims.size()));
PADDLE_ENFORCE_EQ(in_x_dims, in_y_dims);
std::vector<int64_t> output_shape({in_x_dims[0], in_x_dims[1]});
@ -114,6 +119,23 @@ class UnpoolOp : public framework::OperatorWithKernel {
}
};
template <typename T>
class UnpoolOpGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
std::unique_ptr<T> Apply() const override {
auto* op = new T();
op->SetType(this->ForwardOpType() + "_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("Indices", this->Input("Indices"));
op->SetInput("Out", this->Output("Out"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs());
return std::unique_ptr<T>(op);
}
};
class UnpoolOpGrad : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
@ -126,9 +148,12 @@ class UnpoolOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null.");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
"Input(X@GRAD) should not be null.");
PADDLE_ENFORCE_EQ(
ctx->HasInput("X"), true,
platform::errors::NotFound("Input(X) of UnpoolOpGradOp is not found."));
PADDLE_ENFORCE_EQ(ctx->HasOutput(framework::GradVarName("X")), true,
platform::errors::NotFound(
"Input(X@GRAD) of UnpoolOpGradOp is not found."));
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
};
@ -136,10 +161,9 @@ class UnpoolOpGrad : public framework::OperatorWithKernel {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(
unpool, ops::UnpoolOp, ops::Unpool2dOpMaker,
paddle::framework::DefaultGradOpMaker<paddle::framework::OpDesc, true>,
paddle::framework::DefaultGradOpMaker<paddle::imperative::OpBase, true>);
REGISTER_OPERATOR(unpool, ops::UnpoolOp, ops::Unpool2dOpMaker,
ops::UnpoolOpGradMaker<paddle::framework::OpDesc>,
ops::UnpoolOpGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(unpool_grad, ops::UnpoolOpGrad);
REGISTER_OP_CPU_KERNEL(

@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/var_conv_2d_op.h"
#include <memory>
#include <vector>
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/math_function.h"
@ -57,18 +58,24 @@ void VarConv2dOpMaker::Make() {
}
void VarConv2dOP::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(ctx->HasInput("X"),
"X(Input) of VarConv2dOP should not be null.");
PADDLE_ENFORCE(ctx->HasInput("W"),
"W(Input) of VarConv2dOP should not be null.");
PADDLE_ENFORCE(ctx->HasInput("ROW"),
"Input(ROW) of VarConv2dOP should not be null.");
PADDLE_ENFORCE(ctx->HasInput("COLUMN"),
"Input(COLUMN) of VarConv2dOP should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Out(Output) of VarConv2dOP should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Col"),
"Col(Output) of VarConv2dOP should not be null.");
PADDLE_ENFORCE_EQ(
ctx->HasInput("X"), true,
platform::errors::NotFound("X(Input) of VarConv2dOP is not found."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("W"), true,
platform::errors::NotFound("W(Input) of VarConv2dOP is not found."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("ROW"), true,
platform::errors::NotFound("Input(ROW) of VarConv2dOP is not found."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("COLUMN"), true,
platform::errors::NotFound("Input(COLUMN) of VarConv2dOP is not found."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("Out"), true,
platform::errors::NotFound("Out(Output) of VarConv2dOP is not found."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("Col"), true,
platform::errors::NotFound("Col(Output) of VarConv2dOP is not found."));
auto x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(x_dims.size(), 2,
@ -91,7 +98,10 @@ void VarConv2dOP::InferShape(framework::InferShapeContext* ctx) const {
framework::Variable* x_var =
boost::get<framework::Variable*>(ctx->GetInputVarPtrs("X")[0]);
const auto& x_lod = x_var->Get<LoDTensor>().lod();
PADDLE_ENFORCE(!x_lod.empty(), "The Input(X) must hold lod info.");
PADDLE_ENFORCE_EQ(
!x_lod.empty(), true,
platform::errors::InvalidArgument("The Input(X) Tensor of VarConv2dOP "
"does not contain LoD information."));
PADDLE_ENFORCE_GE(x_lod.size(), 1, "The Input(X)'s lod info is corrupted.");
PADDLE_ENFORCE_EQ(
@ -101,12 +111,18 @@ void VarConv2dOP::InferShape(framework::InferShapeContext* ctx) const {
framework::Variable* row_var =
boost::get<framework::Variable*>(ctx->GetInputVarPtrs("ROW")[0]);
const auto& row_lod = row_var->Get<LoDTensor>().lod();
PADDLE_ENFORCE(!row_lod.empty(), "The Input(ROW) must hold lod info.");
PADDLE_ENFORCE_EQ(!row_lod.empty(), true,
platform::errors::InvalidArgument(
"The Input(ROW) Tensor of VarConv2dOP does not "
"contain LoD information."));
framework::Variable* col_var =
boost::get<framework::Variable*>(ctx->GetInputVarPtrs("COLUMN")[0]);
const auto& col_lod = col_var->Get<LoDTensor>().lod();
PADDLE_ENFORCE(!col_lod.empty(), "The Input(COLUMN) must hold lod info.");
PADDLE_ENFORCE_EQ(!col_lod.empty(), true,
platform::errors::InvalidArgument(
"The Input(COLUMN) Tensor of VarConv2dOP does not "
"contain LoD information."));
} else {
std::vector<int64_t> out_dims_vec{-1};
out_dims_vec.push_back(1);
@ -280,13 +296,40 @@ class CPUVarConv2dOPKernel : public framework::OpKernel<T> {
}
};
template <typename T>
class VarConv2dGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
std::unique_ptr<T> Apply() const override {
auto* op = new T();
op->SetType(this->ForwardOpType() + "_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("W", this->Input("W"));
op->SetInput("ROW", this->Input("ROW"));
op->SetInput("COLUMN", this->Input("COLUMN"));
op->SetInput("Col", this->Output("Col"));
op->SetInput("Out", this->Output("Out"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetOutput(framework::GradVarName("W"), this->InputGrad("W"));
op->SetAttrMap(this->Attrs());
return std::unique_ptr<T>(op);
}
};
void VarConv2dOpGrad::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of SequencePadGradOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("W"),
"Input(W) of SequencePadGradOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) of SequencePadGradOp should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
platform::errors::NotFound(
"Input(X) of SequencePadGradOp is not found."));
PADDLE_ENFORCE_EQ(ctx->HasInput("W"), true,
platform::errors::NotFound(
"Input(W) of SequencePadGradOp is not found."));
PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true,
platform::errors::NotFound(
"Input(Out@GRAD) of SequencePadGradOp is not found."));
if (ctx->HasOutput(framework::GradVarName("X"))) {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
@ -416,10 +459,9 @@ class CPUVarConv2dOPGradKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators;
namespace plt = paddle::platform;
namespace frm = paddle::framework;
REGISTER_OPERATOR(
var_conv_2d, ops::VarConv2dOP, ops::VarConv2dOpMaker,
paddle::framework::DefaultGradOpMaker<paddle::framework::OpDesc, true>,
paddle::framework::DefaultGradOpMaker<paddle::imperative::OpBase, true>);
REGISTER_OPERATOR(var_conv_2d, ops::VarConv2dOP, ops::VarConv2dOpMaker,
ops::VarConv2dGradMaker<paddle::framework::OpDesc>,
ops::VarConv2dGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(var_conv_2d_grad, ops::VarConv2dOpGrad);
REGISTER_OP_CPU_KERNEL(var_conv_2d,

@ -28,6 +28,17 @@ class TestMKLDNNReluDim2(TestRelu):
self.attrs = {"use_mkldnn": True}
def test_check_output(self):
# TODO(wangzhongpu): support mkldnn op in dygraph mode
self.check_output(check_dygraph=False)
def test_check_grad(self):
if self.dtype == np.float16:
return
# TODO(wangzhongpu): support mkldnn op in dygraph mode
self.check_grad(
['X'], 'Out', max_relative_error=0.007, check_dygraph=False)
class TestMKLDNNLeakyReluDim2(TestLeakyRelu):
def setUp(self):
@ -35,6 +46,17 @@ class TestMKLDNNLeakyReluDim2(TestLeakyRelu):
self.attrs = {"use_mkldnn": True}
def test_check_output(self):
# TODO(wangzhongpu): support mkldnn op in dygraph mode
self.check_output(check_dygraph=False)
def test_check_grad(self):
if self.dtype == np.float16:
return
# TODO(wangzhongpu): support mkldnn op in dygraph mode
self.check_grad(
['X'], 'Out', max_relative_error=0.007, check_dygraph=False)
class TestMKLDNNTanhDim2(TestTanh):
def setUp(self):
@ -42,6 +64,17 @@ class TestMKLDNNTanhDim2(TestTanh):
self.attrs = {"use_mkldnn": True}
def test_check_output(self):
# TODO(wangzhongpu): support mkldnn op in dygraph mode
self.check_output(check_dygraph=False)
def test_check_grad(self):
if self.dtype == np.float16:
return
# TODO(wangzhongpu): support mkldnn op in dygraph mode
self.check_grad(
['X'], 'Out', max_relative_error=0.007, check_dygraph=False)
class TestMKLDNNSqrtDim2(TestSqrt):
def setUp(self):
@ -49,12 +82,34 @@ class TestMKLDNNSqrtDim2(TestSqrt):
self.attrs = {"use_mkldnn": True}
def test_check_output(self):
# TODO(wangzhongpu): support mkldnn op in dygraph mode
self.check_output(check_dygraph=False)
def test_check_grad(self):
if self.dtype == np.float16:
return
# TODO(wangzhongpu): support mkldnn op in dygraph mode
self.check_grad(
['X'], 'Out', max_relative_error=0.007, check_dygraph=False)
class TestMKLDNNAbsDim2(TestAbs):
def setUp(self):
super(TestMKLDNNAbsDim2, self).setUp()
self.attrs = {"use_mkldnn": True}
def test_check_output(self):
# TODO(wangzhongpu): support mkldnn op in dygraph mode
self.check_output(check_dygraph=False)
def test_check_grad(self):
if self.dtype == np.float16:
return
# TODO(wangzhongpu): support mkldnn op in dygraph mode
self.check_grad(
['X'], 'Out', max_relative_error=0.007, check_dygraph=False)
class TestMKLDNNReluDim4(TestRelu):
def setUp(self):
@ -69,6 +124,17 @@ class TestMKLDNNReluDim4(TestRelu):
self.outputs = {'Out': out}
self.attrs = {"use_mkldnn": True}
def test_check_output(self):
# TODO(wangzhongpu): support mkldnn op in dygraph mode
self.check_output(check_dygraph=False)
def test_check_grad(self):
if self.dtype == np.float16:
return
# TODO(wangzhongpu): support mkldnn op in dygraph mode
self.check_grad(
['X'], 'Out', max_relative_error=0.007, check_dygraph=False)
class TestMKLDNNLeakyReluDim4(TestLeakyRelu):
def setUp(self):
@ -83,6 +149,17 @@ class TestMKLDNNLeakyReluDim4(TestLeakyRelu):
self.outputs = {'Out': out}
self.attrs = {"use_mkldnn": True}
def test_check_output(self):
# TODO(wangzhongpu): support mkldnn op in dygraph mode
self.check_output(check_dygraph=False)
def test_check_grad(self):
if self.dtype == np.float16:
return
# TODO(wangzhongpu): support mkldnn op in dygraph mode
self.check_grad(
['X'], 'Out', max_relative_error=0.007, check_dygraph=False)
class TestMKLDNNTanhDim4(TestTanh):
def setUp(self):
@ -94,6 +171,17 @@ class TestMKLDNNTanhDim4(TestTanh):
self.outputs = {'Out': np.tanh(self.inputs['X'])}
self.attrs = {"use_mkldnn": True}
def test_check_output(self):
# TODO(wangzhongpu): support mkldnn op in dygraph mode
self.check_output(check_dygraph=False)
def test_check_grad(self):
if self.dtype == np.float16:
return
# TODO(wangzhongpu): support mkldnn op in dygraph mode
self.check_grad(
['X'], 'Out', max_relative_error=0.007, check_dygraph=False)
class TestMKLDNNSqrtDim4(TestSqrt):
def setUp(self):
@ -105,6 +193,17 @@ class TestMKLDNNSqrtDim4(TestSqrt):
self.outputs = {'Out': np.sqrt(self.inputs['X'])}
self.attrs = {"use_mkldnn": True}
def test_check_output(self):
# TODO(wangzhongpu): support mkldnn op in dygraph mode
self.check_output(check_dygraph=False)
def test_check_grad(self):
if self.dtype == np.float16:
return
# TODO(wangzhongpu): support mkldnn op in dygraph mode
self.check_grad(
['X'], 'Out', max_relative_error=0.007, check_dygraph=False)
class TestMKLDNNAbsDim4(TestAbs):
def setUp(self):
@ -117,6 +216,17 @@ class TestMKLDNNAbsDim4(TestAbs):
self.outputs = {'Out': np.abs(self.inputs['X'])}
self.attrs = {"use_mkldnn": True}
def test_check_output(self):
# TODO(wangzhongpu): support mkldnn op in dygraph mode
self.check_output(check_dygraph=False)
def test_check_grad(self):
if self.dtype == np.float16:
return
# TODO(wangzhongpu): support mkldnn op in dygraph mode
self.check_grad(
['X'], 'Out', max_relative_error=0.007, check_dygraph=False)
# Check if primitives already exist in backward
class TestMKLDNNAbsPrimitivesAlreadyExist(unittest.TestCase):

@ -36,7 +36,8 @@ class TestConcatOp(OpTest):
self.outputs = {'Out': self.output}
def test_check_output(self):
self.check_output()
# TODO(wangzhongpu): support mkldnn op in dygraph mode
self.check_output(check_dygraph=False)
#--------------------test concat s8 in with axis 0--------------------

@ -24,6 +24,10 @@ class TestMKLDNNConcatOp(TestConcatOp):
self.attrs["use_mkldnn"] = True
self._cpu_only = True
def test_check_output(self):
# TODO(wangzhongpu): support mkldnn op in dygraph mode
self.check_output(check_dygraph=(self.attrs["use_mkldnn"] == False))
def test_check_grad(self):
pass
@ -37,6 +41,10 @@ class TestMKLDNNConcatOp2(TestConcatOp2):
self.attrs["use_mkldnn"] = True
self._cpu_only = True
def test_check_output(self):
# TODO(wangzhongpu): support mkldnn op in dygraph mode
self.check_output(check_dygraph=(self.attrs["use_mkldnn"] == False))
def test_check_grad(self):
pass
@ -50,6 +58,10 @@ class TestMKLDNNConcatOp3(TestConcatOp3):
self.attrs["use_mkldnn"] = True
self._cpu_only = True
def test_check_output(self):
# TODO(wangzhongpu): support mkldnn op in dygraph mode
self.check_output(check_dygraph=(self.attrs["use_mkldnn"] == False))
def test_check_grad(self):
pass

@ -146,7 +146,9 @@ class TestConv2dInt8Op(TestConv2dOp):
self.outputs = {'Output': output}
def test_check_output(self):
self.check_output_with_place(core.CPUPlace(), atol=0)
# TODO(wangzhongpu): support mkldnn op in dygraph mode
self.check_output_with_place(
core.CPUPlace(), atol=0, check_dygraph=False)
def test_check_grad(self):
pass

@ -44,7 +44,8 @@ class TestDeQuantizeOp(OpTest):
self.attrs = {'Scale': self.scale, }
def test_check_output(self):
self.check_output()
# TODO(wangzhongpu): support mkldnn op in dygraph mode
self.check_output(check_dygraph=False)
def set_scale(self):
pass

@ -53,7 +53,8 @@ class TestFCMKLDNNOp(OpTest):
}
def test_check_output(self):
self.check_output()
# TODO(wangzhongpu): support mkldnn op in dygraph mode
self.check_output(check_dygraph=False)
def test_check_grad_normal(self):
pass

@ -25,7 +25,13 @@ class TestLRNMKLDNNOp(TestLRNOp):
return attrs
def test_check_output(self):
self.check_output(atol=0.002)
# TODO(wangzhongpu): support mkldnn op in dygraph mode
self.check_output(atol=0.002, check_dygraph=False)
def test_check_grad_normal(self):
# TODO(wangzhongpu): support mkldnn op in dygraph mode
self.check_grad(
['X'], 'Out', max_relative_error=0.01, check_dygraph=False)
class TestLRNMKLDNNOpWithIsTest(TestLRNMKLDNNOp):
@ -37,7 +43,8 @@ class TestLRNMKLDNNOpWithIsTest(TestLRNMKLDNNOp):
def test_check_grad_normal(self):
def check_raise_is_test():
try:
self.check_grad(['X'], 'Out', max_relative_error=0.01)
self.check_grad(
['X'], 'Out', max_relative_error=0.01, check_dygraph=False)
except Exception as e:
t = \
"is_test attribute should be set to False in training phase."

@ -73,7 +73,9 @@ class TestMKLDNNMulOpS8S8(OpTest):
self.outputs = {'Out': output}
def test_check_output(self):
self.check_output_with_place(core.CPUPlace(), atol=0)
# TODO(wangzhongpu): support mkldnn op in dygraph mode
self.check_output_with_place(
core.CPUPlace(), atol=0, check_dygraph=False)
def test_check_grad_normal(self):
pass

@ -43,7 +43,9 @@ class TestPool2dMKLDNNInt8_Op(TestPool2D_Op):
self.outputs = {'Out': output}
def test_check_output(self):
self.check_output_with_place(core.CPUPlace(), atol=1e-5)
# TODO(wangzhongpu): support mkldnn op in dygraph mode
self.check_output_with_place(
core.CPUPlace(), atol=1e-5, check_dygraph=False)
def test_check_grad(self):
pass

@ -47,7 +47,8 @@ class TestQuantizeOp(OpTest):
}
def test_check_output(self):
self.check_output()
# TODO(wangzhongpu): support mkldnn op in dygraph mode
self.check_output(check_dygraph=False)
def set_scale(self):
pass

@ -52,7 +52,8 @@ class TestReQuantizeOp(OpTest):
self.attrs = {'Scale_in': self.scale_in, 'Scale_out': self.scale_out}
def test_check_output(self):
self.check_output()
# TODO(wangzhongpu): support mkldnn op in dygraph mode
self.check_output(check_dygraph=False)
def set_scale(self):
pass

@ -22,7 +22,62 @@ from paddle.fluid.tests.unittests.test_softmax_op import *
from mkldnn_op_test import check_if_mkldnn_primitives_exist_in_bwd
def stable_softmax(x):
"""Compute the softmax of vector x in a numerically stable way."""
shiftx = x - np.max(x).clip(-64.)
exps = np.exp(shiftx)
return exps / np.sum(exps)
class TestSoftmaxMKLDNNOp(TestSoftmaxOp):
def get_x_shape(self):
return [10, 10]
def get_axis(self):
return -1
def setUp(self):
self.op_type = "softmax"
self.use_cudnn = False
self.use_mkldnn = False
self.dtype = np.float32
self.init_kernel_type()
self.shape = self.get_x_shape()
self.axis = self.get_axis()
x = np.random.uniform(0.1, 1, self.shape).astype(self.dtype)
out = np.apply_along_axis(stable_softmax, self.axis, x)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out}
self.attrs = {
'axis': self.axis,
'use_cudnn': self.use_cudnn,
'use_mkldnn': self.use_mkldnn
}
def test_check_output(self):
# TODO(wangzhongpu): support mkldnn op in dygraph mode
if self.use_cudnn:
place = core.CUDAPlace(0)
self.check_output_with_place(place, atol=1e-5, check_dygraph=False)
else:
self.check_output(check_dygraph=False)
def test_check_grad(self):
# TODO(wangzhongpu): support mkldnn op in dygraph mode
if self.use_cudnn or self.dtype == np.float16:
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_grad_with_place(
place, ["X"],
"Out",
max_relative_error=0.01,
check_dygraph=False)
else:
self.check_grad(
["X"], "Out", max_relative_error=0.01, check_dygraph=False)
def init_kernel_type(self):
self.use_mkldnn = True

@ -17,11 +17,32 @@ from __future__ import print_function
import unittest
from paddle.fluid.tests.unittests.test_sum_op import TestSumOp
import numpy as np
class TestMKLDNN(TestSumOp):
def init_kernel_type(self):
def setUp(self):
self.op_type = "sum"
self.init_kernel_type()
self.use_mkldnn = True
x0 = np.random.random((3, 4)).astype(self.dtype)
x1 = np.random.random((3, 4)).astype(self.dtype)
x2 = np.random.random((3, 4)).astype(self.dtype)
self.inputs = {"X": [("x0", x0), ("x1", x1), ("x2", x2)]}
y = x0 + x1 + x2
self.outputs = {'Out': y}
self.attrs = {'use_mkldnn': self.use_mkldnn}
def init_kernel_type(self):
self.dtype = np.float32
def test_check_output(self):
# TODO(wangzhongpu): support mkldnn op in dygraph mode
self.check_output(check_dygraph=False)
def test_check_grad(self):
# TODO(wangzhongpu): support mkldnn op in dygraph mode
self.check_grad(['x0'], 'Out', check_dygraph=False)
if __name__ == '__main__':

@ -48,8 +48,9 @@ class TestTransposeOp(OpTest):
self.op_type = "transpose2"
def test_check_output(self):
# TODO(wangzhongpu): support mkldnn op in dygraph mode
self.check_output_with_place(
core.CPUPlace(), 1e-5, no_check_set=['XShape'])
core.CPUPlace(), 1e-5, no_check_set=['XShape'], check_dygraph=False)
def initTestCase(self):
self.shape = (2, 3, 4, 5)

@ -17,14 +17,40 @@ from __future__ import print_function
import unittest
from paddle.fluid.tests.unittests.test_transpose_op import TestTransposeOp
import numpy as np
class TestTransposeMKLDNN(TestTransposeOp):
def setUp(self):
self.init_op_type()
self.initTestCase()
self.inputs = {'X': np.random.random(self.shape).astype("float32")}
self.attrs = {
'axis': list(self.axis),
'use_mkldnn': self.use_mkldnn,
}
self.outputs = {
'XShape': np.random.random(self.shape).astype("float32"),
'Out': self.inputs['X'].transpose(self.axis)
}
def init_op_type(self):
self.op_type = "transpose2"
self.use_mkldnn = True
return
def test_check_output(self):
# TODO(wangzhongpu): support mkldnn op in dygraph mode
self.check_output(no_check_set=['XShape'], check_dygraph=False)
def test_check_grad(self):
# TODO(wangzhongpu): support mkldnn op in dygraph mode
self.check_grad(['X'], 'Out', check_dygraph=False)
def initTestCase(self):
self.shape = (3, 4)
self.axis = (1, 0)
class TestCase0MKLDNN(TestTransposeMKLDNN):
def initTestCase(self):

@ -38,7 +38,7 @@ class TestNGRAPHIncrementOp(OpTest):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_dygraph=False)
if __name__ == "__main__":

File diff suppressed because it is too large Load Diff

@ -319,34 +319,44 @@ class TestConv2dOp(OpTest):
def test_check_output(self):
place = core.CUDAPlace(0) if self.has_cuda() else core.CPUPlace()
self.check_output_with_place(place, atol=1e-5)
# TODO(wangzhongpu): support mkldnn op in dygraph mode
self.check_output_with_place(
place, atol=1e-5, check_dygraph=(self.use_mkldnn == False))
def test_check_grad(self):
if self.dtype == np.float16:
return
place = core.CUDAPlace(0) if self.has_cuda() else core.CPUPlace()
# TODO(wangzhongpu): support mkldnn op in dygraph mode
self.check_grad_with_place(
place, {'Input', 'Filter'}, 'Output', max_relative_error=0.02)
place, {'Input', 'Filter'},
'Output',
max_relative_error=0.02,
check_dygraph=(self.use_mkldnn == False))
def test_check_grad_no_filter(self):
if self.dtype == np.float16:
return
place = core.CUDAPlace(0) if self.has_cuda() else core.CPUPlace()
# TODO(wangzhongpu): support mkldnn op in dygraph mode
self.check_grad_with_place(
place, ['Input'],
'Output',
max_relative_error=0.02,
no_grad_set=set(['Filter']))
no_grad_set=set(['Filter']),
check_dygraph=(self.use_mkldnn == False))
def test_check_grad_no_input(self):
if self.dtype == np.float16:
return
place = core.CUDAPlace(0) if self.has_cuda() else core.CPUPlace()
# TODO(wangzhongpu): support mkldnn op in dygraph mode
self.check_grad_with_place(
place, ['Filter'],
'Output',
max_relative_error=0.02,
no_grad_set=set(['Input']))
no_grad_set=set(['Input']),
check_dygraph=(self.use_mkldnn == False))
def init_test_case(self):
self.pad = [0, 0]
@ -739,17 +749,24 @@ class TestConv2dOp_v2(OpTest):
self.use_cuda)
def test_check_output(self):
# TODO(wangzhongpu): support mkldnn op in dygraph mode
place = core.CUDAPlace(0) if self.has_cuda() else core.CPUPlace()
self.check_output_with_place(place, atol=1e-5)
self.check_output_with_place(
place, atol=1e-5, check_dygraph=(self.use_mkldnn == False))
def test_check_grad(self):
# TODO(wangzhongpu): support mkldnn op in dygraph mode
if self.dtype == np.float16:
return
place = core.CUDAPlace(0) if self.has_cuda() else core.CPUPlace()
self.check_grad_with_place(
place, {'Input', 'Filter'}, 'Output', max_relative_error=0.02)
place, {'Input', 'Filter'},
'Output',
max_relative_error=0.02,
check_dygraph=(self.use_mkldnn == False))
def test_check_grad_no_filter(self):
# TODO(wangzhongpu): support mkldnn op in dygraph mode
if self.dtype == np.float16:
return
place = core.CUDAPlace(0) if self.has_cuda() else core.CPUPlace()
@ -757,9 +774,11 @@ class TestConv2dOp_v2(OpTest):
place, ['Input'],
'Output',
max_relative_error=0.02,
no_grad_set=set(['Filter']))
no_grad_set=set(['Filter']),
check_dygraph=(self.use_mkldnn == False))
def test_check_grad_no_input(self):
# TODO(wangzhongpu): support mkldnn op in dygraph mode
if self.dtype == np.float16:
return
place = core.CUDAPlace(0) if self.has_cuda() else core.CPUPlace()
@ -767,7 +786,8 @@ class TestConv2dOp_v2(OpTest):
place, ['Filter'],
'Output',
max_relative_error=0.02,
no_grad_set=set(['Input']))
no_grad_set=set(['Input']),
check_dygraph=(self.use_mkldnn == False))
def init_test_case(self):
self.pad = [0, 0]

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save