【NPU】Support npu op gelu and gelu_grad (#31530)
	
		
	
				
					
				
			* Support npu op gelu and gelu_grad * Support npu op gelu and gelu_gradrevert-31562-mean
							parent
							
								
									5d29a27c2e
								
							
						
					
					
						commit
						382fc31f89
					
				| @ -0,0 +1,89 @@ | ||||
| /* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
 | ||||
| 
 | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | ||||
| You may obtain a copy of the License at | ||||
| 
 | ||||
|     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| 
 | ||||
| Unless required by applicable law or agreed to in writing, software | ||||
| distributed under the License is distributed on an "AS IS" BASIS, | ||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| See the License for the specific language governing permissions and | ||||
| limitations under the License. */ | ||||
| 
 | ||||
| #include <memory> | ||||
| #include <string> | ||||
| 
 | ||||
| #include "paddle/fluid/operators/gelu_op.h" | ||||
| #include "paddle/fluid/operators/npu_op_runner.h" | ||||
| 
 | ||||
| namespace paddle { | ||||
| namespace operators { | ||||
| 
 | ||||
| using Tensor = framework::Tensor; | ||||
| 
 | ||||
| template <typename DeviceContext, typename T> | ||||
| class GeluNPUKernel : public framework::OpKernel<T> { | ||||
|  public: | ||||
|   void Compute(const framework::ExecutionContext& ctx) const override { | ||||
|     auto* x = ctx.Input<Tensor>("X"); | ||||
| 
 | ||||
|     auto* out = ctx.Output<Tensor>("Out"); | ||||
| 
 | ||||
|     auto place = ctx.GetPlace(); | ||||
| 
 | ||||
|     out->mutable_data<T>(place); | ||||
| 
 | ||||
|     auto stream = | ||||
|         ctx.template device_context<paddle::platform::NPUDeviceContext>() | ||||
|             .stream(); | ||||
| 
 | ||||
|     auto runner = NpuOpRunner("Gelu", {*x}, {*out}, {}); | ||||
|     runner.Run(stream); | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| template <typename DeviceContext, typename T> | ||||
| class GeluGradNPUKernel : public framework::OpKernel<T> { | ||||
|  public: | ||||
|   void Compute(const framework::ExecutionContext& ctx) const override { | ||||
|     auto* x = ctx.Input<Tensor>("X"); | ||||
|     auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out")); | ||||
| 
 | ||||
|     auto* dx = ctx.Output<Tensor>(framework::GradVarName("X")); | ||||
| 
 | ||||
|     auto place = ctx.GetPlace(); | ||||
| 
 | ||||
|     dx->mutable_data<T>(place); | ||||
| 
 | ||||
|     auto stream = | ||||
|         ctx.template device_context<paddle::platform::NPUDeviceContext>() | ||||
|             .stream(); | ||||
| 
 | ||||
|     Tensor out(x->type()); | ||||
|     out.mutable_data<T>(x->dims(), place); | ||||
|     auto out_runner = NpuOpRunner("Gelu", {*x}, {out}, {}); | ||||
|     out_runner.Run(stream); | ||||
| 
 | ||||
|     auto dx_runner = NpuOpRunner("GeluGrad", {*dout, *x, out}, {*dx}, {}); | ||||
|     dx_runner.Run(stream); | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| }  // namespace operators
 | ||||
| }  // namespace paddle
 | ||||
| 
 | ||||
| namespace ops = paddle::operators; | ||||
| 
 | ||||
| REGISTER_OP_NPU_KERNEL( | ||||
|     gelu, | ||||
|     ops::GeluNPUKernel<paddle::platform::NPUDeviceContext, float>, | ||||
|     ops::GeluNPUKernel<paddle::platform::NPUDeviceContext, | ||||
|     paddle::platform::float16>); | ||||
| 
 | ||||
| REGISTER_OP_NPU_KERNEL( | ||||
|     gelu_grad, | ||||
|     ops::GeluGradNPUKernel<paddle::platform::NPUDeviceContext, float>, | ||||
|     ops::GeluGradNPUKernel<paddle::platform::NPUDeviceContext, | ||||
|     paddle::platform::float16>); | ||||
| @ -0,0 +1,169 @@ | ||||
| /* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
 | ||||
| 
 | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | ||||
| You may obtain a copy of the License at | ||||
| 
 | ||||
|     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| 
 | ||||
| Unless required by applicable law or agreed to in writing, software | ||||
| distributed under the License is distributed on an "AS IS" BASIS, | ||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| See the License for the specific language governing permissions and | ||||
| limitations under the License. */ | ||||
| 
 | ||||
| #ifndef _WIN32 | ||||
| #include <unistd.h> | ||||
| #endif | ||||
| 
 | ||||
| #include <string> | ||||
| #include <thread>  // NOLINT | ||||
| #include <vector> | ||||
| 
 | ||||
| #include "gtest/gtest.h" | ||||
| #include "paddle/fluid/framework/op_registry.h" | ||||
| #include "paddle/fluid/framework/operator.h" | ||||
| #include "paddle/fluid/framework/program_desc.h" | ||||
| #include "paddle/fluid/operators/dropout_op.h" | ||||
| #include "paddle/fluid/operators/math/math_function.h" | ||||
| #include "paddle/fluid/string/printf.h" | ||||
| 
 | ||||
| namespace f = paddle::framework; | ||||
| namespace p = paddle::platform; | ||||
| namespace m = paddle::operators::math; | ||||
| 
 | ||||
| USE_OP(gelu); | ||||
| USE_OP_DEVICE_KERNEL(gelu, NPU); | ||||
| 
 | ||||
| template <typename T> | ||||
| void Compare(f::Scope* scope, const p::DeviceContext& ctx) { | ||||
|   // init
 | ||||
|   auto x = scope->Var("X"); | ||||
|   auto tensor_x = x->GetMutable<f::LoDTensor>(); | ||||
| 
 | ||||
|   std::vector<T> init_x; | ||||
|   for (int64_t i = 0; i < 10 * 10; ++i) { | ||||
|     init_x.push_back(static_cast<T>(1.0)); | ||||
|   } | ||||
| 
 | ||||
|   TensorFromVector(init_x, ctx, tensor_x); | ||||
|   tensor_x->Resize({10, 10}); | ||||
| 
 | ||||
|   auto out = scope->Var("Out"); | ||||
|   auto tensor_out = out->GetMutable<f::LoDTensor>(); | ||||
| 
 | ||||
|   f::AttributeMap attrs; | ||||
| 
 | ||||
|   ctx.Wait(); | ||||
| 
 | ||||
|   // run
 | ||||
|   auto place = ctx.GetPlace(); | ||||
| 
 | ||||
|   auto op = f::OpRegistry::CreateOp("gelu", {{"X", {"X"}}}, | ||||
|                                     {{"Out", {"Out"}}}, attrs); | ||||
|   op->Run(*scope, place); | ||||
| 
 | ||||
|   ctx.Wait(); | ||||
| 
 | ||||
|   // eval time
 | ||||
|   struct timeval start, end; | ||||
|   gettimeofday(&start, NULL); | ||||
| 
 | ||||
|   for (int i = 0; i < 100; i++) { | ||||
|     op->Run(*scope, place); | ||||
|   } | ||||
| 
 | ||||
|   ctx.Wait(); | ||||
| 
 | ||||
|   gettimeofday(&end, NULL); | ||||
|   int micros = (((end.tv_sec - start.tv_sec) * 1000000) + | ||||
|                   end.tv_usec) - (start.tv_usec); | ||||
|   printf("used time: %d\n", micros / 100); | ||||
| 
 | ||||
|   // eval value
 | ||||
|   std::vector<T> out_vec; | ||||
|   TensorToVector(*tensor_out, ctx, &out_vec); | ||||
| 
 | ||||
|   float expected = 0.841192; | ||||
|   for (uint32_t i = 0; i < out_vec.size(); i++) { | ||||
|     EXPECT_FLOAT_EQ(out_vec[i], static_cast<T>(expected)); | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| template <typename T> | ||||
| void CompareGrad(f::Scope* scope, const p::DeviceContext& ctx) { | ||||
|   auto dout = scope->Var("DOut"); | ||||
|   auto tensor_dout = dout->GetMutable<f::LoDTensor>(); | ||||
| 
 | ||||
|   auto x = scope->Var("X"); | ||||
|   auto tensor_x = x->GetMutable<f::LoDTensor>(); | ||||
| 
 | ||||
|   std::vector<T> init_dout; | ||||
|   for (int64_t i = 0; i < 10 * 10; ++i) { | ||||
|     init_dout.push_back(static_cast<T>(1.0)); | ||||
|   } | ||||
| 
 | ||||
|   std::vector<T> init_x; | ||||
|   for (int64_t i = 0; i < 10 * 10; ++i) { | ||||
|     init_x.push_back(static_cast<T>(1.0)); | ||||
|   } | ||||
| 
 | ||||
|   TensorFromVector(init_dout, ctx, tensor_dout); | ||||
|   tensor_dout->Resize({10, 10}); | ||||
|   TensorFromVector(init_x, ctx, tensor_x); | ||||
|   tensor_x->Resize({10, 10}); | ||||
| 
 | ||||
|   auto dx = scope->Var("DX"); | ||||
|   auto tensor_dx = dx->GetMutable<f::LoDTensor>(); | ||||
| 
 | ||||
|   f::AttributeMap attrs; | ||||
| 
 | ||||
|   ctx.Wait(); | ||||
| 
 | ||||
|   // run
 | ||||
|   auto place = ctx.GetPlace(); | ||||
| 
 | ||||
|   auto op = f::OpRegistry::CreateOp("gelu_grad", | ||||
|     {{"Out@GRAD", {"DOut"}}, {"X", {"X"}}}, | ||||
|     {{"X@GRAD", {"DX"}}}, attrs); | ||||
|   op->Run(*scope, place); | ||||
| 
 | ||||
|   ctx.Wait(); | ||||
| 
 | ||||
|   // eval time
 | ||||
|   struct timeval start, end; | ||||
|   gettimeofday(&start, NULL); | ||||
| 
 | ||||
|   for (int i = 0; i < 100; i++) { | ||||
|     op->Run(*scope, place); | ||||
|   } | ||||
| 
 | ||||
|   ctx.Wait(); | ||||
| 
 | ||||
|   gettimeofday(&end, NULL); | ||||
|   int micros = (((end.tv_sec - start.tv_sec) * 1000000) + | ||||
|                   end.tv_usec) - (start.tv_usec); | ||||
|   printf("used time: %d\n", micros / 100); | ||||
| 
 | ||||
|   // eval value
 | ||||
|   std::vector<T> dx_vec; | ||||
|   TensorToVector(*tensor_dx, ctx, &dx_vec); | ||||
| 
 | ||||
|   float expected = 1.082964; | ||||
|   for (uint32_t i = 0; i < dx_vec.size(); i++) { | ||||
|     EXPECT_FLOAT_EQ(dx_vec[i], static_cast<T>(expected)); | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| TEST(gelu, NPU_fp32) { | ||||
|     f::Scope scope; | ||||
|     p::NPUDeviceContext ctx(p::NPUPlace(0)); | ||||
|     Compare<float>(&scope, ctx); | ||||
| } | ||||
| 
 | ||||
| TEST(gelu_grad, NPU) { | ||||
|     f::Scope scope; | ||||
|     p::NPUDeviceContext ctx(p::NPUPlace(0)); | ||||
|     CompareGrad<float>(&scope, ctx); | ||||
| } | ||||
| 
 | ||||
| @ -0,0 +1,160 @@ | ||||
| #  Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #    http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| 
 | ||||
| from __future__ import print_function | ||||
| 
 | ||||
| import numpy as np | ||||
| from scipy import special | ||||
| import unittest | ||||
| import sys | ||||
| sys.path.append("..") | ||||
| from op_test import OpTest | ||||
| import paddle | ||||
| import paddle.fluid as fluid | ||||
| 
 | ||||
| paddle.enable_static() | ||||
| SEED = 2021 | ||||
| 
 | ||||
| 
 | ||||
| def np_gelu(x): | ||||
|     y = 0.5 * x * (1 + special.erf(x / np.sqrt(2))) | ||||
|     return y | ||||
| 
 | ||||
| 
 | ||||
| @unittest.skipIf(not paddle.is_compiled_with_npu(), | ||||
|                  "core is not compiled with NPU") | ||||
| class TestGelu(OpTest): | ||||
|     def setUp(self): | ||||
|         self.set_npu() | ||||
|         self.op_type = "gelu" | ||||
|         self.place = paddle.NPUPlace(0) | ||||
| 
 | ||||
|         self.init_dtype() | ||||
|         np.random.seed(SEED) | ||||
|         x = np.random.uniform(1, 2, [11, 17]).astype(self.dtype) | ||||
|         out = np_gelu(x) | ||||
| 
 | ||||
|         self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} | ||||
|         self.attrs = {} | ||||
|         self.outputs = {'Out': out} | ||||
| 
 | ||||
|     def set_npu(self): | ||||
|         self.__class__.use_npu = True | ||||
| 
 | ||||
|     def init_dtype(self): | ||||
|         self.dtype = np.float32 | ||||
| 
 | ||||
|     def test_check_output(self): | ||||
|         self.check_output_with_place(self.place, check_dygraph=False, atol=1e-3) | ||||
| 
 | ||||
|     # TODO(ascendrc): Add grad test | ||||
|     # def test_check_grad(self): | ||||
|     #     if self.dtype == np.float16: | ||||
|     #         return | ||||
|     #     self.check_grad(['X'], 'Out') | ||||
|     # | ||||
| 
 | ||||
| 
 | ||||
| @unittest.skipIf(not paddle.is_compiled_with_npu(), | ||||
|                  "core is not compiled with NPU") | ||||
| class TestGeluFp16(OpTest): | ||||
|     def setUp(self): | ||||
|         self.set_npu() | ||||
|         self.op_type = "gelu" | ||||
|         self.place = paddle.NPUPlace(0) | ||||
| 
 | ||||
|         self.init_dtype() | ||||
|         np.random.seed(SEED) | ||||
|         x = np.random.uniform(1, 2, [3, 4]).astype(self.dtype) | ||||
|         out = np_gelu(x) | ||||
| 
 | ||||
|         self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} | ||||
|         self.attrs = {} | ||||
|         self.outputs = {'Out': out} | ||||
| 
 | ||||
|     def set_npu(self): | ||||
|         self.__class__.use_npu = True | ||||
|         self.__class__.no_need_check_grad = True | ||||
| 
 | ||||
|     def init_dtype(self): | ||||
|         self.dtype = np.float16 | ||||
| 
 | ||||
|     def test_check_output(self): | ||||
|         self.check_output_with_place(self.place, check_dygraph=False, atol=1e-3) | ||||
| 
 | ||||
| 
 | ||||
| @unittest.skipIf(not paddle.is_compiled_with_npu(), | ||||
|                  "core is not compiled with NPU") | ||||
| class TestGeluNet(unittest.TestCase): | ||||
|     def _test(self, run_npu=True): | ||||
|         main_prog = paddle.static.Program() | ||||
|         startup_prog = paddle.static.Program() | ||||
|         main_prog.random_seed = SEED | ||||
|         startup_prog.random_seed = SEED | ||||
|         np.random.seed(SEED) | ||||
| 
 | ||||
|         a_np = np.random.random(size=(32, 32)).astype('float32') | ||||
|         b_np = np.random.random(size=(32, 32)).astype('float32') | ||||
|         label_np = np.random.randint(2, size=(32, 1)).astype('int64') | ||||
| 
 | ||||
|         with paddle.static.program_guard(main_prog, startup_prog): | ||||
|             a = paddle.static.data(name="a", shape=[32, 32], dtype='float32') | ||||
|             b = paddle.static.data(name="b", shape=[32, 32], dtype='float32') | ||||
|             label = paddle.static.data( | ||||
|                 name="label", shape=[32, 1], dtype='int64') | ||||
| 
 | ||||
|             c = paddle.multiply(a, b) | ||||
|             d = fluid.layers.gelu(c) | ||||
| 
 | ||||
|             fc_1 = fluid.layers.fc(input=d, size=128) | ||||
|             prediction = fluid.layers.fc(input=fc_1, size=2, act='softmax') | ||||
| 
 | ||||
|             cost = fluid.layers.cross_entropy(input=prediction, label=label) | ||||
|             loss = fluid.layers.reduce_mean(cost) | ||||
|             sgd = fluid.optimizer.SGD(learning_rate=0.01) | ||||
|             sgd.minimize(loss) | ||||
| 
 | ||||
|         if run_npu: | ||||
|             place = paddle.NPUPlace(0) | ||||
|         else: | ||||
|             place = paddle.CPUPlace() | ||||
| 
 | ||||
|         exe = paddle.static.Executor(place) | ||||
|         exe.run(startup_prog) | ||||
| 
 | ||||
|         print("Start run on {}".format(place)) | ||||
|         for epoch in range(100): | ||||
| 
 | ||||
|             pred_res, loss_res = exe.run( | ||||
|                 main_prog, | ||||
|                 feed={"a": a_np, | ||||
|                       "b": b_np, | ||||
|                       "label": label_np}, | ||||
|                 fetch_list=[prediction, loss]) | ||||
|             if epoch % 10 == 0: | ||||
|                 print("Epoch {} | Prediction[0]: {}, Loss: {}".format( | ||||
|                     epoch, pred_res[0], loss_res)) | ||||
| 
 | ||||
|         return pred_res, loss_res | ||||
| 
 | ||||
|     def test_npu(self): | ||||
|         cpu_pred, cpu_loss = self._test(False) | ||||
|         npu_pred, npu_loss = self._test(True) | ||||
| 
 | ||||
|         self.assertTrue(np.allclose(npu_pred, cpu_pred, atol=1e-3)) | ||||
|         self.assertTrue(np.allclose(npu_loss, cpu_loss, atol=1e-3)) | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == '__main__': | ||||
|     unittest.main() | ||||
					Loading…
					
					
				
		Reference in new issue