parent
17ec3ab23e
commit
3e1676fa9a
@ -0,0 +1,157 @@
|
||||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/operators/meshgrid_op.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using framework::Tensor;
|
||||
|
||||
class MeshgridOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
protected:
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE_GE(
|
||||
ctx->Inputs("X").size(), 1UL,
|
||||
platform::errors::InvalidArgument("Input(X) should not be empty."));
|
||||
PADDLE_ENFORCE_GE(
|
||||
ctx->Outputs("Out").size(), 1UL,
|
||||
platform::errors::InvalidArgument("Output(Out) should not be empty."));
|
||||
|
||||
auto inputs_dims = ctx->GetInputsDim("X");
|
||||
const size_t inputs_num = inputs_dims.size();
|
||||
auto outs_names = ctx->Outputs("Out");
|
||||
const size_t outputs_num = outs_names.size();
|
||||
|
||||
auto out_shape = std::vector<int>(inputs_num);
|
||||
|
||||
for (size_t i = 0; i < inputs_num; i++) {
|
||||
out_shape[i] = inputs_dims[i][0];
|
||||
}
|
||||
auto out_dims = framework::make_ddim(std::vector<int>(out_shape));
|
||||
std::vector<framework::DDim> outs_dims(outputs_num, out_dims);
|
||||
ctx->SetOutputsDim("Out", outs_dims);
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
auto inputs = ctx.MultiInput<Tensor>("X");
|
||||
auto input_data_type = framework::proto::VarType::Type(0);
|
||||
bool flag = 0;
|
||||
for (auto* input : inputs) {
|
||||
if (input->IsInitialized() && input->numel() > 0) {
|
||||
input_data_type = input->type();
|
||||
flag = 1;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (flag == 0) {
|
||||
PADDLE_THROW(platform::errors::InvalidArgument(
|
||||
"All Inputs of Meshgrid OP are Empty!"));
|
||||
}
|
||||
|
||||
return framework::OpKernelType(input_data_type, ctx.GetPlace());
|
||||
}
|
||||
};
|
||||
|
||||
class MeshgridOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput("X", "(Tensor, default Tensor<float>).").AsDuplicable();
|
||||
AddOutput("Out", "(Tensor, default Tensor<float>.)").AsDuplicable();
|
||||
|
||||
AddComment(R"DOC(
|
||||
Meshgrid Operator.
|
||||
Take: N tensors, each of which can be either scalr or 1-dimensional vector, and create
|
||||
N-dimensional grids.
|
||||
|
||||
Args:
|
||||
tensors (list of tensor): if the input k tensors has (N1,), (N2,),..., (Nk,), then
|
||||
the output tensors are all of size (N1, N2, ...., Nk).
|
||||
|
||||
Example::
|
||||
>>> x = fluid.data(name='x', shape=[10], dtype='float64')
|
||||
>>> y = fluid.data(name='y', shape=[20], dtype='float64')
|
||||
>>> grid_x, grid_y = fluid.layers.meshgrid([x, y])
|
||||
>>> grid_x.shape
|
||||
(10,20)
|
||||
>>> grid_y.shape
|
||||
(10,20)
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
class MeshgridGradOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
protected:
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE_GT(ctx->Inputs(framework::GradVarName("Out")).size(), 1,
|
||||
platform::errors::InvalidArgument(
|
||||
"Number of Inputs(Out@Grad) must be larger than 1"));
|
||||
ctx->SetOutputsDim(framework::GradVarName("X"), ctx->GetInputsDim("X"));
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
|
||||
ctx, framework::GradVarName("Out")),
|
||||
ctx.device_context());
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class MeshgridGradOpMaker : public framework::SingleGradOpMaker<T> {
|
||||
public:
|
||||
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
|
||||
|
||||
protected:
|
||||
void Apply(GradOpPtr<T> op) const override {
|
||||
op->SetType("meshgrid_grad");
|
||||
op->SetInput("X", this->Input("X"));
|
||||
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
|
||||
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X", false));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OPERATOR(meshgrid, ops::MeshgridOp, ops::MeshgridOpMaker,
|
||||
ops::MeshgridGradOpMaker<paddle::framework::OpDesc>,
|
||||
ops::MeshgridGradOpMaker<paddle::imperative::OpBase>);
|
||||
REGISTER_OPERATOR(meshgrid_grad, ops::MeshgridGradOp);
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
meshgrid, ops::MeshgridKernel<paddle::platform::CPUDeviceContext, float>,
|
||||
ops::MeshgridKernel<paddle::platform::CPUDeviceContext, double>,
|
||||
ops::MeshgridKernel<paddle::platform::CPUDeviceContext, int>,
|
||||
ops::MeshgridKernel<paddle::platform::CPUDeviceContext, int64_t>);
|
||||
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
meshgrid_grad,
|
||||
ops::MeshgridGradKernel<paddle::platform::CPUDeviceContext, float>,
|
||||
ops::MeshgridGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
|
||||
ops::MeshgridGradKernel<paddle::platform::CPUDeviceContext, int>,
|
||||
ops::MeshgridGradKernel<paddle::platform::CPUDeviceContext, double>);
|
@ -0,0 +1,29 @@
|
||||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/operators/meshgrid_op.h"
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_CUDA_KERNEL(
|
||||
meshgrid, ops::MeshgridKernel<paddle::platform::CUDADeviceContext, float>,
|
||||
ops::MeshgridKernel<paddle::platform::CUDADeviceContext, double>,
|
||||
ops::MeshgridKernel<paddle::platform::CUDADeviceContext, int>,
|
||||
ops::MeshgridKernel<paddle::platform::CUDADeviceContext, int64_t>,
|
||||
ops::MeshgridKernel<paddle::platform::CUDADeviceContext, bool>);
|
||||
REGISTER_OP_CUDA_KERNEL(
|
||||
meshgrid_grad,
|
||||
ops::MeshgridGradKernel<paddle::platform::CUDADeviceContext, float>,
|
||||
ops::MeshgridGradKernel<paddle::platform::CUDADeviceContext, double>,
|
||||
ops::MeshgridGradKernel<paddle::platform::CUDADeviceContext, int>,
|
||||
ops::MeshgridGradKernel<paddle::platform::CUDADeviceContext, int64_t>);
|
@ -0,0 +1,198 @@
|
||||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include <boost/preprocessor/arithmetic/mod.hpp>
|
||||
#include <boost/preprocessor/comparison/greater.hpp>
|
||||
#include <boost/preprocessor/comparison/greater_equal.hpp>
|
||||
#include <boost/preprocessor/control/if.hpp>
|
||||
#include <boost/preprocessor/repetition/repeat.hpp>
|
||||
|
||||
#include "paddle/fluid/framework/eigen.h"
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/framework/operator.h"
|
||||
#include "paddle/fluid/platform/errors.h"
|
||||
|
||||
#define MAX_RANK_SUPPORTED 6
|
||||
|
||||
#define MESHGRID_TEMPLATE(z, n, data) \
|
||||
case n + 1: { \
|
||||
MeshgridForward<n + 1>(context); \
|
||||
break; \
|
||||
}
|
||||
#define REP_MESHGRID_TEMPLATE(n) BOOST_PP_REPEAT(n, MESHGRID_TEMPLATE, ~)
|
||||
#define COND(n) BOOST_PP_GREATER_EQUAL(n, BOOST_PP_MOD(n, MAX_RANK_SUPPORTED))
|
||||
|
||||
#define MESHGRID_GRAD_CASE(n) \
|
||||
case n: { \
|
||||
MeshgridBackward<n>(context); \
|
||||
break; \
|
||||
}
|
||||
#define MESHGRID_GRAD_TEMPLATE(z, n, data) \
|
||||
BOOST_PP_IF(COND(n), MESHGRID_GRAD_CASE(n), )
|
||||
#define REP_MESHGRID_GRAD_TEMPLATE(n) \
|
||||
BOOST_PP_REPEAT(n, MESHGRID_GRAD_TEMPLATE, ~)
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename T, int MajorType = Eigen::RowMajor,
|
||||
typename IndexType = Eigen::DenseIndex>
|
||||
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
|
||||
template <typename T, int MajorType = Eigen::RowMajor,
|
||||
typename IndexType = Eigen::DenseIndex>
|
||||
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
|
||||
template <typename T, size_t D, int MajorType = Eigen::RowMajor,
|
||||
typename IndexType = Eigen::DenseIndex>
|
||||
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class MeshgridKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
auto ins = context.MultiInput<framework::Tensor>("X");
|
||||
auto rank = ins.size();
|
||||
switch (rank) {
|
||||
REP_MESHGRID_TEMPLATE(MAX_RANK_SUPPORTED)
|
||||
default:
|
||||
PADDLE_THROW(platform::errors::InvalidArgument(
|
||||
"Only support tensor nums between 1 and 6."));
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
template <int Rank>
|
||||
void MeshgridForward(const framework::ExecutionContext& context) const {
|
||||
auto ins = context.MultiInput<framework::Tensor>("X");
|
||||
auto outs = context.MultiOutput<framework::Tensor>("Out");
|
||||
PADDLE_ENFORCE_EQ(
|
||||
ins.size() > 1, true,
|
||||
platform::errors::InvalidArgument("expect at least 2 input tensors"));
|
||||
|
||||
int64_t size = ins.size();
|
||||
std::vector<int64_t> shape(size);
|
||||
|
||||
for (int64_t i = 0; i < size; i++) {
|
||||
switch (ins[i]->dims().size()) {
|
||||
case 0:
|
||||
shape[i] = 1;
|
||||
break;
|
||||
case 1:
|
||||
shape[i] = ins[i]->dims()[0];
|
||||
break;
|
||||
default:
|
||||
PADDLE_THROW(platform::errors::InvalidArgument(
|
||||
"Expected scalar or 1D tensor in the tensor list but got tensor "
|
||||
"%d: ",
|
||||
i));
|
||||
}
|
||||
}
|
||||
|
||||
for (int64_t i = 0; i < size; i++) {
|
||||
std::vector<int64_t> view_shape(size, 1);
|
||||
view_shape[i] = shape[i];
|
||||
|
||||
framework::Tensor reshape_ins_tensor;
|
||||
TensorCopy(*ins[i], context.GetPlace(), context.device_context(),
|
||||
&reshape_ins_tensor);
|
||||
framework::DDim out_dims_reshape = framework::make_ddim(view_shape);
|
||||
reshape_ins_tensor.Resize(out_dims_reshape);
|
||||
framework::DDim out_dims = framework::make_ddim(shape);
|
||||
|
||||
Eigen::DSizes<int, Rank> bcast_dims;
|
||||
for (int64_t j = 0; j < size; j++) {
|
||||
bcast_dims[j] = shape[j];
|
||||
}
|
||||
bcast_dims[i] = 1;
|
||||
|
||||
outs[i]->Resize(out_dims);
|
||||
auto x = EigenTensor<T, Rank>::From(reshape_ins_tensor);
|
||||
outs[i]->mutable_data<T>(context.GetPlace());
|
||||
auto y = EigenTensor<T, Rank>::From(*outs[i]);
|
||||
auto& place =
|
||||
*context.template device_context<DeviceContext>().eigen_device();
|
||||
y.device(place) = x.broadcast(bcast_dims);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class MeshgridGradKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
auto out_grad =
|
||||
context.MultiInput<framework::Tensor>(framework::GradVarName("Out"));
|
||||
int n = out_grad.size();
|
||||
switch (n) {
|
||||
REP_MESHGRID_GRAD_TEMPLATE(MAX_RANK_SUPPORTED)
|
||||
default:
|
||||
PADDLE_THROW(platform::errors::InvalidArgument(
|
||||
"only support tensor nums being between 1 and 6."));
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
template <int Rank>
|
||||
void MeshgridBackward(const framework::ExecutionContext& context) const {
|
||||
auto out_grad =
|
||||
context.MultiInput<framework::Tensor>(framework::GradVarName("Out"));
|
||||
auto ins = context.MultiInput<framework::Tensor>("X");
|
||||
auto outs =
|
||||
context.MultiOutput<framework::Tensor>(framework::GradVarName("X"));
|
||||
|
||||
int n = out_grad.size();
|
||||
auto out_dims = out_grad[0]->dims();
|
||||
|
||||
for (int i = 0; i < n; i++) {
|
||||
outs[i]->mutable_data<T>(context.GetPlace());
|
||||
auto out_grad_tmp = EigenVector<T>::Flatten(*out_grad[i]);
|
||||
auto in_grad = EigenVector<T>::Flatten(*outs[i]);
|
||||
|
||||
std::vector<int> reduce_dims_vec;
|
||||
std::vector<int> reshape_dims_vec;
|
||||
for (int j = 0; j < n; j++) {
|
||||
reduce_dims_vec.push_back(reshape_dims_vec.size());
|
||||
if (j == i) {
|
||||
reshape_dims_vec.push_back(1);
|
||||
reshape_dims_vec.push_back(out_dims[j]);
|
||||
} else {
|
||||
reshape_dims_vec.push_back(out_dims[j]);
|
||||
reshape_dims_vec.push_back(1);
|
||||
}
|
||||
}
|
||||
|
||||
Eigen::DSizes<int, Rank> reduce_dims;
|
||||
for (int k = 0; k < n; k++) {
|
||||
reduce_dims[k] = reduce_dims_vec[k];
|
||||
}
|
||||
|
||||
Eigen::DSizes<int, Rank * 2> reshape_dims;
|
||||
for (int k = 0; k < n * 2; k++) {
|
||||
reshape_dims[k] = reshape_dims_vec[k];
|
||||
}
|
||||
|
||||
auto tensor_reduce_tmp =
|
||||
out_grad_tmp.reshape(reshape_dims).sum(reduce_dims);
|
||||
auto& place =
|
||||
*context.template device_context<DeviceContext>().eigen_device();
|
||||
in_grad.device(place) = tensor_reduce_tmp.reshape(in_grad.dimensions());
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,118 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
from op_test import OpTest, skip_check_grad_ci
|
||||
import paddle.fluid as fluid
|
||||
import paddle
|
||||
from paddle.fluid import compiler, Program, program_guard, core
|
||||
|
||||
|
||||
class TestMeshgridOp(OpTest):
|
||||
def setUp(self):
|
||||
self.op_type = "meshgrid"
|
||||
self.dtype = self.get_dtype()
|
||||
ins, outs = self.init_test_data()
|
||||
self.inputs = {'X': [('x%d' % i, ins[i]) for i in range(len(ins))]}
|
||||
self.outputs = {
|
||||
'Out': [('out%d' % i, outs[i]) for i in range(len(outs))]
|
||||
}
|
||||
|
||||
def get_dtype(self):
|
||||
return "float64"
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
def test_check_grad(self):
|
||||
self.check_grad(['x0'], ['out0'])
|
||||
self.check_grad(['x1'], ['out1'])
|
||||
|
||||
def init_test_data(self):
|
||||
self.shape = self.get_x_shape()
|
||||
ins = []
|
||||
outs = []
|
||||
for i in range(len(self.shape)):
|
||||
ins.append(np.random.random((self.shape[i], )).astype(self.dtype))
|
||||
|
||||
for i in range(len(self.shape)):
|
||||
out_reshape = [1] * len(self.shape)
|
||||
out_reshape[i] = self.shape[i]
|
||||
out_temp = np.reshape(ins[i], out_reshape)
|
||||
outs.append(np.broadcast_to(out_temp, self.shape))
|
||||
return ins, outs
|
||||
|
||||
def get_x_shape(self):
|
||||
return [100, 200]
|
||||
|
||||
|
||||
class TestMeshgridOp2(TestMeshgridOp):
|
||||
def get_x_shape(self):
|
||||
return [100, 300]
|
||||
|
||||
|
||||
class TestMeshgridOp3(unittest.TestCase):
|
||||
def test_api(self):
|
||||
x = fluid.data(shape=[100], dtype='int32', name='x')
|
||||
y = fluid.data(shape=[200], dtype='int32', name='y')
|
||||
|
||||
input_1 = np.random.randint(0, 100, [100, ]).astype('int32')
|
||||
input_2 = np.random.randint(0, 100, [200, ]).astype('int32')
|
||||
|
||||
out_1 = np.reshape(input_1, [100, 1])
|
||||
out_1 = np.broadcast_to(out_1, [100, 200])
|
||||
out_2 = np.reshape(input_2, [1, 200])
|
||||
out_2 = np.broadcast_to(out_2, [100, 200])
|
||||
|
||||
exe = fluid.Executor(place=fluid.CPUPlace())
|
||||
grid_x, grid_y = paddle.tensor.meshgrid([x, y])
|
||||
res_1, res_2 = exe.run(fluid.default_main_program(),
|
||||
feed={'x': input_1,
|
||||
'y': input_2},
|
||||
fetch_list=[grid_x, grid_y])
|
||||
|
||||
assert np.array_equal(res_1, out_1)
|
||||
assert np.array_equal(res_2, out_2)
|
||||
|
||||
|
||||
class TestMeshgridOp4(unittest.TestCase):
|
||||
def test_errors(self):
|
||||
with program_guard(Program(), Program()):
|
||||
|
||||
def test_input_type():
|
||||
x = fluid.data(shape=[200], dtype='float32', name='x2')
|
||||
paddle.tensor.meshgrid(x)
|
||||
|
||||
self.assertRaises(TypeError, test_input_type)
|
||||
|
||||
|
||||
class TestMeshgridOp5(unittest.TestCase):
|
||||
def test_api_with_dygraph(self):
|
||||
input_3 = np.random.randint(0, 100, [100, ]).astype('int32')
|
||||
input_4 = np.random.randint(0, 100, [200, ]).astype('int32')
|
||||
|
||||
with fluid.dygraph.guard():
|
||||
tensor_3 = fluid.dygraph.to_variable(input_3)
|
||||
tensor_4 = fluid.dygraph.to_variable(input_4)
|
||||
res_3, res_4 = paddle.tensor.meshgrid([tensor_3, tensor_4])
|
||||
|
||||
assert np.array_equal(res_3.shape, [100, 200])
|
||||
assert np.array_equal(res_4.shape, [100, 200])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue