commit
						91855659d7
					
				@ -0,0 +1,19 @@
 | 
				
			||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
 | 
				
			||||
 | 
				
			||||
   Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
   you may not use this file except in compliance with the License.
 | 
				
			||||
   You may obtain a copy of the License at
 | 
				
			||||
 | 
				
			||||
   http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
 | 
				
			||||
   Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
   distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
   See the License for the specific language governing permissions and
 | 
				
			||||
   limitations under the License. */
 | 
				
			||||
 | 
				
			||||
#include "paddle/operators/clip_by_norm_op.h"
 | 
				
			||||
 | 
				
			||||
namespace ops = paddle::operators;
 | 
				
			||||
REGISTER_OP_GPU_KERNEL(
 | 
				
			||||
    clip_by_norm, ops::ClipByNormKernel<paddle::platform::GPUPlace, float>);
 | 
				
			||||
@ -0,0 +1,52 @@
 | 
				
			||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
 | 
				
			||||
 | 
				
			||||
   Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
   you may not use this file except in compliance with the License.
 | 
				
			||||
   You may obtain a copy of the License at
 | 
				
			||||
 | 
				
			||||
   http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
 | 
				
			||||
   Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
   distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
   See the License for the specific language governing permissions and
 | 
				
			||||
   limitations under the License. */
 | 
				
			||||
 | 
				
			||||
#pragma once
 | 
				
			||||
 | 
				
			||||
#include "paddle/framework/eigen.h"
 | 
				
			||||
#include "paddle/framework/op_registry.h"
 | 
				
			||||
#include "paddle/platform/transform.h"
 | 
				
			||||
 | 
				
			||||
namespace paddle {
 | 
				
			||||
namespace operators {
 | 
				
			||||
 | 
				
			||||
using Tensor = framework::Tensor;
 | 
				
			||||
template <typename T, int MajorType = Eigen::RowMajor,
 | 
				
			||||
          typename IndexType = Eigen::DenseIndex>
 | 
				
			||||
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
 | 
				
			||||
 | 
				
			||||
template <typename Place, typename T>
 | 
				
			||||
class ClipByNormKernel : public framework::OpKernel<T> {
 | 
				
			||||
 public:
 | 
				
			||||
  void Compute(const framework::ExecutionContext& context) const override {
 | 
				
			||||
    auto max_norm = context.Attr<T>("max_norm");
 | 
				
			||||
    auto* input = context.Input<Tensor>("X");
 | 
				
			||||
    auto* output = context.Output<Tensor>("Out");
 | 
				
			||||
    output->mutable_data<T>(context.GetPlace());
 | 
				
			||||
 | 
				
			||||
    auto x = EigenVector<T>::Flatten(*input);
 | 
				
			||||
    auto out = EigenVector<T>::Flatten(*output);
 | 
				
			||||
    auto x_norm = x.square().sum().sqrt();
 | 
				
			||||
    auto place = context.GetEigenDevice<Place>();
 | 
				
			||||
 | 
				
			||||
    auto temp = (x_norm <= max_norm).template cast<T>().eval();
 | 
				
			||||
    auto scaling = temp + (static_cast<T>(1) - temp) * max_norm / x_norm;
 | 
				
			||||
    Eigen::array<int, 1> one_dim{{1}};
 | 
				
			||||
    Eigen::DSizes<int, 1> m_dsize(input->numel());
 | 
				
			||||
    out.device(place) = x * scaling.reshape(one_dim).broadcast(m_dsize);
 | 
				
			||||
  }
 | 
				
			||||
};
 | 
				
			||||
 | 
				
			||||
}  // namespace operators
 | 
				
			||||
}  // namespace paddle
 | 
				
			||||
@ -0,0 +1,50 @@
 | 
				
			||||
import unittest
 | 
				
			||||
import numpy as np
 | 
				
			||||
from op_test import OpTest
 | 
				
			||||
 | 
				
			||||
 | 
				
			||||
class TestClipByNormOp(OpTest):
 | 
				
			||||
    def setUp(self):
 | 
				
			||||
        self.max_relative_error = 0.006
 | 
				
			||||
        self.initTestCase()
 | 
				
			||||
        input = np.random.random(self.shape).astype("float32")
 | 
				
			||||
        input[np.abs(input) < self.max_relative_error] = 0.5
 | 
				
			||||
        self.op_type = "clip_by_norm"
 | 
				
			||||
        self.inputs = {'X': input, }
 | 
				
			||||
        self.attrs = {}
 | 
				
			||||
        self.attrs['max_norm'] = self.max_norm
 | 
				
			||||
        norm = np.sqrt(np.sum(np.square(input)))
 | 
				
			||||
        if norm > self.max_norm:
 | 
				
			||||
            output = self.max_norm * input / norm
 | 
				
			||||
        else:
 | 
				
			||||
            output = input
 | 
				
			||||
        self.outputs = {'Out': output}
 | 
				
			||||
 | 
				
			||||
    def test_check_output(self):
 | 
				
			||||
        self.check_output()
 | 
				
			||||
 | 
				
			||||
    def initTestCase(self):
 | 
				
			||||
        self.shape = (100, )
 | 
				
			||||
        self.max_norm = 1.0
 | 
				
			||||
 | 
				
			||||
 | 
				
			||||
class TestCase1(TestClipByNormOp):
 | 
				
			||||
    def initTestCase(self):
 | 
				
			||||
        self.shape = (100, )
 | 
				
			||||
        self.max_norm = 1e20
 | 
				
			||||
 | 
				
			||||
 | 
				
			||||
class TestCase2(TestClipByNormOp):
 | 
				
			||||
    def initTestCase(self):
 | 
				
			||||
        self.shape = (16, 16)
 | 
				
			||||
        self.max_norm = 0.1
 | 
				
			||||
 | 
				
			||||
 | 
				
			||||
class TestCase3(TestClipByNormOp):
 | 
				
			||||
    def initTestCase(self):
 | 
				
			||||
        self.shape = (4, 8, 16)
 | 
				
			||||
        self.max_norm = 1.0
 | 
				
			||||
 | 
				
			||||
 | 
				
			||||
if __name__ == '__main__':
 | 
				
			||||
    unittest.main()
 | 
				
			||||
					Loading…
					
					
				
		Reference in new issue