|
|
|
@ -32,7 +32,7 @@ namespace ge {
|
|
|
|
|
namespace {
|
|
|
|
|
const size_t kConcatV2InputNum = 3;
|
|
|
|
|
const int kSupportEmptyTensorRank = 1;
|
|
|
|
|
const std::set<DataType> concatv2_supported_type = {DT_INT32, DT_FLOAT};
|
|
|
|
|
const std::set<DataType> concatv2_supported_type = {DT_INT32, DT_FLOAT, DT_INT64};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void GetOutputData(std::vector<T> &y_data, int64_t loop, size_t &input_size,
|
|
|
|
@ -88,6 +88,7 @@ Status ConcatV2Kernel::Compute(const ge::OpDescPtr op_desc_ptr, const vector<ge:
|
|
|
|
|
|
|
|
|
|
std::vector<int32_t> y_data_int32_t;
|
|
|
|
|
std::vector<float> y_data_float;
|
|
|
|
|
std::vector<int64_t> y_data_int64_t;
|
|
|
|
|
|
|
|
|
|
// Index 0 can always gets a GeTensorDesc object from any OpDescPtr.
|
|
|
|
|
auto output_tensor_desc = op_desc_ptr->GetOutputDesc(0);
|
|
|
|
@ -106,6 +107,7 @@ Status ConcatV2Kernel::Compute(const ge::OpDescPtr op_desc_ptr, const vector<ge:
|
|
|
|
|
switch (data_type) {
|
|
|
|
|
SET_OUTPUT(DT_INT32, int32_t)
|
|
|
|
|
SET_OUTPUT(DT_FLOAT, float)
|
|
|
|
|
SET_OUTPUT(DT_INT64, int64_t)
|
|
|
|
|
default:
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|