You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Paddle/paddle/fluid/operators/stack_op.cc

166 lines
6.2 KiB

7 years ago
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/stack_op.h"
#include <memory>
#include <vector>
7 years ago
namespace plat = paddle::platform;
namespace ops = paddle::operators;
namespace paddle {
namespace operators {
class StackOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_GT(ctx->Inputs("X").size(), 0,
platform::errors::InvalidArgument(
"Number of Inputs(X) must be larger than 0, but"
" received value is:%d.",
ctx->Inputs("X").size()));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Y"), true,
platform::errors::InvalidArgument(
"Output(Y) of stack_op should not be null."));
auto input_dims = ctx->GetInputsDim("X");
for (size_t i = 1; i < input_dims.size(); ++i) {
PADDLE_ENFORCE_EQ(input_dims[i], input_dims[0],
platform::errors::InvalidArgument(
"Dims of all Inputs(X) must be the same, but"
" received input %d dim is:%d not equal to input 0"
" dim:%d.",
i, input_dims[i], input_dims[0]));
}
// Only lod of X[0] would be shared with Y
ctx->ShareLoD("X", /*->*/ "Y");
int axis = ctx->Attrs().Get<int>("axis");
int rank = input_dims[0].size();
PADDLE_ENFORCE_GE(
axis, -(rank + 1),
platform::errors::InvalidArgument(
"Attr(axis) must be inside [-(rank+1), rank+1), where rank = %d, "
"but received axis is:%d.",
rank, axis));
PADDLE_ENFORCE_LT(
axis, rank + 1,
platform::errors::InvalidArgument(
"Attr(axis) must be inside [-(rank+1), rank+1), where rank = %d, "
"but received axis is:%d",
rank, axis));
if (axis < 0) axis += (rank + 1);
auto vec = framework::vectorize<int>(input_dims[0]);
vec.insert(vec.begin() + axis, input_dims.size());
ctx->SetOutputDim("Y", framework::make_ddim(vec));
}
};
class StackOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "The input of stack op.").AsDuplicable();
AddOutput("Y", "The output of stack op.");
AddAttr<int>("axis",
"The axis along which all of the Inputs(X) should be stacked.")
.SetDefault(0);
AddComment(R"DOC(
Stack Operator.
Stack all of the Inputs(X) into one tensor along Attr(axis). The dims of all Inputs(X) must be the same.
)DOC");
}
};
class StackOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput(framework::GradVarName("Y")), true,
platform::errors::InvalidArgument("Input(Y@Grad) not exist."));
int axis = ctx->Attrs().Get<int>("axis");
auto dy_dim = ctx->GetInputDim(framework::GradVarName("Y"));
int rank = dy_dim.size();
PADDLE_ENFORCE_GE(
axis, -rank,
platform::errors::InvalidArgument(
"Attr(axis) must be inside [-rank, rank), where rank = %d, "
"but received axis is:%d.",
rank, axis));
PADDLE_ENFORCE_LT(
axis, rank,
platform::errors::InvalidArgument(
"Attr(axis) must be inside [-rank, rank), where rank = %d, "
"but received axis is:%d.",
rank, axis));
if (axis < 0) axis += rank;
PADDLE_ENFORCE_EQ(
ctx->Outputs(framework::GradVarName("X")).size(),
static_cast<size_t>(dy_dim[axis]),
platform::errors::InvalidArgument(
"Number of Outputs(X@Grad) is equal to dy dim at axis, but"
" received outputs size is:%d, dy dims is:%d.",
ctx->Outputs(framework::GradVarName("X")).size(),
static_cast<size_t>(dy_dim[axis])));
auto vec = framework::vectorize<int>(dy_dim);
vec.erase(vec.begin() + axis);
ctx->SetOutputsDim(
framework::GradVarName("X"),
std::vector<framework::DDim>(dy_dim[axis], framework::make_ddim(vec)));
}
};
template <typename T>
class StackGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("stack_grad");
op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X", false));
op->SetAttrMap(this->Attrs());
}
};
} // namespace operators
} // namespace paddle
7 years ago
REGISTER_OPERATOR(stack, ops::StackOp, ops::StackOpMaker,
GradMaker for dygraph (#19706) * refactor dygraph,test=develop * fix failed unittest,test=develop * polish code,test=develop * check windows ci error,test=develop try to fix windows ci error by np.allclose,test=develop * polish vlog and profiler, test=develop * try to fix preceding ops order,test=develop * test transformer in windows ci, test=develop * use python c-api to speed up tracer.trace,test=develop * test=develop, fix docker with paddle nccl problem * test=develop, add ut for debug string and gradient_accumulator * test=develop, add tests for layer/gradient_accumulator/prepared_op * test=develop, fix complie error for test_prepared_op * test=develop, add more ut for dygraph * test=develop, create API.spec for dygraph api change * optimize grad maker; test=develop * optimize grad maker * test * grad make optim; test=develop * fix unittest bugs; test=develop * add dygraph grad op maker and split_op * grad op maker refactor; test=develop * add dygraph grad maker; test=develop * fix op deformable_conv_v1_op bug; test=develop * fix deformable_conv prroi pool bugs; * fix new op grad op maker bug; test=develop * fix split by ref bug; test=develop * fix dygraph auto prune bug; test=develop * fix test_trace bug; test=develop * fix fused emb seq pool bug; test=develop * remove useless code in op_desc file; test=develop * remove useless code, StrVarBaseNode; test=develop * fix review issues; test=develop * fix rank_loss grad maker; test=develop * remove flag in VarBase; test=develop * fix distributed_notify_op compile bug ; test=develop * fix reshape op double grad; test=develop * fix expand as op; test=develop * add impertive type_defs.h for demo_train; test=develop * fix inference lib cmake; test=develop * fix inference lib; test=develop * fix infernce_lib; test=develop * fix inference cmake; test=develop * fix inference lib; test=develop * fix inference lib; test=develop * remove condition dygraph grad maker, modify local name; test=develop * fix split grad maker bug; test=develop * fix pyramid_op bug; test=develop * change travis time out limit; test=develop * restore travis; test=develop * change timeout limit; test=develop
5 years ago
ops::StackGradOpMaker<paddle::framework::OpDesc>,
ops::StackGradOpMaker<paddle::imperative::OpBase>);
7 years ago
REGISTER_OPERATOR(stack_grad, ops::StackOpGrad);
7 years ago
REGISTER_OP_CPU_KERNEL(stack, ops::StackKernel<plat::CPUDeviceContext, float>,
ops::StackKernel<plat::CPUDeviceContext, double>,
ops::StackKernel<plat::CPUDeviceContext, int>,
ops::StackKernel<plat::CPUDeviceContext, int64_t>);
7 years ago
REGISTER_OP_CPU_KERNEL(stack_grad,
7 years ago
ops::StackGradKernel<plat::CPUDeviceContext, float>,
ops::StackGradKernel<plat::CPUDeviceContext, double>,
ops::StackGradKernel<plat::CPUDeviceContext, int>,
ops::StackGradKernel<plat::CPUDeviceContext, int64_t>);