|
|
|
@ -15,7 +15,7 @@ limitations under the License. */
|
|
|
|
|
#pragma once
|
|
|
|
|
#include <vector>
|
|
|
|
|
#include "paddle/fluid/framework/data_type.h"
|
|
|
|
|
#include "paddle/fluid/framework/tensor.h"
|
|
|
|
|
#include "paddle/fluid/framework/lod_tensor.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
@ -57,7 +57,7 @@ template <typename DeviceContext, typename T>
|
|
|
|
|
class ConcatGradFunctor {
|
|
|
|
|
public:
|
|
|
|
|
void operator()(const DeviceContext& context, const framework::Tensor& input,
|
|
|
|
|
const std::vector<const framework::Tensor*>& ref_inputs,
|
|
|
|
|
const std::vector<const framework::LoDTensor*>& ref_inputs,
|
|
|
|
|
const int axis, std::vector<framework::Tensor*>* outputs);
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|