Fix GetExpectedKernelType in Concat op (#17459)

* fix concat op vartype check, test=develop
fix_ema
jerrywgz 6 years ago committed by GitHub
parent 58f7695ab2
commit c1aae8b8d2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -80,8 +80,19 @@ class ConcatOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
auto input_data_type = auto vars = ctx.MultiInputVar("X");
framework::GetDataTypeOfVar(ctx.MultiInputVar("X")[0]); auto input_data_type = framework::proto::VarType::Type(0);
bool flag = 0;
for (auto *var : vars) {
if (var->IsInitialized()) {
input_data_type = framework::GetDataTypeOfVar(var);
flag = 1;
break;
}
}
if (flag == 0) {
PADDLE_THROW("All Inputs of Concat OP are Empty!");
}
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (platform::CanMKLDNNBeUsed(ctx)) { if (platform::CanMKLDNNBeUsed(ctx)) {

Loading…
Cancel
Save