Add Where Op(#16793)
parent
1bfff02047
commit
d4b67e1692
@ -0,0 +1,58 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/where_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class WhereOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("Condition"),
|
||||
"Input(Condition) of WhereOp should not be null.");
|
||||
PADDLE_ENFORCE(
|
||||
ctx->GetInputDim("Condition").size() >= 1,
|
||||
"Input(Condition) should have number of dimension at least 1");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
||||
"Output(OUt) of WhereOp should not be null.");
|
||||
ctx->SetOutputDim("Out", {-1, ctx->GetInputDim("Condition").size()});
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
auto output_type = framework::proto::VarType::INT64;
|
||||
return framework::OpKernelType(output_type, ctx.device_context());
|
||||
}
|
||||
};
|
||||
|
||||
class WhereOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput("Condition", "A bool tensor whose rank is at least 1");
|
||||
AddOutput("Out", "An int64 tensor of rank 2");
|
||||
AddComment(R"DOC(
|
||||
Return a int64 tensor with rank 2, specifying the coordinate of true element in `Condition`.
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_WITHOUT_GRADIENT(where, ops::WhereOp, ops::WhereOpMaker);
|
||||
REGISTER_OP_CPU_KERNEL(where, ops::CPUWhereKernel<int64_t>);
|
@ -0,0 +1,81 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include <thrust/device_vector.h>
|
||||
#include "paddle/fluid/framework/ddim.h"
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/operators/where_op.h"
|
||||
#include "paddle/fluid/platform/cuda_primitives.h"
|
||||
#include "paddle/fluid/platform/for_range.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using CUDADeviceContext = paddle::platform::CUDADeviceContext;
|
||||
|
||||
template <typename T>
|
||||
class CUDAWhereKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
auto* condition = context.Input<framework::Tensor>("Condition");
|
||||
auto* out = context.Output<framework::Tensor>("Out");
|
||||
|
||||
// TODO(zhoukunsheng): Should optimize to ensure GPU is faster than CPU.
|
||||
framework::Tensor cond_cpu;
|
||||
framework::TensorCopy(*condition, platform::CPUPlace(), &cond_cpu);
|
||||
|
||||
const bool* cond_data = cond_cpu.data<bool>();
|
||||
int64_t numel = cond_cpu.numel();
|
||||
auto dims = cond_cpu.dims();
|
||||
int rank = dims.size();
|
||||
|
||||
thrust::host_vector<int> h_true_index;
|
||||
for (int64_t i = 0; i < numel; i++) {
|
||||
if (cond_data[i]) {
|
||||
h_true_index.push_back(i);
|
||||
}
|
||||
}
|
||||
thrust::device_vector<int> d_true_index = h_true_index;
|
||||
int* ptr_true_index = thrust::raw_pointer_cast(d_true_index.data());
|
||||
|
||||
size_t true_num = h_true_index.size();
|
||||
|
||||
out->Resize(framework::make_ddim({static_cast<int64_t>(true_num), rank}));
|
||||
auto out_ptr = out->mutable_data<T>(context.GetPlace());
|
||||
|
||||
if (true_num == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
thrust::host_vector<int> h_stride(rank, 0);
|
||||
h_stride[rank - 1] = 1;
|
||||
for (int i = rank - 2; i >= 0; i--) {
|
||||
h_stride[i] = h_stride[i + 1] * dims[i + 1];
|
||||
}
|
||||
thrust::device_vector<int> d_stride = h_stride;
|
||||
int* ptr_stride = thrust::raw_pointer_cast(d_stride.data());
|
||||
|
||||
auto& dev_ctx = context.template device_context<CUDADeviceContext>();
|
||||
WhereFunctor<int*> functor(ptr_true_index, true_num, ptr_stride, rank,
|
||||
out_ptr);
|
||||
platform::ForRange<CUDADeviceContext> for_range(dev_ctx, true_num);
|
||||
for_range(functor);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_CUDA_KERNEL(where, ops::CUDAWhereKernel<int64_t>);
|
@ -0,0 +1,95 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include <functional>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/eigen.h"
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/operators/math/math_function.h"
|
||||
#include "paddle/fluid/platform/for_range.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename T>
|
||||
struct WhereFunctor {
|
||||
WhereFunctor(const T& true_index, int true_num, const T& stride, int rank,
|
||||
int64_t* out)
|
||||
: true_index_(true_index),
|
||||
true_num_(true_num),
|
||||
stride_(stride),
|
||||
rank_(rank),
|
||||
out_ptr_(out) {}
|
||||
|
||||
HOSTDEVICE void operator()(size_t idx) const {
|
||||
int index = true_index_[idx];
|
||||
for (int j = 0; j < rank_; j++) {
|
||||
out_ptr_[idx * rank_ + j] = index / stride_[j];
|
||||
index -= out_ptr_[idx * rank_ + j] * stride_[j];
|
||||
}
|
||||
}
|
||||
|
||||
const T true_index_;
|
||||
int true_num_;
|
||||
const T stride_;
|
||||
int rank_;
|
||||
int64_t* out_ptr_;
|
||||
};
|
||||
|
||||
using CPUDeviceContext = paddle::platform::CPUDeviceContext;
|
||||
|
||||
template <typename T>
|
||||
class CPUWhereKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
auto* condition = context.Input<framework::Tensor>("Condition");
|
||||
auto* out = context.Output<framework::Tensor>("Out");
|
||||
|
||||
const bool* cond_data = condition->data<bool>();
|
||||
auto numel = condition->numel();
|
||||
auto dims = condition->dims();
|
||||
const int rank = dims.size();
|
||||
|
||||
std::vector<int> true_index;
|
||||
for (auto i = 0; i < numel; i++) {
|
||||
if (cond_data[i]) {
|
||||
true_index.push_back(i);
|
||||
}
|
||||
}
|
||||
auto true_num = true_index.size();
|
||||
|
||||
out->Resize(framework::make_ddim({static_cast<int64_t>(true_num), rank}));
|
||||
auto out_ptr = out->mutable_data<T>(context.GetPlace());
|
||||
|
||||
if (true_num == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<int> stride(rank);
|
||||
stride[rank - 1] = 1;
|
||||
for (int i = rank - 2; i >= 0; i--) {
|
||||
stride[i] = stride[i + 1] * dims[i + 1];
|
||||
}
|
||||
|
||||
auto& dev_ctx = context.template device_context<CPUDeviceContext>();
|
||||
WhereFunctor<int*> functor(true_index.data(), true_num, stride.data(), rank,
|
||||
out_ptr);
|
||||
platform::ForRange<CPUDeviceContext> for_range(dev_ctx, true_num);
|
||||
for_range(functor);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,92 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
from op_test import OpTest
|
||||
import paddle.fluid.core as core
|
||||
from paddle.fluid.op import Operator
|
||||
|
||||
|
||||
class TestWhereOp(OpTest):
|
||||
def setUp(self):
|
||||
self.op_type = "where"
|
||||
self.init_config()
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
def init_config(self):
|
||||
self.inputs = {'Condition': np.array([True, False, True]), }
|
||||
|
||||
self.outputs = {'Out': np.array([[0], [2]], dtype='int64')}
|
||||
|
||||
|
||||
class TestAllFalse(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.op_type = "where"
|
||||
self.init_config()
|
||||
|
||||
def check_with_place(self, place):
|
||||
scope = core.Scope()
|
||||
condition = scope.var('Condition').get_tensor()
|
||||
condition.set(self.cond_data, place)
|
||||
|
||||
out = scope.var("Out").get_tensor()
|
||||
out.set(np.full(self.shape, 0).astype('int64'), place)
|
||||
|
||||
op = Operator("where", Condition="Condition", Out="Out")
|
||||
op.run(scope, place)
|
||||
|
||||
out_array = np.array(out)
|
||||
self.assertTrue((out_array == self.out_data).all())
|
||||
|
||||
def init_config(self):
|
||||
self.cond_data = np.array([False, False, False])
|
||||
self.shape = (3, 1)
|
||||
self.out_data = np.array([], dtype='int64')
|
||||
|
||||
def test_all_false(self):
|
||||
self.check_with_place(core.CPUPlace())
|
||||
|
||||
if core.is_compiled_with_cuda():
|
||||
self.check_with_place(core.CUDAPlace(0))
|
||||
|
||||
|
||||
class TestRank2(TestWhereOp):
|
||||
def init_config(self):
|
||||
self.inputs = {'Condition': np.array([[True, False], [False, True]]), }
|
||||
|
||||
self.outputs = {'Out': np.array([[0, 0], [1, 1]], dtype='int64')}
|
||||
|
||||
|
||||
class TestRank3(TestWhereOp):
|
||||
def init_config(self):
|
||||
self.inputs = {
|
||||
'Condition': np.array([[[True, False], [False, True]],
|
||||
[[False, True], [True, False]],
|
||||
[[False, False], [False, True]]]),
|
||||
}
|
||||
|
||||
self.outputs = {
|
||||
'Out': np.array(
|
||||
[[0, 0, 0], [0, 1, 1], [1, 0, 1], [1, 1, 0], [2, 1, 1]],
|
||||
dtype='int64')
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in new issue