|
|
|
@ -16,12 +16,15 @@ limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/framework/eigen.h"
|
|
|
|
|
#include "paddle/fluid/framework/op_registry.h"
|
|
|
|
|
#include "paddle/fluid/framework/selected_rows.h"
|
|
|
|
|
#include "paddle/fluid/operators/math/selected_rows_functor.h"
|
|
|
|
|
#include "paddle/fluid/platform/transform.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
|
using Tensor = framework::Tensor;
|
|
|
|
|
using SelectedRows = framework::SelectedRows;
|
|
|
|
|
template <typename T, int MajorType = Eigen::RowMajor,
|
|
|
|
|
typename IndexType = Eigen::DenseIndex>
|
|
|
|
|
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
|
|
|
|
@ -31,9 +34,40 @@ class ClipByNormKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
|
auto max_norm = context.Attr<T>("max_norm");
|
|
|
|
|
auto* input = context.Input<Tensor>("X");
|
|
|
|
|
auto* output = context.Output<Tensor>("Out");
|
|
|
|
|
output->mutable_data<T>(context.GetPlace());
|
|
|
|
|
auto in_var = context.InputVar("X");
|
|
|
|
|
|
|
|
|
|
Tensor* output = nullptr;
|
|
|
|
|
const Tensor* input = nullptr;
|
|
|
|
|
if (in_var->IsType<framework::LoDTensor>()) {
|
|
|
|
|
input = context.Input<Tensor>("X");
|
|
|
|
|
|
|
|
|
|
output = context.Output<Tensor>("Out");
|
|
|
|
|
output->mutable_data<T>(context.GetPlace());
|
|
|
|
|
} else if (in_var->IsType<SelectedRows>()) {
|
|
|
|
|
auto* x = context.Input<SelectedRows>("X");
|
|
|
|
|
|
|
|
|
|
// merge ids in selected rows first
|
|
|
|
|
math::scatter::MergeAdd<DeviceContext, T> merge_func;
|
|
|
|
|
SelectedRows* merged_input =
|
|
|
|
|
const_cast<framework::Scope&>(context.scope())
|
|
|
|
|
.Var()
|
|
|
|
|
->GetMutable<SelectedRows>();
|
|
|
|
|
merge_func(context.template device_context<DeviceContext>(), *x,
|
|
|
|
|
merged_input);
|
|
|
|
|
input = &(merged_input->value());
|
|
|
|
|
|
|
|
|
|
SelectedRows* output_selected_rows = context.Output<SelectedRows>("Out");
|
|
|
|
|
output_selected_rows->set_rows(merged_input->rows());
|
|
|
|
|
output_selected_rows->set_height(merged_input->height());
|
|
|
|
|
output = output_selected_rows->mutable_value();
|
|
|
|
|
output->Resize(merged_input->value().dims());
|
|
|
|
|
output->mutable_data<T>(context.GetPlace());
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_THROW("Unexpected branch, input variable type is %s",
|
|
|
|
|
in_var->Type().name());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(input);
|
|
|
|
|
|
|
|
|
|
auto x = EigenVector<T>::Flatten(*input);
|
|
|
|
|
auto out = EigenVector<T>::Flatten(*output);
|
|
|
|
|