You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							110 lines
						
					
					
						
							3.4 KiB
						
					
					
				
			
		
		
	
	
							110 lines
						
					
					
						
							3.4 KiB
						
					
					
				| /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
 | |
| 
 | |
|    Licensed under the Apache License, Version 2.0 (the "License");
 | |
|    you may not use this file except in compliance with the License.
 | |
|    You may obtain a copy of the License at
 | |
| 
 | |
|    http://www.apache.org/licenses/LICENSE-2.0
 | |
| 
 | |
|    Unless required by applicable law or agreed to in writing, software
 | |
|    distributed under the License is distributed on an "AS IS" BASIS,
 | |
|    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
|    See the License for the specific language governing permissions and
 | |
|    limitations under the License. */
 | |
| 
 | |
| #pragma once
 | |
| 
 | |
| #include "paddle/framework/grad_op_desc_maker.h"
 | |
| #include "paddle/framework/op_info.h"
 | |
| #include "paddle/framework/op_proto_maker.h"
 | |
| #include "paddle/framework/operator.h"
 | |
| 
 | |
| namespace paddle {
 | |
| namespace framework {
 | |
| namespace details {
 | |
| 
 | |
| enum OpInfoFillType {
 | |
|   kOperator = 0,
 | |
|   kOpProtoAndCheckerMaker = 1,
 | |
|   kGradOpDescMaker = 2
 | |
| };
 | |
| 
 | |
| template <typename T>
 | |
| struct OpInfoFillTypeID {
 | |
|   static constexpr OpInfoFillType ID() {
 | |
|     return std::is_base_of<OperatorBase, T>::value
 | |
|                ? kOperator
 | |
|                : (std::is_base_of<OpProtoAndCheckerMaker, T>::value
 | |
|                       ? kOpProtoAndCheckerMaker
 | |
|                       : (std::is_base_of<GradOpDescMakerBase, T>::value
 | |
|                              ? kGradOpDescMaker
 | |
|                              : static_cast<OpInfoFillType>(-1)));
 | |
|   }
 | |
| };
 | |
| 
 | |
| template <typename T, OpInfoFillType = OpInfoFillTypeID<T>::ID()>
 | |
| struct OpInfoFiller;
 | |
| 
 | |
| template <size_t I, bool at_end, typename... ARGS>
 | |
| class OperatorRegistrarRecursive;
 | |
| 
 | |
| template <size_t I, typename... ARGS>
 | |
| class OperatorRegistrarRecursive<I, false, ARGS...> {
 | |
|  public:
 | |
|   using T = typename std::tuple_element<I, std::tuple<ARGS...>>::type;
 | |
|   OperatorRegistrarRecursive(const char* op_type, OpInfo* info) {
 | |
|     OpInfoFiller<T> fill;
 | |
|     fill(op_type, info);
 | |
|     constexpr auto size = sizeof...(ARGS);
 | |
|     OperatorRegistrarRecursive<I + 1, I + 1 == size, ARGS...> reg(op_type,
 | |
|                                                                   info);
 | |
|     (void)(reg);
 | |
|   }
 | |
| };
 | |
| 
 | |
| template <size_t I, typename... ARGS>
 | |
| class OperatorRegistrarRecursive<I, true, ARGS...> {
 | |
|  public:
 | |
|   OperatorRegistrarRecursive(const char* op_type, OpInfo* info) {}
 | |
| };
 | |
| 
 | |
| template <typename T>
 | |
| struct OpInfoFiller<T, kOperator> {
 | |
|   void operator()(const char* op_type, OpInfo* info) const {
 | |
|     info->creator_ = [](const std::string& type, const VariableNameMap& inputs,
 | |
|                         const VariableNameMap& outputs,
 | |
|                         const AttributeMap& attrs) {
 | |
|       return new T(type, inputs, outputs, attrs);
 | |
|     };
 | |
|   }
 | |
| };
 | |
| 
 | |
| template <typename T>
 | |
| struct OpInfoFiller<T, kOpProtoAndCheckerMaker> {
 | |
|   void operator()(const char* op_type, OpInfo* info) const {
 | |
|     info->proto_ = new OpProto;
 | |
|     info->checker_ = new OpAttrChecker();
 | |
|     auto maker = T(info->proto_, info->checker_);
 | |
|     maker.Validate();
 | |
|     info->proto_->set_type(op_type);
 | |
|     PADDLE_ENFORCE(
 | |
|         info->proto_->IsInitialized(),
 | |
|         "Fail to initialize %s's OpProto, because %s is not initialized",
 | |
|         op_type, info->proto_->InitializationErrorString());
 | |
|   }
 | |
| };
 | |
| 
 | |
| template <typename T>
 | |
| struct OpInfoFiller<T, kGradOpDescMaker> {
 | |
|   void operator()(const char* op_type, OpInfo* info) const {
 | |
|     info->grad_op_maker_ = [](const OpDescBind& fwd_op) {
 | |
|       T maker(fwd_op);
 | |
|       return maker();
 | |
|     };
 | |
|   }
 | |
| };
 | |
| }  // namespace details
 | |
| 
 | |
| }  // namespace framework
 | |
| }  // namespace paddle
 |