|
|
|
@ -19,7 +19,8 @@ namespace operators {
|
|
|
|
|
namespace math {
|
|
|
|
|
|
|
|
|
|
/*
|
|
|
|
|
* All tensors' dimension should be the same.
|
|
|
|
|
* All tensors' dimension should be the same and the values of
|
|
|
|
|
* each dimension are the same, except the axis dimension.
|
|
|
|
|
*/
|
|
|
|
|
template <typename T>
|
|
|
|
|
class ConcatFunctor<platform::CPUDeviceContext, T> {
|
|
|
|
@ -27,12 +28,9 @@ class ConcatFunctor<platform::CPUDeviceContext, T> {
|
|
|
|
|
void operator()(const platform::CPUDeviceContext& context,
|
|
|
|
|
const std::vector<framework::Tensor>& input, const int axis,
|
|
|
|
|
framework::Tensor* output) {
|
|
|
|
|
// assume the the max size of input is less than 8 and see the performance
|
|
|
|
|
// save origin dim
|
|
|
|
|
// TODO(zcd): Add input data validity checking
|
|
|
|
|
int num = input.size();
|
|
|
|
|
std::vector<paddle::framework::DDim> origin_dim(num);
|
|
|
|
|
|
|
|
|
|
// get the matrix size
|
|
|
|
|
int rows = 1;
|
|
|
|
|
auto dim_0 = input[0].dims();
|
|
|
|
|
for (int i = 0; i < axis; ++i) {
|
|
|
|
@ -40,7 +38,6 @@ class ConcatFunctor<platform::CPUDeviceContext, T> {
|
|
|
|
|
}
|
|
|
|
|
int out_rows = rows, out_cols = 0;
|
|
|
|
|
|
|
|
|
|
// get input's cols
|
|
|
|
|
std::vector<int64_t> input_cols(input.size());
|
|
|
|
|
for (int i = 0; i < num; ++i) {
|
|
|
|
|
int t_cols = input[i].numel() / rows;
|
|
|
|
@ -64,18 +61,19 @@ class ConcatFunctor<platform::CPUDeviceContext, T> {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
/*
|
|
|
|
|
* All tensors' dimension should be the same and the values of
|
|
|
|
|
* each dimension are the same, except the axis dimension.
|
|
|
|
|
*/
|
|
|
|
|
template <typename T>
|
|
|
|
|
class ConcatGradFunctor<platform::CPUDeviceContext, T> {
|
|
|
|
|
public:
|
|
|
|
|
void operator()(const platform::CPUDeviceContext& context,
|
|
|
|
|
const framework::Tensor& input, const int axis,
|
|
|
|
|
std::vector<framework::Tensor>& outputs) {
|
|
|
|
|
// assume the the max size of input is less than 8 and see the performance
|
|
|
|
|
// save origin dim
|
|
|
|
|
// TODO(zcd): Add input data validity checking
|
|
|
|
|
int num = outputs.size();
|
|
|
|
|
std::vector<paddle::framework::DDim> origin_dim(num);
|
|
|
|
|
|
|
|
|
|
// get the matrix size
|
|
|
|
|
int input_rows = 1;
|
|
|
|
|
auto dim_0 = outputs[0].dims();
|
|
|
|
|
for (int i = 0; i < axis; ++i) {
|
|
|
|
@ -83,7 +81,6 @@ class ConcatGradFunctor<platform::CPUDeviceContext, T> {
|
|
|
|
|
}
|
|
|
|
|
int input_cols = 0;
|
|
|
|
|
|
|
|
|
|
// get outputs' cols
|
|
|
|
|
std::vector<int64_t> output_cols(outputs.size());
|
|
|
|
|
for (int i = 0; i < num; ++i) {
|
|
|
|
|
int t_cols = outputs[i].numel() / input_rows;
|
|
|
|
|