|
|
|
|
@ -13,10 +13,13 @@ See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/operators/concat_op.h"
|
|
|
|
|
|
|
|
|
|
#include <string>
|
|
|
|
|
#include <vector>
|
|
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_MKLDNN
|
|
|
|
|
#include <paddle/fluid/platform/mkldnn_helper.h>
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
using framework::Tensor;
|
|
|
|
|
@ -47,9 +50,19 @@ class ConcatOp : public framework::OperatorWithKernel {
|
|
|
|
|
if (j == axis) {
|
|
|
|
|
out_dims[axis] += ins[i][j];
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_ENFORCE_EQ(out_dims[j], ins[i][j],
|
|
|
|
|
"Input tensors should have the same "
|
|
|
|
|
"elements except the specify axis.");
|
|
|
|
|
if (ctx->IsRuntime()) {
|
|
|
|
|
// check all shape in run time
|
|
|
|
|
PADDLE_ENFORCE_EQ(out_dims[j], ins[i][j],
|
|
|
|
|
"Input tensors should have the same "
|
|
|
|
|
"elements except the specify axis.");
|
|
|
|
|
} else {
|
|
|
|
|
// not check -1 in compile time
|
|
|
|
|
if (out_dims[j] != -1 && ins[i][j] != -1) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(out_dims[j], ins[i][j],
|
|
|
|
|
"Input tensors should have the same "
|
|
|
|
|
"elements except the specify axis.");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
@ -59,6 +72,22 @@ class ConcatOp : public framework::OperatorWithKernel {
|
|
|
|
|
ctx->SetOutputDim("Out", out_dims);
|
|
|
|
|
ctx->ShareLoD("X", /*->*/ "Out");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
framework::OpKernelType GetExpectedKernelType(
|
|
|
|
|
const framework::ExecutionContext &ctx) const override {
|
|
|
|
|
auto input_data_type =
|
|
|
|
|
framework::GetDataTypeOfVar(ctx.MultiInputVar("X")[0]);
|
|
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_MKLDNN
|
|
|
|
|
if (platform::CanMKLDNNBeUsed(ctx)) {
|
|
|
|
|
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
|
|
|
|
|
framework::DataLayout::kMKLDNN,
|
|
|
|
|
framework::LibraryType::kMKLDNN);
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
return framework::OpKernelType(input_data_type, ctx.GetPlace());
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class ConcatOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
@ -66,6 +95,10 @@ class ConcatOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
void Make() override {
|
|
|
|
|
AddInput("X", "Input tensors of concat operator.").AsDuplicable();
|
|
|
|
|
AddOutput("Out", "Output tensor of concat operator.");
|
|
|
|
|
AddAttr<bool>(
|
|
|
|
|
"use_mkldnn",
|
|
|
|
|
"(bool, default false) Indicates if MKL-DNN kernel will be used")
|
|
|
|
|
.SetDefault(false);
|
|
|
|
|
AddAttr<int>("axis",
|
|
|
|
|
"The axis along which the input tensors will be concatenated.")
|
|
|
|
|
.SetDefault(0);
|
|
|
|
|
|