|
|
|
@ -40,7 +40,7 @@ struct SequenceExpandFunctor<platform::CPUDeviceContext, T> {
|
|
|
|
|
LoDTensor* out) {
|
|
|
|
|
auto x_dims = x.dims();
|
|
|
|
|
size_t element_len = framework::product(x_dims) / x_dims[0];
|
|
|
|
|
const T* x_data = x->data<T>();
|
|
|
|
|
const T* x_data = x.data<T>();
|
|
|
|
|
T* out_data = out->mutable_data<T>(context.GetPlace());
|
|
|
|
|
auto out_starts = out->lod().back();
|
|
|
|
|
|
|
|
|
@ -92,12 +92,12 @@ class SequenceExpandKernel : public framework::OpKernel<T> {
|
|
|
|
|
* */
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct SequenceExpandGradFunctor<platform::CPUDeviceContext, T> {
|
|
|
|
|
void operator()(const platform::CPUDeviceContext& ctx, const LoDTensor& x,
|
|
|
|
|
void operator()(const platform::CPUDeviceContext& context, const LoDTensor& x,
|
|
|
|
|
const LoDTensor& out, const LoDTensor& dout, LoDTensor* dx) {
|
|
|
|
|
auto out_last_level = out.lod().back();
|
|
|
|
|
const T* d_out_data = d_out.data<T>();
|
|
|
|
|
T* d_x_data = d_x->mutable_data<T>(context.GetPlace());
|
|
|
|
|
size_t element_len = d_out.numel() / d_out.dims()[0];
|
|
|
|
|
const T* d_out_data = dout.data<T>();
|
|
|
|
|
T* d_x_data = dx->mutable_data<T>(context.GetPlace());
|
|
|
|
|
size_t element_len = dout.numel() / dout.dims()[0];
|
|
|
|
|
for (size_t i = 0; i < out_last_level.size() - 1; ++i) {
|
|
|
|
|
size_t repeat = out_last_level[i + 1] - out_last_level[i];
|
|
|
|
|
Eigen::TensorMap<
|
|
|
|
@ -117,13 +117,15 @@ template <typename DeviceContext, typename T>
|
|
|
|
|
class SequenceExpandGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
|
auto* d_out = context.Input<LoDTensor>(framework::GradVarName("Out"));
|
|
|
|
|
auto* x = context.Input<LoDTensor>("X");
|
|
|
|
|
auto* out = context.Input<LoDTensor>("Out");
|
|
|
|
|
auto* d_out = context.Input<LoDTensor>(framework::GradVarName("Out"));
|
|
|
|
|
|
|
|
|
|
auto* d_x = context.Output<LoDTensor>(framework::GradVarName("X"));
|
|
|
|
|
d_x->set_lod(x->lod());
|
|
|
|
|
SequenceExpandGradFunctor(context.template device_context(), *x, *out,
|
|
|
|
|
d_out, d_x);
|
|
|
|
|
SequenceExpandGradFunctor<DeviceContext, T> functor;
|
|
|
|
|
functor(context.template device_context<DeviceContext>(), *x, *out, *d_out,
|
|
|
|
|
d_x);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|