|
|
|
@ -16,6 +16,7 @@ limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/framework/eigen.h"
|
|
|
|
|
#include "paddle/fluid/framework/op_registry.h"
|
|
|
|
|
#include "paddle/fluid/framework/selected_rows.h"
|
|
|
|
|
#include "paddle/fluid/operators/math/selected_rows_functor.h"
|
|
|
|
|
#include "paddle/fluid/platform/transform.h"
|
|
|
|
|
|
|
|
|
@ -23,6 +24,7 @@ namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
|
using Tensor = framework::Tensor;
|
|
|
|
|
using SelectedRows = framework::SelectedRows;
|
|
|
|
|
template <typename T, int MajorType = Eigen::RowMajor,
|
|
|
|
|
typename IndexType = Eigen::DenseIndex>
|
|
|
|
|
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
|
|
|
|
@ -41,22 +43,24 @@ class ClipByNormKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
output = context.Output<Tensor>("Out");
|
|
|
|
|
output->mutable_data<T>(context.GetPlace());
|
|
|
|
|
} else if (in_var->IsType<framework::SelectedRows>()) {
|
|
|
|
|
auto* x = context.Input<framework::SelectedRows>("X");
|
|
|
|
|
} else if (in_var->IsType<SelectedRows>()) {
|
|
|
|
|
auto* x = context.Input<SelectedRows>("X");
|
|
|
|
|
|
|
|
|
|
// merge ids in selected rows first
|
|
|
|
|
math::scatter::MergeAdd<DeviceContext, T> merge_func;
|
|
|
|
|
auto* merged_input = const_cast<framework::Scope&>(context.scope())
|
|
|
|
|
.Var()
|
|
|
|
|
->GetMutable<framework::SelectedRows>();
|
|
|
|
|
SelectedRows* merged_input =
|
|
|
|
|
const_cast<framework::Scope&>(context.scope())
|
|
|
|
|
.Var()
|
|
|
|
|
->GetMutable<SelectedRows>();
|
|
|
|
|
merge_func(context.template device_context<DeviceContext>(), *x,
|
|
|
|
|
merged_input);
|
|
|
|
|
input = &(merged_input->value());
|
|
|
|
|
|
|
|
|
|
auto* output_selected_rows = context.Output<SelectedRows>("Out");
|
|
|
|
|
output_selected_rows->set_rows(merged_input.rows());
|
|
|
|
|
output = output_selected_rows->mutable_data();
|
|
|
|
|
output->Resize(framework::make_ddim(merged_input.value().dims()));
|
|
|
|
|
SelectedRows* output_selected_rows = context.Output<SelectedRows>("Out");
|
|
|
|
|
output_selected_rows->set_rows(merged_input->rows());
|
|
|
|
|
output_selected_rows->set_height(merged_input->height());
|
|
|
|
|
output = output_selected_rows->mutable_value();
|
|
|
|
|
output->Resize(merged_input->value().dims());
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_THROW("Unexpected branch, input variable type is %s",
|
|
|
|
|
in_var->Type().name());
|
|
|
|
|