|
|
|
@ -33,12 +33,14 @@ class ClipByNormKernel : public framework::OpKernel<T> {
|
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
|
auto max_norm = context.Attr<T>("max_norm");
|
|
|
|
|
auto in_var = context.InputVar("X");
|
|
|
|
|
auto* output = context.Output<Tensor>("Out");
|
|
|
|
|
output->mutable_data<T>(context.GetPlace());
|
|
|
|
|
|
|
|
|
|
Tensor* output = nullptr;
|
|
|
|
|
const Tensor* input = nullptr;
|
|
|
|
|
if (in_var->IsType<framework::LoDTensor>()) {
|
|
|
|
|
input = context.Input<Tensor>("X");
|
|
|
|
|
|
|
|
|
|
output = context.Output<Tensor>("Out");
|
|
|
|
|
output->mutable_data<T>(context.GetPlace());
|
|
|
|
|
} else if (in_var->IsType<framework::SelectedRows>()) {
|
|
|
|
|
auto* x = context.Input<framework::SelectedRows>("X");
|
|
|
|
|
|
|
|
|
@ -50,6 +52,11 @@ class ClipByNormKernel : public framework::OpKernel<T> {
|
|
|
|
|
merge_func(context.template device_context<DeviceContext>(), *x,
|
|
|
|
|
merged_input);
|
|
|
|
|
input = &(merged_input->value());
|
|
|
|
|
|
|
|
|
|
auto* output_selected_rows = context.Output<SelectedRows>("Out");
|
|
|
|
|
output_selected_rows->set_rows(merged_input.rows());
|
|
|
|
|
output = output_selected_rows->mutable_data();
|
|
|
|
|
output->Resize(framework::make_ddim(merged_input.value().dims()));
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_THROW("Unexpected branch, input variable type is %s",
|
|
|
|
|
in_var->Type().name());
|
|
|
|
|