You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							189 lines
						
					
					
						
							6.2 KiB
						
					
					
				
			
		
		
	
	
							189 lines
						
					
					
						
							6.2 KiB
						
					
					
				| // Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
 | |
| //
 | |
| // Licensed under the Apache License, Version 2.0 (the "License");
 | |
| // you may not use this file except in compliance with the License.
 | |
| // You may obtain a copy of the License at
 | |
| //
 | |
| //     http://www.apache.org/licenses/LICENSE-2.0
 | |
| //
 | |
| // Unless required by applicable law or agreed to in writing, software
 | |
| // distributed under the License is distributed on an "AS IS" BASIS,
 | |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| // See the License for the specific language governing permissions and
 | |
| // limitations under the License.
 | |
| 
 | |
| #pragma once
 | |
| 
 | |
| #include <vector>
 | |
| 
 | |
| #include <boost/preprocessor/arithmetic/mod.hpp>
 | |
| #include <boost/preprocessor/comparison/greater.hpp>
 | |
| #include <boost/preprocessor/comparison/greater_equal.hpp>
 | |
| #include <boost/preprocessor/control/if.hpp>
 | |
| #include <boost/preprocessor/repetition/repeat.hpp>
 | |
| 
 | |
| #include "paddle/fluid/framework/eigen.h"
 | |
| #include "paddle/fluid/framework/op_registry.h"
 | |
| #include "paddle/fluid/framework/operator.h"
 | |
| #include "paddle/fluid/platform/errors.h"
 | |
| 
 | |
| #define MAX_RANK_SUPPORTED 6
 | |
| 
 | |
| #define MESHGRID_TEMPLATE(z, n, data) \
 | |
|   case n + 1: {                       \
 | |
|     MeshgridForward<n + 1>(context);  \
 | |
|     break;                            \
 | |
|   }
 | |
| #define REP_MESHGRID_TEMPLATE(n) BOOST_PP_REPEAT(n, MESHGRID_TEMPLATE, ~)
 | |
| #define COND(n) BOOST_PP_GREATER_EQUAL(n, BOOST_PP_MOD(n, MAX_RANK_SUPPORTED))
 | |
| 
 | |
| #define MESHGRID_GRAD_CASE(n)     \
 | |
|   case n: {                       \
 | |
|     MeshgridBackward<n>(context); \
 | |
|     break;                        \
 | |
|   }
 | |
| #define MESHGRID_GRAD_TEMPLATE(z, n, data) \
 | |
|   BOOST_PP_IF(COND(n), MESHGRID_GRAD_CASE(n), )
 | |
| #define REP_MESHGRID_GRAD_TEMPLATE(n) \
 | |
|   BOOST_PP_REPEAT(n, MESHGRID_GRAD_TEMPLATE, ~)
 | |
| 
 | |
| namespace paddle {
 | |
| namespace operators {
 | |
| 
 | |
| template <typename DeviceContext, typename T>
 | |
| class MeshgridKernel : public framework::OpKernel<T> {
 | |
|  public:
 | |
|   void Compute(const framework::ExecutionContext& context) const override {
 | |
|     auto ins = context.MultiInput<framework::Tensor>("X");
 | |
|     auto rank = ins.size();
 | |
|     switch (rank) {
 | |
|       REP_MESHGRID_TEMPLATE(MAX_RANK_SUPPORTED)
 | |
|       default:
 | |
|         PADDLE_THROW(platform::errors::InvalidArgument(
 | |
|             "Only support tensor nums between 1 and 6."));
 | |
|     }
 | |
|   }
 | |
| 
 | |
|  protected:
 | |
|   template <int Rank>
 | |
|   void MeshgridForward(const framework::ExecutionContext& context) const {
 | |
|     auto ins = context.MultiInput<framework::Tensor>("X");
 | |
|     auto outs = context.MultiOutput<framework::Tensor>("Out");
 | |
|     PADDLE_ENFORCE_EQ(
 | |
|         ins.size() > 1, true,
 | |
|         platform::errors::InvalidArgument("expect at least 2 input tensors"));
 | |
| 
 | |
|     int64_t size = ins.size();
 | |
|     std::vector<int64_t> shape(size);
 | |
| 
 | |
|     for (int64_t i = 0; i < size; i++) {
 | |
|       switch (ins[i]->dims().size()) {
 | |
|         case 0:
 | |
|           shape[i] = 1;
 | |
|           break;
 | |
|         case 1:
 | |
|           shape[i] = ins[i]->dims()[0];
 | |
|           break;
 | |
|         default:
 | |
|           PADDLE_THROW(platform::errors::InvalidArgument(
 | |
|               "Expected scalar or 1D tensor in the tensor list but got tensor "
 | |
|               "%d: ",
 | |
|               i));
 | |
|       }
 | |
|     }
 | |
| 
 | |
|     for (int64_t i = 0; i < size; i++) {
 | |
|       std::vector<int64_t> view_shape(size, 1);
 | |
|       view_shape[i] = shape[i];
 | |
| 
 | |
|       framework::Tensor reshape_ins_tensor;
 | |
|       TensorCopy(*ins[i], context.GetPlace(), context.device_context(),
 | |
|                  &reshape_ins_tensor);
 | |
|       framework::DDim out_dims_reshape = framework::make_ddim(view_shape);
 | |
|       reshape_ins_tensor.Resize(out_dims_reshape);
 | |
|       framework::DDim out_dims = framework::make_ddim(shape);
 | |
| 
 | |
|       Eigen::DSizes<int, Rank> bcast_dims;
 | |
|       for (int64_t j = 0; j < size; j++) {
 | |
|         bcast_dims[j] = shape[j];
 | |
|       }
 | |
|       bcast_dims[i] = 1;
 | |
| 
 | |
|       outs[i]->Resize(out_dims);
 | |
|       auto x = framework::EigenTensor<T, Rank>::From(reshape_ins_tensor);
 | |
|       outs[i]->mutable_data<T>(context.GetPlace());
 | |
|       auto y = framework::EigenTensor<T, Rank>::From(*outs[i]);
 | |
|       auto& place =
 | |
|           *context.template device_context<DeviceContext>().eigen_device();
 | |
|       y.device(place) = x.broadcast(bcast_dims);
 | |
|     }
 | |
|   }
 | |
| };
 | |
| 
 | |
| template <typename DeviceContext, typename T>
 | |
| class MeshgridGradKernel : public framework::OpKernel<T> {
 | |
|  public:
 | |
|   void Compute(const framework::ExecutionContext& context) const override {
 | |
|     auto out_grad =
 | |
|         context.MultiInput<framework::Tensor>(framework::GradVarName("Out"));
 | |
|     int n = out_grad.size();
 | |
|     switch (n) {
 | |
|       REP_MESHGRID_GRAD_TEMPLATE(MAX_RANK_SUPPORTED)
 | |
|       default:
 | |
|         PADDLE_THROW(platform::errors::InvalidArgument(
 | |
|             "only support tensor nums being between 1 and 6."));
 | |
|     }
 | |
|   }
 | |
| 
 | |
|  protected:
 | |
|   template <int Rank>
 | |
|   void MeshgridBackward(const framework::ExecutionContext& context) const {
 | |
|     auto out_grad =
 | |
|         context.MultiInput<framework::Tensor>(framework::GradVarName("Out"));
 | |
|     auto ins = context.MultiInput<framework::Tensor>("X");
 | |
|     auto outs =
 | |
|         context.MultiOutput<framework::Tensor>(framework::GradVarName("X"));
 | |
| 
 | |
|     int n = out_grad.size();
 | |
|     auto out_dims = out_grad[0]->dims();
 | |
| 
 | |
|     for (int i = 0; i < n; i++) {
 | |
|       outs[i]->mutable_data<T>(context.GetPlace());
 | |
|       auto out_grad_tmp = framework::EigenVector<T>::Flatten(*out_grad[i]);
 | |
|       auto in_grad = framework::EigenVector<T>::Flatten(*outs[i]);
 | |
| 
 | |
|       std::vector<int> reduce_dims_vec;
 | |
|       std::vector<int> reshape_dims_vec;
 | |
|       for (int j = 0; j < n; j++) {
 | |
|         reduce_dims_vec.push_back(reshape_dims_vec.size());
 | |
|         if (j == i) {
 | |
|           reshape_dims_vec.push_back(1);
 | |
|           reshape_dims_vec.push_back(out_dims[j]);
 | |
|         } else {
 | |
|           reshape_dims_vec.push_back(out_dims[j]);
 | |
|           reshape_dims_vec.push_back(1);
 | |
|         }
 | |
|       }
 | |
| 
 | |
|       Eigen::DSizes<int, Rank> reduce_dims;
 | |
|       for (int k = 0; k < n; k++) {
 | |
|         reduce_dims[k] = reduce_dims_vec[k];
 | |
|       }
 | |
| 
 | |
|       Eigen::DSizes<int, Rank * 2> reshape_dims;
 | |
|       for (int k = 0; k < n * 2; k++) {
 | |
|         reshape_dims[k] = reshape_dims_vec[k];
 | |
|       }
 | |
| 
 | |
|       auto tensor_reduce_tmp =
 | |
|           out_grad_tmp.reshape(reshape_dims).sum(reduce_dims);
 | |
|       auto& place =
 | |
|           *context.template device_context<DeviceContext>().eigen_device();
 | |
|       in_grad.device(place) = tensor_reduce_tmp.reshape(in_grad.dimensions());
 | |
|     }
 | |
|   }
 | |
| };
 | |
| 
 | |
| }  // namespace operators
 | |
| }  // namespace paddle
 |