modified: src/ge/host_kernels/concat_v2_kernel.cc

pull/314/head
zhaoxinxin 4 years ago
parent c60bfbe20b
commit 8751d19dd8

@ -32,7 +32,7 @@ namespace ge {
namespace {
const size_t kConcatV2InputNum = 3;
const int kSupportEmptyTensorRank = 1;
const std::set<DataType> concatv2_supported_type = {DT_INT32, DT_FLOAT};
const std::set<DataType> concatv2_supported_type = {DT_INT32, DT_FLOAT, DT_INT64};
template <typename T>
void GetOutputData(std::vector<T> &y_data, int64_t loop, size_t &input_size,
@ -88,6 +88,7 @@ Status ConcatV2Kernel::Compute(const ge::OpDescPtr op_desc_ptr, const vector<ge:
std::vector<int32_t> y_data_int32_t;
std::vector<float> y_data_float;
std::vector<int64_t> y_data_int64_t;
// Index 0 can always gets a GeTensorDesc object from any OpDescPtr.
auto output_tensor_desc = op_desc_ptr->GetOutputDesc(0);
@ -106,6 +107,7 @@ Status ConcatV2Kernel::Compute(const ge::OpDescPtr op_desc_ptr, const vector<ge:
switch (data_type) {
SET_OUTPUT(DT_INT32, int32_t)
SET_OUTPUT(DT_FLOAT, float)
SET_OUTPUT(DT_INT64, int64_t)
default:
break;
}

Loading…
Cancel
Save