|
|
|
@ -16,7 +16,7 @@ limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/framework/op_registry.h"
|
|
|
|
|
#include "paddle/fluid/memory/memcpy.h"
|
|
|
|
|
#include "unsupported/Eigen/CXX11/Tensor"
|
|
|
|
|
#include "paddle/fluid/operators/math/math_function.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
@ -32,52 +32,53 @@ class SequenceExpandKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto* out = context.Output<LoDTensor>("Out");
|
|
|
|
|
int ref_level = context.Attr<int>("ref_level");
|
|
|
|
|
|
|
|
|
|
out->mutable_data<T>(context.GetPlace());
|
|
|
|
|
auto& x_lod = x->lod();
|
|
|
|
|
auto& y_lod = y->lod();
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_GE(ref_level, 0,
|
|
|
|
|
"Value of attribute `ref_level` should be greater or "
|
|
|
|
|
"equal to 0.");
|
|
|
|
|
PADDLE_ENFORCE_GT(y_lod.size(), 0,
|
|
|
|
|
"Level number of `Y`'s lod should be greater than 0.");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_LT(ref_level, y_lod.size(),
|
|
|
|
|
"Value of attribute `ref_level` should be smaller than "
|
|
|
|
|
"level number of Y's lod.");
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
ref_level == -1 || (ref_level >= 0 && ref_level < y_lod.size()),
|
|
|
|
|
"Invlid `ref_level`, which should be either equal to -1 "
|
|
|
|
|
"or in [0, %d)",
|
|
|
|
|
y_lod.size());
|
|
|
|
|
|
|
|
|
|
if (y_lod[ref_level].size() < 1) {
|
|
|
|
|
if (ref_level == -1) ref_level = y_lod.size() - 1;
|
|
|
|
|
|
|
|
|
|
if (y_lod[ref_level].size() <= 1) {
|
|
|
|
|
framework::TensorCopy(*x, context.GetPlace(), out);
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (x_lod.size() == 0) {
|
|
|
|
|
int out_start = 0;
|
|
|
|
|
for (size_t i = 1; i < y_lod[ref_level].size(); ++i) {
|
|
|
|
|
int repeat_num = y_lod[ref_level][i] - y_lod[ref_level][i - 1];
|
|
|
|
|
auto x_sub_tensor = x->Slice(i - 1, i);
|
|
|
|
|
for (size_t j = 0; j < repeat_num; ++j) {
|
|
|
|
|
auto out_sub_tensor = out->Slice(out_start, out_start + 1);
|
|
|
|
|
framework::TensorCopy(x_sub_tensor, context.GetPlace(),
|
|
|
|
|
&out_sub_tensor);
|
|
|
|
|
out_start++;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
auto& out_lod = *out->mutable_lod();
|
|
|
|
|
auto& out_lod = *out->mutable_lod();
|
|
|
|
|
if (x_lod.size() == 1) {
|
|
|
|
|
out_lod.resize(1);
|
|
|
|
|
out_lod[0].resize(1);
|
|
|
|
|
out_lod[0][0] = 0;
|
|
|
|
|
int out_idx = 0;
|
|
|
|
|
for (size_t i = 1; i < y_lod[ref_level].size(); ++i) {
|
|
|
|
|
int repeat_num = y_lod[ref_level][i] - y_lod[ref_level][i - 1];
|
|
|
|
|
int x_seq_len = x_lod[0][i] - x_lod[0][i - 1];
|
|
|
|
|
auto x_sub_tensor = x->Slice(x_lod[0][i], x_lod[0][i - 1]);
|
|
|
|
|
for (size_t j = 0; j < repeat_num; ++j) {
|
|
|
|
|
auto out_sub_tensor =
|
|
|
|
|
out->Slice(out_lod[0][out_idx], out_lod[0][out_idx] + x_seq_len);
|
|
|
|
|
framework::TensorCopy(x_sub_tensor, context.GetPlace(),
|
|
|
|
|
&out_sub_tensor);
|
|
|
|
|
out_lod[0].push_back(out_lod[0][out_idx] + x_seq_len);
|
|
|
|
|
out_idx++;
|
|
|
|
|
out_lod[0] = {0};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int out_offset = 0;
|
|
|
|
|
for (size_t i = 1; i < y_lod[ref_level].size(); ++i) {
|
|
|
|
|
int repeat_num = y_lod[ref_level][i] - y_lod[ref_level][i - 1];
|
|
|
|
|
int x_start = i - 1;
|
|
|
|
|
int x_end = i;
|
|
|
|
|
if (x_lod.size() == 1) {
|
|
|
|
|
x_start = x_lod[0][i - 1];
|
|
|
|
|
x_end = x_lod[0][i];
|
|
|
|
|
}
|
|
|
|
|
int x_seq_len = x_end - x_start;
|
|
|
|
|
auto x_sub_tensor = x->Slice(x_start, x_end);
|
|
|
|
|
for (size_t j = 0; j < repeat_num; ++j) {
|
|
|
|
|
int out_start = out_offset;
|
|
|
|
|
if (x_lod.size() == 1) {
|
|
|
|
|
out_start = out_lod[0][out_offset];
|
|
|
|
|
out_lod[0].push_back(x_seq_len);
|
|
|
|
|
}
|
|
|
|
|
auto out_sub_tensor = out->Slice(out_start, out_start + x_seq_len);
|
|
|
|
|
framework::TensorCopy(x_sub_tensor, context.GetPlace(),
|
|
|
|
|
&out_sub_tensor);
|
|
|
|
|
out_offset++;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -99,27 +100,49 @@ template <typename DeviceContext, typename T>
|
|
|
|
|
class SequenceExpandGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
|
auto* d_out = context.Input<LoDTensor>(framework::GradVarName("Out"));
|
|
|
|
|
auto* g_out = context.Input<LoDTensor>(framework::GradVarName("Out"));
|
|
|
|
|
auto* x = context.Input<LoDTensor>("X");
|
|
|
|
|
auto* out = context.Input<LoDTensor>("Out");
|
|
|
|
|
auto* d_x = context.Output<LoDTensor>(framework::GradVarName("X"));
|
|
|
|
|
auto out_last_level = out->lod().back();
|
|
|
|
|
d_x->set_lod(x->lod());
|
|
|
|
|
const T* d_out_data = d_out->data<T>();
|
|
|
|
|
T* d_x_data = d_x->mutable_data<T>(context.GetPlace());
|
|
|
|
|
size_t element_len = d_out->numel() / d_out->dims()[0];
|
|
|
|
|
for (size_t i = 0; i < out_last_level.size() - 1; ++i) {
|
|
|
|
|
size_t repeat = out_last_level[i + 1] - out_last_level[i];
|
|
|
|
|
Eigen::TensorMap<
|
|
|
|
|
Eigen::Tensor<const T, 2, Eigen::RowMajor, Eigen::DenseIndex>>
|
|
|
|
|
d_out_t(d_out_data, static_cast<int>(repeat), element_len);
|
|
|
|
|
Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor, Eigen::DenseIndex>>
|
|
|
|
|
d_x_t(d_x_data, static_cast<int>(element_len));
|
|
|
|
|
auto place =
|
|
|
|
|
context.template device_context<DeviceContext>().eigen_device();
|
|
|
|
|
d_x_t.device(*place) = d_out_t.sum(Eigen::array<int, 1>({{0}}));
|
|
|
|
|
d_out_data += (repeat * element_len);
|
|
|
|
|
d_x_data += element_len;
|
|
|
|
|
auto* y = context.Input<LoDTensor>("Y");
|
|
|
|
|
auto* g_x = context.Output<LoDTensor>(framework::GradVarName("X"));
|
|
|
|
|
int ref_level = context.Attr<int>("ref_level");
|
|
|
|
|
|
|
|
|
|
g_x->mutable_data<T>(context.GetPlace());
|
|
|
|
|
g_x->set_lod(x->lod());
|
|
|
|
|
|
|
|
|
|
auto& x_lod = x->lod();
|
|
|
|
|
auto& y_lod = y->lod();
|
|
|
|
|
|
|
|
|
|
if (ref_level == -1) ref_level = y_lod.size() - 1;
|
|
|
|
|
|
|
|
|
|
// just copy the gradient
|
|
|
|
|
if (y_lod[ref_level].size() <= 1) {
|
|
|
|
|
framework::TensorCopy(*g_out, context.GetPlace(), g_x);
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto& dev_ctx = context.template device_context<DeviceContext>();
|
|
|
|
|
|
|
|
|
|
int g_out_offset = 0;
|
|
|
|
|
for (size_t i = 1; i < y_lod[ref_level].size(); ++i) {
|
|
|
|
|
int repeat_num = y_lod[ref_level][i] - y_lod[ref_level][i - 1];
|
|
|
|
|
if (repeat_num > 0) {
|
|
|
|
|
int x_start = i - 1;
|
|
|
|
|
int x_end = i;
|
|
|
|
|
if (x_lod.size() == 1) {
|
|
|
|
|
x_start = x_lod[0][i - 1];
|
|
|
|
|
x_end = x_lod[0][i];
|
|
|
|
|
}
|
|
|
|
|
int x_seq_len = x_end - x_start;
|
|
|
|
|
auto column = x_seq_len * x->dims()[1];
|
|
|
|
|
auto g_x_sub = g_x->Slice(x_start, x_end);
|
|
|
|
|
g_x_sub = framework::ReshapeToMatrix(g_x_sub, column);
|
|
|
|
|
int g_out_end = g_out_offset + repeat_num * x_seq_len;
|
|
|
|
|
auto g_out_sub = g_out->Slice(g_out_offset, g_out_end);
|
|
|
|
|
g_out_sub = framework::ReshapeToMatrix(g_out_sub, column);
|
|
|
|
|
math::ColwiseSum<DeviceContext, T> col_sum;
|
|
|
|
|
col_sum(dev_ctx, g_out_sub, &g_x_sub);
|
|
|
|
|
g_out_offset += repeat_num * x_seq_len;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|