|
|
@ -80,8 +80,19 @@ class ConcatOp : public framework::OperatorWithKernel {
|
|
|
|
protected:
|
|
|
|
protected:
|
|
|
|
framework::OpKernelType GetExpectedKernelType(
|
|
|
|
framework::OpKernelType GetExpectedKernelType(
|
|
|
|
const framework::ExecutionContext &ctx) const override {
|
|
|
|
const framework::ExecutionContext &ctx) const override {
|
|
|
|
auto input_data_type =
|
|
|
|
auto vars = ctx.MultiInputVar("X");
|
|
|
|
framework::GetDataTypeOfVar(ctx.MultiInputVar("X")[0]);
|
|
|
|
auto input_data_type = framework::proto::VarType::Type(0);
|
|
|
|
|
|
|
|
bool flag = 0;
|
|
|
|
|
|
|
|
for (auto *var : vars) {
|
|
|
|
|
|
|
|
if (var->IsInitialized()) {
|
|
|
|
|
|
|
|
input_data_type = framework::GetDataTypeOfVar(var);
|
|
|
|
|
|
|
|
flag = 1;
|
|
|
|
|
|
|
|
break;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if (flag == 0) {
|
|
|
|
|
|
|
|
PADDLE_THROW("All Inputs of Concat OP are Empty!");
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_MKLDNN
|
|
|
|
#ifdef PADDLE_WITH_MKLDNN
|
|
|
|
if (platform::CanMKLDNNBeUsed(ctx)) {
|
|
|
|
if (platform::CanMKLDNNBeUsed(ctx)) {
|
|
|
|