follow comments and refine code

shanyi15-patch-2
chengduoZH 7 years ago
parent 00e596edbe
commit 82bd82c186

@ -33,6 +33,7 @@ class ConcatKernel : public framework::OpKernel<T> {
auto place = ctx.GetPlace();
out->mutable_data<T>(place);
// TODO(zcd): Sometimes direct copies will be faster
std::vector<framework::Tensor> inputs(ins.size());
for (size_t j = 0; j < ins.size(); ++j) {
inputs[j] = *ins[j];
@ -51,6 +52,7 @@ class ConcatGradKernel : public framework::OpKernel<T> {
auto outs = ctx.MultiOutput<framework::Tensor>(framework::GradVarName("X"));
int64_t axis = static_cast<int64_t>(ctx.Attr<int>("axis"));
// TODO(zcd): Sometimes direct copies will be faster
std::vector<framework::Tensor> outputs(outs.size());
for (size_t j = 0; j < outs.size(); ++j) {
outs[j]->mutable_data<T>(ctx.GetPlace());

@ -19,7 +19,8 @@ namespace operators {
namespace math {
/*
* All tensors' dimension should be the same.
* All tensors' dimension should be the same and the values of
* each dimension are the same, except the axis dimension.
*/
template <typename T>
class ConcatFunctor<platform::CPUDeviceContext, T> {
@ -27,12 +28,9 @@ class ConcatFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& context,
const std::vector<framework::Tensor>& input, const int axis,
framework::Tensor* output) {
// assume the the max size of input is less than 8 and see the performance
// save origin dim
// TODO(zcd): Add input data validity checking
int num = input.size();
std::vector<paddle::framework::DDim> origin_dim(num);
// get the matrix size
int rows = 1;
auto dim_0 = input[0].dims();
for (int i = 0; i < axis; ++i) {
@ -40,7 +38,6 @@ class ConcatFunctor<platform::CPUDeviceContext, T> {
}
int out_rows = rows, out_cols = 0;
// get input's cols
std::vector<int64_t> input_cols(input.size());
for (int i = 0; i < num; ++i) {
int t_cols = input[i].numel() / rows;
@ -64,18 +61,19 @@ class ConcatFunctor<platform::CPUDeviceContext, T> {
}
};
/*
* All tensors' dimension should be the same and the values of
* each dimension are the same, except the axis dimension.
*/
template <typename T>
class ConcatGradFunctor<platform::CPUDeviceContext, T> {
public:
void operator()(const platform::CPUDeviceContext& context,
const framework::Tensor& input, const int axis,
std::vector<framework::Tensor>& outputs) {
// assume the the max size of input is less than 8 and see the performance
// save origin dim
// TODO(zcd): Add input data validity checking
int num = outputs.size();
std::vector<paddle::framework::DDim> origin_dim(num);
// get the matrix size
int input_rows = 1;
auto dim_0 = outputs[0].dims();
for (int i = 0; i < axis; ++i) {
@ -83,7 +81,6 @@ class ConcatGradFunctor<platform::CPUDeviceContext, T> {
}
int input_cols = 0;
// get outputs' cols
std::vector<int64_t> output_cols(outputs.size());
for (int i = 0; i < num; ++i) {
int t_cols = outputs[i].numel() / input_rows;

File diff suppressed because it is too large Load Diff

@ -20,7 +20,16 @@ namespace operators {
namespace math {
/*
* \brief Concatenate the input tensors along the dimension axis.
* TODO(zcd): maybe it needs to be more detailed.
* Examples:
* Input[0] = [[1,2],[3,4]]
* Input[1] = [[5,6]]
* axis = 0
*
* Output = [[1,2],
* [3,4],
* [5,6]]
*/
template <typename DeviceContext, typename T>
class ConcatFunctor {
@ -30,6 +39,18 @@ class ConcatFunctor {
framework::Tensor* output);
};
/*
* \brief Split the input tensors along the dimension axis into outputs.
* TODO(zcd): maybe it needs to be more detailed.
* Examples:
* Input = [[1,2],
* [3,4],
* [5,6]]
* axis = 0
*
* Output[0] = [[1,2],[3,4]]
* Output[1] = [[5,6]]
*/
template <typename DeviceContext, typename T>
class ConcatGradFunctor {
public:

Loading…
Cancel
Save