|
|
|
|
@ -16,6 +16,7 @@ limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/framework/eigen.h"
|
|
|
|
|
#include "paddle/fluid/framework/op_registry.h"
|
|
|
|
|
#include "paddle/fluid/operators/math/selected_rows_functor.h"
|
|
|
|
|
#include "paddle/fluid/platform/transform.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
@ -31,10 +32,31 @@ class ClipByNormKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
|
auto max_norm = context.Attr<T>("max_norm");
|
|
|
|
|
auto* input = context.Input<Tensor>("X");
|
|
|
|
|
auto in_var = context.InputVar("X");
|
|
|
|
|
auto* output = context.Output<Tensor>("Out");
|
|
|
|
|
output->mutable_data<T>(context.GetPlace());
|
|
|
|
|
|
|
|
|
|
const Tensor* input = nullptr;
|
|
|
|
|
if (in_var->IsType<framework::LoDTensor>()) {
|
|
|
|
|
input = context.Input<Tensor>("X");
|
|
|
|
|
} else if (in_var->IsType<framework::SelectedRows>()) {
|
|
|
|
|
auto* x = context.Input<framework::SelectedRows>("X");
|
|
|
|
|
|
|
|
|
|
// merge ids in selected rows first
|
|
|
|
|
math::scatter::MergeAdd<DeviceContext, T> merge_func;
|
|
|
|
|
auto* merged_input = const_cast<framework::Scope&>(context.scope())
|
|
|
|
|
.Var()
|
|
|
|
|
->GetMutable<framework::SelectedRows>();
|
|
|
|
|
merge_func(context.template device_context<DeviceContext>(), *x,
|
|
|
|
|
merged_input);
|
|
|
|
|
input = &(merged_input->value());
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_THROW("Unexpected branch, input variable type is %s",
|
|
|
|
|
in_var->Type().name());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(input);
|
|
|
|
|
|
|
|
|
|
auto x = EigenVector<T>::Flatten(*input);
|
|
|
|
|
auto out = EigenVector<T>::Flatten(*output);
|
|
|
|
|
auto x_norm = x.square().sum().sqrt();
|
|
|
|
|
|