You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							95 lines
						
					
					
						
							2.9 KiB
						
					
					
				
			
		
		
	
	
							95 lines
						
					
					
						
							2.9 KiB
						
					
					
				| // Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
 | |
| //
 | |
| // Licensed under the Apache License, Version 2.0 (the "License");
 | |
| // you may not use this file except in compliance with the License.
 | |
| // You may obtain a copy of the License at
 | |
| //
 | |
| //     http://www.apache.org/licenses/LICENSE-2.0
 | |
| //
 | |
| // Unless required by applicable law or agreed to in writing, software
 | |
| // distributed under the License is distributed on an "AS IS" BASIS,
 | |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| // See the License for the specific language governing permissions and
 | |
| // limitations under the License.
 | |
| 
 | |
| #pragma once
 | |
| #include <vector>
 | |
| #include "paddle/fluid/framework/op_registry.h"
 | |
| 
 | |
| namespace paddle {
 | |
| namespace operators {
 | |
| 
 | |
| using Tensor = framework::Tensor;
 | |
| using LoDTensor = framework::LoDTensor;
 | |
| using DDim = framework::DDim;
 | |
| 
 | |
| template <typename DeviceContext, typename T>
 | |
| class MaskedSelectKernel : public framework::OpKernel<T> {
 | |
|  public:
 | |
|   void Compute(const framework::ExecutionContext& context) const override {
 | |
|     auto input = context.Input<framework::Tensor>("X");
 | |
|     auto mask = context.Input<framework::Tensor>("Mask");
 | |
|     auto out = context.Output<framework::Tensor>("Y");
 | |
|     auto* mask_data = mask->data<bool>();
 | |
|     auto input_data = input->data<T>();
 | |
| 
 | |
|     auto mask_size = mask->numel();
 | |
| 
 | |
|     auto input_dim = input->dims();
 | |
|     auto mask_dim = mask->dims();
 | |
|     PADDLE_ENFORCE_EQ(
 | |
|         input_dim, mask_dim,
 | |
|         platform::errors::InvalidArgument(
 | |
|             "The dim size of input and mask in OP(masked_selected) "
 | |
|             "must be equal, but got input dim:(%ld), mask dim: "
 | |
|             "(%ld). Please check input "
 | |
|             "value.",
 | |
|             input_dim, mask_dim));
 | |
| 
 | |
|     int out_size = 0;
 | |
|     for (int i = 0; i < mask_size; i++) {
 | |
|       if (mask_data[i]) out_size++;
 | |
|     }
 | |
| 
 | |
|     framework::DDim out_dim{out_size};
 | |
|     out->Resize(out_dim);
 | |
|     auto out_data = out->mutable_data<T>(context.GetPlace());
 | |
| 
 | |
|     int index = 0;
 | |
|     for (int i = 0; i < mask_size; i++) {
 | |
|       if (mask_data[i]) {
 | |
|         out_data[index] = input_data[i];
 | |
|         index++;
 | |
|       }
 | |
|     }
 | |
|   }
 | |
| };
 | |
| 
 | |
| template <typename DeviceContext, typename T>
 | |
| class MaskedSelectGradKernel : public framework::OpKernel<T> {
 | |
|  public:
 | |
|   void Compute(const framework::ExecutionContext& context) const override {
 | |
|     auto out = context.Output<framework::Tensor>(framework::GradVarName("X"));
 | |
|     auto mask = context.Input<framework::Tensor>("Mask");
 | |
|     auto input = context.Input<framework::Tensor>(framework::GradVarName("Y"));
 | |
| 
 | |
|     auto* mask_data = mask->data<bool>();
 | |
|     auto* input_data = input->data<T>();
 | |
|     auto* out_data = out->mutable_data<T>(context.GetPlace());
 | |
|     int mask_size = mask->numel();
 | |
| 
 | |
|     int index = 0;
 | |
|     for (int i = 0; i < mask_size; i++) {
 | |
|       if (mask_data[i]) {
 | |
|         out_data[i] = input_data[index];
 | |
|         index++;
 | |
|       } else {
 | |
|         out_data[i] = 0;
 | |
|       }
 | |
|     }
 | |
|   }
 | |
| };
 | |
| 
 | |
| }  // namespace operators
 | |
| }  // namespace paddle
 |