refine GetExpectedKernelType in conat op, test=develop (#17934)

dependabot/pip/python/requests-2.20.0
jerrywgz 6 years ago committed by GitHub
parent 3ece61f71e
commit aab4d12c0e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -23,7 +23,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
using framework::Tensor;
using Tensor = framework::Tensor;
class ConcatOp : public framework::OperatorWithKernel {
public:
@ -80,12 +80,12 @@ class ConcatOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto vars = ctx.MultiInputVar("X");
auto inputs = ctx.MultiInput<Tensor>("X");
auto input_data_type = framework::proto::VarType::Type(0);
bool flag = 0;
for (auto *var : vars) {
if (var->IsInitialized()) {
input_data_type = framework::GetDataTypeOfVar(var);
for (auto *input : inputs) {
if (input->IsInitialized() && input->numel() > 0) {
input_data_type = input->type();
flag = 1;
break;
}

Loading…
Cancel
Save