|
|
|
@ -23,7 +23,7 @@ limitations under the License. */
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
using framework::Tensor;
|
|
|
|
|
using Tensor = framework::Tensor;
|
|
|
|
|
|
|
|
|
|
class ConcatOp : public framework::OperatorWithKernel {
|
|
|
|
|
public:
|
|
|
|
@ -80,12 +80,12 @@ class ConcatOp : public framework::OperatorWithKernel {
|
|
|
|
|
protected:
|
|
|
|
|
framework::OpKernelType GetExpectedKernelType(
|
|
|
|
|
const framework::ExecutionContext &ctx) const override {
|
|
|
|
|
auto vars = ctx.MultiInputVar("X");
|
|
|
|
|
auto inputs = ctx.MultiInput<Tensor>("X");
|
|
|
|
|
auto input_data_type = framework::proto::VarType::Type(0);
|
|
|
|
|
bool flag = 0;
|
|
|
|
|
for (auto *var : vars) {
|
|
|
|
|
if (var->IsInitialized()) {
|
|
|
|
|
input_data_type = framework::GetDataTypeOfVar(var);
|
|
|
|
|
for (auto *input : inputs) {
|
|
|
|
|
if (input->IsInitialized() && input->numel() > 0) {
|
|
|
|
|
input_data_type = input->type();
|
|
|
|
|
flag = 1;
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|