You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
85 lines
2.7 KiB
85 lines
2.7 KiB
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
#pragma once
|
|
|
|
#include <iostream>
|
|
#include <vector>
|
|
#include "paddle/extension.h"
|
|
|
|
int64_t GetRows(std::vector<int64_t> shape, int64_t axis) {
|
|
int64_t rows = 1;
|
|
for (int64_t i = 0; i < axis; ++i) {
|
|
rows *= shape[i];
|
|
}
|
|
return rows;
|
|
}
|
|
|
|
std::vector<int64_t> GetCols(const std::vector<paddle::Tensor>& ins,
|
|
int64_t rows,
|
|
int64_t* cols) {
|
|
std::vector<int64_t> cols_vec(ins.size());
|
|
for (size_t i = 0; i < ins.size(); ++i) {
|
|
int64_t t_cols = ins[i].size() / rows;
|
|
*cols += t_cols;
|
|
cols_vec[i] = t_cols;
|
|
}
|
|
return cols_vec;
|
|
}
|
|
|
|
template <typename data_t>
|
|
void ConcatCpuKernel(const std::vector<paddle::Tensor>& ins,
|
|
paddle::Tensor* out,
|
|
int64_t axis) {
|
|
size_t num = ins.size();
|
|
int64_t out_rows = GetRows(ins[0].shape(), axis);
|
|
int64_t out_cols = 0;
|
|
auto ins_cols = GetCols(ins, out_rows, &out_cols);
|
|
|
|
auto* out_data = out->mutable_data<data_t>();
|
|
int64_t col_idx = 0;
|
|
for (size_t i = 0; i < num; ++i) {
|
|
int64_t col_len = ins_cols[i];
|
|
auto* in_data = ins[i].data<data_t>();
|
|
for (int j = 0; j < out_rows; ++j) {
|
|
std::memcpy(out_data + j * out_cols + col_idx,
|
|
in_data + j * col_len,
|
|
sizeof(data_t) * col_len);
|
|
}
|
|
col_idx += col_len;
|
|
}
|
|
}
|
|
|
|
template <typename data_t>
|
|
void SplitCpuKernel(const paddle::Tensor& in,
|
|
const std::vector<paddle::Tensor>& ref_ins,
|
|
std::vector<paddle::Tensor>* outs,
|
|
int64_t axis) {
|
|
size_t num = outs->size();
|
|
int64_t in_rows = GetRows(ref_ins[0].shape(), axis);
|
|
int64_t in_cols = 0;
|
|
auto out_cols = GetCols(ref_ins, in_rows, &in_cols);
|
|
|
|
for (size_t i = 0; i < in_rows; ++i) {
|
|
auto* in_data = in.data<data_t>() + i * in_cols;
|
|
int64_t col_idx = 0;
|
|
for (size_t j = 0; j < num; ++j) {
|
|
int64_t col_len = out_cols[j];
|
|
auto* out_data = outs->at(j).mutable_data<data_t>() + i * col_len;
|
|
std::memcpy(out_data, in_data + col_idx, sizeof(data_t) * col_len);
|
|
col_idx += col_len;
|
|
}
|
|
}
|
|
}
|