|
|
|
@ -13,10 +13,14 @@ See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/operators/concat_op.h"
|
|
|
|
|
|
|
|
|
|
#include <memory>
|
|
|
|
|
#include <string>
|
|
|
|
|
#include <vector>
|
|
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_MKLDNN
|
|
|
|
|
#include <paddle/fluid/platform/mkldnn_helper.h>
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
using framework::Tensor;
|
|
|
|
@ -45,11 +49,29 @@ class ConcatOp : public framework::OperatorWithKernel {
|
|
|
|
|
for (size_t i = 1; i < n; i++) {
|
|
|
|
|
for (size_t j = 0; j < in_zero_dims_size; j++) {
|
|
|
|
|
if (j == axis) {
|
|
|
|
|
out_dims[axis] += ins[i][j];
|
|
|
|
|
if (ctx->IsRuntime()) {
|
|
|
|
|
out_dims[axis] += ins[i][j];
|
|
|
|
|
} else {
|
|
|
|
|
if (out_dims[axis] == -1 || ins[i][j] == -1) {
|
|
|
|
|
out_dims[axis] = -1;
|
|
|
|
|
} else {
|
|
|
|
|
out_dims[axis] += ins[i][j];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_ENFORCE_EQ(out_dims[j], ins[i][j],
|
|
|
|
|
"Input tensors should have the same "
|
|
|
|
|
"elements except the specify axis.");
|
|
|
|
|
if (ctx->IsRuntime()) {
|
|
|
|
|
// check all shape in run time
|
|
|
|
|
PADDLE_ENFORCE_EQ(out_dims[j], ins[i][j],
|
|
|
|
|
"Input tensors should have the same "
|
|
|
|
|
"elements except the specify axis.");
|
|
|
|
|
} else {
|
|
|
|
|
// not check -1 with other in compile time
|
|
|
|
|
if (out_dims[j] > 0 && ins[i][j] > 0) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(out_dims[j], ins[i][j],
|
|
|
|
|
"Input tensors should have the same "
|
|
|
|
|
"elements except the specify axis.");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -59,6 +81,22 @@ class ConcatOp : public framework::OperatorWithKernel {
|
|
|
|
|
ctx->SetOutputDim("Out", out_dims);
|
|
|
|
|
ctx->ShareLoD("X", /*->*/ "Out");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
framework::OpKernelType GetExpectedKernelType(
|
|
|
|
|
const framework::ExecutionContext &ctx) const override {
|
|
|
|
|
auto input_data_type =
|
|
|
|
|
framework::GetDataTypeOfVar(ctx.MultiInputVar("X")[0]);
|
|
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_MKLDNN
|
|
|
|
|
if (platform::CanMKLDNNBeUsed(ctx)) {
|
|
|
|
|
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
|
|
|
|
|
framework::DataLayout::kMKLDNN,
|
|
|
|
|
framework::LibraryType::kMKLDNN);
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
return framework::OpKernelType(input_data_type, ctx.GetPlace());
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class ConcatOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
@ -66,6 +104,10 @@ class ConcatOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
void Make() override {
|
|
|
|
|
AddInput("X", "Input tensors of concat operator.").AsDuplicable();
|
|
|
|
|
AddOutput("Out", "Output tensor of concat operator.");
|
|
|
|
|
AddAttr<bool>(
|
|
|
|
|
"use_mkldnn",
|
|
|
|
|
"(bool, default false) Indicates if MKL-DNN kernel will be used")
|
|
|
|
|
.SetDefault(false);
|
|
|
|
|
AddAttr<int>("axis",
|
|
|
|
|
"The axis along which the input tensors will be concatenated.")
|
|
|
|
|
.SetDefault(0);
|
|
|
|
@ -87,11 +129,7 @@ Examples:
|
|
|
|
|
|
|
|
|
|
class ConcatOpGrad : public framework::OperatorWithKernel {
|
|
|
|
|
public:
|
|
|
|
|
ConcatOpGrad(const std::string &type,
|
|
|
|
|
const framework::VariableNameMap &inputs,
|
|
|
|
|
const framework::VariableNameMap &outputs,
|
|
|
|
|
const framework::AttributeMap &attrs)
|
|
|
|
|
: OperatorWithKernel(type, inputs, outputs, attrs) {}
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
|
|
|
|
auto in_x = "X";
|
|
|
|
@ -109,6 +147,33 @@ class ConcatOpGrad : public framework::OperatorWithKernel {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
framework::OpKernelType GetExpectedKernelType(
|
|
|
|
|
const framework::ExecutionContext &ctx) const override {
|
|
|
|
|
return framework::OpKernelType(
|
|
|
|
|
ctx.Input<Tensor>(framework::GradVarName("Out"))->type(),
|
|
|
|
|
ctx.GetPlace());
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(ConcatOpGradNoNeedBufferVarInference,
|
|
|
|
|
"X");
|
|
|
|
|
|
|
|
|
|
class ConcatGradOpDescMaker : public framework::SingleGradOpDescMaker {
|
|
|
|
|
public:
|
|
|
|
|
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
std::unique_ptr<framework::OpDesc> Apply() const override {
|
|
|
|
|
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
|
|
|
|
|
op->SetType("concat_grad");
|
|
|
|
|
op->SetInput("X", Input("X"));
|
|
|
|
|
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
|
|
|
|
|
op->SetOutput(framework::GradVarName("X"), InputGrad("X", false));
|
|
|
|
|
op->SetAttrMap(Attrs());
|
|
|
|
|
return op;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
} // namespace operators
|
|
|
|
@ -116,9 +181,9 @@ class ConcatOpGrad : public framework::OperatorWithKernel {
|
|
|
|
|
|
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
|
REGISTER_OPERATOR(concat, ops::ConcatOp, ops::ConcatOpMaker,
|
|
|
|
|
paddle::framework::DefaultGradOpDescMaker<
|
|
|
|
|
false> /* set false to disable empty grad */);
|
|
|
|
|
REGISTER_OPERATOR(concat_grad, ops::ConcatOpGrad);
|
|
|
|
|
ops::ConcatGradOpDescMaker);
|
|
|
|
|
REGISTER_OPERATOR(concat_grad, ops::ConcatOpGrad,
|
|
|
|
|
ops::ConcatOpGradNoNeedBufferVarInference);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
|
concat, ops::ConcatKernel<paddle::platform::CPUDeviceContext, double>,
|
|
|
|
|
ops::ConcatKernel<paddle::platform::CPUDeviceContext, float>,
|
|
|
|
|