parent
							
								
									e81f0228df
								
							
						
					
					
						commit
						e2d849b989
					
				| @ -0,0 +1,57 @@ | ||||
| // Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
 | ||||
| //
 | ||||
| // Licensed under the Apache License, Version 2.0 (the "License");
 | ||||
| // you may not use this file except in compliance with the License.
 | ||||
| // You may obtain a copy of the License at
 | ||||
| //
 | ||||
| //     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| //
 | ||||
| // Unless required by applicable law or agreed to in writing, software
 | ||||
| // distributed under the License is distributed on an "AS IS" BASIS,
 | ||||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | ||||
| // See the License for the specific language governing permissions and
 | ||||
| // limitations under the License.
 | ||||
| 
 | ||||
| #include "paddle/fluid/operators/seed_op.h" | ||||
| 
 | ||||
| namespace paddle { | ||||
| namespace operators { | ||||
| 
 | ||||
| using Tensor = framework::Tensor; | ||||
| class SeedOp : public framework::OperatorWithKernel { | ||||
|  public: | ||||
|   using framework::OperatorWithKernel::OperatorWithKernel; | ||||
| 
 | ||||
|   void InferShape(framework::InferShapeContext* ctx) const override { | ||||
|     ctx->SetOutputDim("Out", {1}); | ||||
|   } | ||||
| 
 | ||||
|  protected: | ||||
|   framework::OpKernelType GetExpectedKernelType( | ||||
|       const framework::ExecutionContext& ctx) const override { | ||||
|     return framework::OpKernelType(framework::proto::VarType::INT32, | ||||
|                                    platform::CPUPlace()); | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| class SeedOpMaker : public framework::OpProtoAndCheckerMaker { | ||||
|  public: | ||||
|   void Make() override { | ||||
|     AddOutput("Out", "The output of seed op."); | ||||
|     AddAttr<int>("seed", "Dropout random seed.").SetDefault(0); | ||||
|     AddComment(R"DOC( | ||||
| Seed Operator. | ||||
| )DOC"); | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| }  // namespace operators
 | ||||
| }  // namespace paddle
 | ||||
| 
 | ||||
| namespace ops = paddle::operators; | ||||
| REGISTER_OPERATOR( | ||||
|     seed, ops::SeedOp, ops::SeedOpMaker, | ||||
|     paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, | ||||
|     paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); | ||||
| REGISTER_OP_CPU_KERNEL( | ||||
|     seed, ops::CPUSeedKernel<paddle::platform::CPUDeviceContext, int>); | ||||
| @ -0,0 +1,44 @@ | ||||
| // Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
 | ||||
| //
 | ||||
| // Licensed under the Apache License, Version 2.0 (the "License");
 | ||||
| // you may not use this file except in compliance with the License.
 | ||||
| // You may obtain a copy of the License at
 | ||||
| //
 | ||||
| //     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| //
 | ||||
| // Unless required by applicable law or agreed to in writing, software
 | ||||
| // distributed under the License is distributed on an "AS IS" BASIS,
 | ||||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | ||||
| // See the License for the specific language governing permissions and
 | ||||
| // limitations under the License.
 | ||||
| #pragma once | ||||
| 
 | ||||
| #include "paddle/fluid/framework/op_registry.h" | ||||
| 
 | ||||
| namespace paddle { | ||||
| namespace operators { | ||||
| using Tensor = framework::Tensor; | ||||
| 
 | ||||
| template <typename DeviceContext, typename T> | ||||
| class CPUSeedKernel : public framework::OpKernel<T> { | ||||
|  public: | ||||
|   void Compute(const framework::ExecutionContext& context) const override { | ||||
|     auto* out = context.Output<Tensor>("Out"); | ||||
|     auto* out_data = out->mutable_data<T>(context.GetPlace()); | ||||
|     int user_seed = context.Attr<int>("seed"); | ||||
| 
 | ||||
|     // NOTE: fixed seed should only be used in unittest or for debug.
 | ||||
|     // Guarantee to use random seed in training.
 | ||||
|     std::random_device rnd; | ||||
|     int seed; | ||||
|     if (user_seed != 0) { | ||||
|       seed = user_seed; | ||||
|     } else { | ||||
|       seed = rnd(); | ||||
|     } | ||||
|     out_data[0] = seed; | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| }  // namespace operators
 | ||||
| }  // namespace paddle
 | ||||
| @ -0,0 +1,46 @@ | ||||
| #   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| 
 | ||||
| from __future__ import print_function | ||||
| 
 | ||||
| import unittest | ||||
| import numpy as np | ||||
| from op_test import OpTest | ||||
| import paddle.fluid as fluid | ||||
| 
 | ||||
| 
 | ||||
| class TestSeedOpFixSeed(OpTest): | ||||
|     def setUp(self): | ||||
|         self.op_type = "seed" | ||||
|         self.inputs = {} | ||||
|         self.attrs = {"seed": 123} | ||||
|         self.outputs = {"Out": np.asarray((123)).astype('int32')} | ||||
| 
 | ||||
|     def test_check_output(self): | ||||
|         self.check_output() | ||||
| 
 | ||||
| 
 | ||||
| class TestSeedOpDiffSeed(OpTest): | ||||
|     def setUp(self): | ||||
|         self.op_type = "seed" | ||||
|         self.inputs = {} | ||||
|         self.attrs = {"seed": 0} | ||||
|         self.outputs = {"Out": np.asarray((123)).astype('int32')} | ||||
| 
 | ||||
|     def test_check_output(self): | ||||
|         self.check_output(no_check_set=["Out"]) | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == '__main__': | ||||
|     unittest.main() | ||||
					Loading…
					
					
				
		Reference in new issue