Implement a new C++ operator where and API tensor.where (#23220)
parent
9b82e4c183
commit
c068512f34
@ -0,0 +1,159 @@
|
||||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/operators/where_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class WhereOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
OP_INOUT_CHECK(ctx->HasInput("Condition"), "Input", "Condition", "Where");
|
||||
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Where");
|
||||
OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "Where");
|
||||
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Where");
|
||||
|
||||
auto cond_dims = ctx->GetInputDim("Condition");
|
||||
auto x_dims = ctx->GetInputDim("X");
|
||||
auto y_dims = ctx->GetInputDim("Y");
|
||||
PADDLE_ENFORCE_EQ(
|
||||
cond_dims, x_dims,
|
||||
platform::errors::InvalidArgument(
|
||||
"The dims of Inputs(Condition) and Inputs(X) should be same. "
|
||||
"But received Condition's shape is [%s], X's shape is [%s]",
|
||||
cond_dims, x_dims));
|
||||
PADDLE_ENFORCE_EQ(x_dims, y_dims,
|
||||
platform::errors::InvalidArgument(
|
||||
"The dims of Inputs(X) and Inputs(Y) should be same. "
|
||||
"But received X's shape is [%s], Y's shape is [%s]",
|
||||
x_dims, y_dims));
|
||||
|
||||
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
|
||||
ctx->ShareLoD("X", /*->*/ "Out");
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
return framework::OpKernelType(
|
||||
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
|
||||
}
|
||||
};
|
||||
|
||||
class WhereGradOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
OP_INOUT_CHECK(ctx->HasInput("Condition"), "Input", "Condition", "Where");
|
||||
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Where");
|
||||
OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "Where");
|
||||
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
|
||||
framework::GradVarName("Out"), "Where");
|
||||
|
||||
auto x_dims = ctx->GetInputDim("X");
|
||||
auto y_dims = ctx->GetInputDim("Y");
|
||||
|
||||
auto x_grad_name = framework::GradVarName("X");
|
||||
auto y_grad_name = framework::GradVarName("Y");
|
||||
|
||||
if (ctx->HasOutput(x_grad_name)) {
|
||||
ctx->SetOutputDim(x_grad_name, x_dims);
|
||||
}
|
||||
if (ctx->HasOutput(y_grad_name)) {
|
||||
ctx->SetOutputDim(y_grad_name, y_dims);
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
|
||||
ctx, framework::GradVarName("Out")),
|
||||
ctx.GetPlace());
|
||||
}
|
||||
};
|
||||
|
||||
class WhereOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput("Condition",
|
||||
"(Tensor) A bool tensor whose rank is at least 1. When Condition "
|
||||
"is True, yield x, otherwise yield y");
|
||||
AddInput("X",
|
||||
"(Tensor), The first input tensor of where op. When the "
|
||||
"corresponding position of the condition is true, the output "
|
||||
"takes the element of X.");
|
||||
AddInput("Y",
|
||||
"(Tensor), The second input tensor of where op. When the "
|
||||
"corresponding position of condition is false, the output takes "
|
||||
"the element of Y.");
|
||||
AddOutput("Out", "(Tensor), The output tensor of mul op.");
|
||||
AddComment(R"DOC(
|
||||
Where Operator.
|
||||
Return a tensor of elements selected from either $X$ or $Y$, depending on condition.
|
||||
The equation is:
|
||||
$$
|
||||
Out_i =
|
||||
\begin{cases}
|
||||
\X_i, \quad \text{if} \ cond_i is True \\
|
||||
\Y_i, \quad \text{if} \ cond_i is False \\
|
||||
\end{cases}
|
||||
$$
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class WhereOpGradMaker : public framework::SingleGradOpMaker<T> {
|
||||
public:
|
||||
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
|
||||
|
||||
protected:
|
||||
void Apply(GradOpPtr<T> grad) const override {
|
||||
grad->SetType("where_grad");
|
||||
grad->SetInput("Condition", this->Input("Condition"));
|
||||
grad->SetInput("X", this->Input("X"));
|
||||
grad->SetInput("Y", this->Input("Y"));
|
||||
grad->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
|
||||
grad->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
|
||||
grad->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
|
||||
}
|
||||
};
|
||||
|
||||
DECLARE_NO_NEED_BUFFER_VARS_INFERER(WhereGradNoNeedBufferVarsInference, "X",
|
||||
"Y");
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OPERATOR(where, ops::WhereOp, ops::WhereOpMaker,
|
||||
ops::WhereOpGradMaker<paddle::framework::OpDesc>,
|
||||
ops::WhereOpGradMaker<paddle::imperative::OpBase>);
|
||||
|
||||
REGISTER_OPERATOR(where_grad, ops::WhereGradOp,
|
||||
ops::WhereGradNoNeedBufferVarsInference);
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
where, ops::WhereKernel<paddle::platform::CPUDeviceContext, float>,
|
||||
ops::WhereKernel<paddle::platform::CPUDeviceContext, double>,
|
||||
ops::WhereKernel<paddle::platform::CPUDeviceContext, int>,
|
||||
ops::WhereKernel<paddle::platform::CPUDeviceContext, int64_t>);
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
where_grad, ops::WhereGradKernel<paddle::platform::CPUDeviceContext, float>,
|
||||
ops::WhereGradKernel<paddle::platform::CPUDeviceContext, double>,
|
||||
ops::WhereGradKernel<paddle::platform::CPUDeviceContext, int>,
|
||||
ops::WhereGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
|
||||
@ -0,0 +1,122 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/operators/where_op.h"
|
||||
#include "paddle/fluid/platform/gpu_launch_param_config.h"
|
||||
|
||||
namespace platform = paddle::platform;
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename T>
|
||||
__global__ void WhereCUDAKernel(const int N, const bool* cond, const T* x,
|
||||
const T* y, T* out) {
|
||||
int idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
for (; idx < N; idx += blockDim.x * gridDim.x) {
|
||||
out[idx] = cond[idx] ? x[idx] : y[idx];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void WhereGradCUDAKernel(const int N, const T* out, const bool* cond,
|
||||
T* x, T* y) {
|
||||
int idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
for (; idx < N; idx += blockDim.x * gridDim.x) {
|
||||
if (x != nullptr) {
|
||||
x[idx] = out[idx] * (cond[idx] ? 1. : 0.);
|
||||
}
|
||||
if (y != nullptr) {
|
||||
y[idx] = out[idx] * (cond[idx] ? 0. : 1.);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class WhereKernel<platform::CUDADeviceContext, T>
|
||||
: public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
PADDLE_ENFORCE_EQ(
|
||||
platform::is_gpu_place(context.GetPlace()), true,
|
||||
platform::errors::PermissionDenied("It must use CUDAPlace."));
|
||||
auto* condition = context.Input<framework::Tensor>("Condition");
|
||||
auto* X = context.Input<framework::Tensor>("X");
|
||||
auto* Y = context.Input<framework::Tensor>("Y");
|
||||
auto* out = context.Output<framework::Tensor>("Out");
|
||||
auto numel = condition->numel();
|
||||
|
||||
// TODO(GaaoWei8): Input of where can be broadcast
|
||||
const bool* cond_data = condition->data<bool>();
|
||||
const T* x_data = X->data<T>();
|
||||
const T* y_data = Y->data<T>();
|
||||
T* out_data = out->mutable_data<T>(context.GetPlace());
|
||||
|
||||
auto stream = context.cuda_device_context().stream();
|
||||
auto& dev_ctx =
|
||||
context.template device_context<platform::CUDADeviceContext>();
|
||||
auto config = GetGpuLaunchConfig1D(dev_ctx, numel);
|
||||
WhereCUDAKernel<
|
||||
T><<<config.block_per_grid.x, config.thread_per_block.x, 0, stream>>>(
|
||||
numel, cond_data, x_data, y_data, out_data);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class WhereGradKernel<platform::CUDADeviceContext, T>
|
||||
: public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
PADDLE_ENFORCE_EQ(
|
||||
platform::is_gpu_place(context.GetPlace()), true,
|
||||
platform::errors::PermissionDenied("It must use CUDAPlace."));
|
||||
|
||||
auto* condition = context.Input<framework::Tensor>("Condition");
|
||||
const bool* cond_data = condition->data<bool>();
|
||||
auto numel = condition->numel();
|
||||
|
||||
auto* dout_t =
|
||||
context.Input<framework::Tensor>(framework::GradVarName("Out"));
|
||||
auto* dx_t = context.Output<framework::Tensor>(framework::GradVarName("X"));
|
||||
auto* dy_t = context.Output<framework::Tensor>(framework::GradVarName("Y"));
|
||||
auto* dout = dout_t->data<T>();
|
||||
T* dx =
|
||||
(dx_t != nullptr) ? dx_t->mutable_data<T>(context.GetPlace()) : nullptr;
|
||||
T* dy =
|
||||
(dy_t != nullptr) ? dy_t->mutable_data<T>(context.GetPlace()) : nullptr;
|
||||
|
||||
auto stream = context.cuda_device_context().stream();
|
||||
auto& dev_ctx =
|
||||
context.template device_context<platform::CUDADeviceContext>();
|
||||
auto config = GetGpuLaunchConfig1D(dev_ctx, condition->numel());
|
||||
WhereGradCUDAKernel<
|
||||
T><<<config.block_per_grid.x, config.thread_per_block.x, 0, stream>>>(
|
||||
numel, dout, cond_data, dx, dy);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
REGISTER_OP_CUDA_KERNEL(
|
||||
where, paddle::operators::WhereKernel<platform::CUDADeviceContext, float>,
|
||||
paddle::operators::WhereKernel<platform::CUDADeviceContext, double>,
|
||||
paddle::operators::WhereKernel<platform::CUDADeviceContext, int>,
|
||||
paddle::operators::WhereKernel<platform::CUDADeviceContext, int64_t>);
|
||||
REGISTER_OP_CUDA_KERNEL(
|
||||
where_grad,
|
||||
paddle::operators::WhereGradKernel<platform::CUDADeviceContext, float>,
|
||||
paddle::operators::WhereGradKernel<platform::CUDADeviceContext, double>,
|
||||
paddle::operators::WhereGradKernel<platform::CUDADeviceContext, int>,
|
||||
paddle::operators::WhereGradKernel<platform::CUDADeviceContext, int64_t>);
|
||||
@ -0,0 +1,73 @@
|
||||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/operators/math/math_function.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class WhereKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
auto* condition = context.Input<framework::Tensor>("Condition");
|
||||
auto* X = context.Input<framework::Tensor>("X");
|
||||
auto* Y = context.Input<framework::Tensor>("Y");
|
||||
auto* out = context.Output<framework::Tensor>("Out");
|
||||
|
||||
const bool* cond_data = condition->data<bool>();
|
||||
const T* x_data = X->data<T>();
|
||||
const T* y_data = Y->data<T>();
|
||||
T* out_data = out->mutable_data<T>(context.GetPlace());
|
||||
|
||||
auto x_numel = X->numel();
|
||||
for (int i = 0; i < x_numel; i++) {
|
||||
out_data[i] = cond_data[i] ? x_data[i] : y_data[i];
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class WhereGradKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
auto* condition = context.Input<framework::LoDTensor>("Condition");
|
||||
const auto* cond_data = condition->data<bool>();
|
||||
auto numel = condition->numel();
|
||||
|
||||
auto* dout_t =
|
||||
context.Input<framework::Tensor>(framework::GradVarName("Out"));
|
||||
auto* dx_t = context.Output<framework::Tensor>(framework::GradVarName("X"));
|
||||
auto* dy_t = context.Output<framework::Tensor>(framework::GradVarName("Y"));
|
||||
|
||||
auto* dout = dout_t->data<T>();
|
||||
if (dx_t != nullptr) {
|
||||
auto* dx = dx_t->mutable_data<T>(context.GetPlace());
|
||||
for (int i = 0; i < numel; i++) {
|
||||
dx[i] = dout[i] * (cond_data[i] ? 1. : 0.);
|
||||
}
|
||||
}
|
||||
if (dy_t != nullptr) {
|
||||
auto* dy = dy_t->mutable_data<T>(context.GetPlace());
|
||||
for (int i = 0; i < numel; i++) {
|
||||
dy[i] = dout[i] * (cond_data[i] ? 0. : 1.);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
@ -0,0 +1,173 @@
|
||||
#Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
import paddle.fluid as fluid
|
||||
import paddle.fluid.layers as layers
|
||||
import paddle.tensor as tensor
|
||||
import paddle.fluid.core as core
|
||||
from op_test import OpTest
|
||||
from paddle.fluid import compiler, Program, program_guard
|
||||
from paddle.fluid.op import Operator
|
||||
from paddle.fluid.backward import append_backward
|
||||
|
||||
|
||||
class TestWhereOp(OpTest):
|
||||
def setUp(self):
|
||||
self.op_type = "where"
|
||||
self.init_config()
|
||||
self.inputs = {'Condition': self.cond, 'X': self.x, 'Y': self.y}
|
||||
self.outputs = {'Out': np.where(self.cond, self.x, self.y)}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
def test_check_grad(self):
|
||||
self.check_grad(['X', 'Y'], 'Out')
|
||||
|
||||
def init_config(self):
|
||||
self.x = np.random.uniform(-3, 5, (100)).astype("float64")
|
||||
self.y = np.random.uniform(-3, 5, (100)).astype("float64")
|
||||
self.cond = np.zeros((100)).astype("bool")
|
||||
|
||||
|
||||
class TestWhereOp2(TestWhereOp):
|
||||
def init_config(self):
|
||||
self.x = np.random.uniform(-5, 5, (60, 2)).astype("float64")
|
||||
self.y = np.random.uniform(-5, 5, (60, 2)).astype("float64")
|
||||
self.cond = np.ones((60, 2)).astype("bool")
|
||||
|
||||
|
||||
class TestWhereOp3(TestWhereOp):
|
||||
def init_config(self):
|
||||
self.x = np.random.uniform(-3, 5, (20, 2, 4)).astype("float64")
|
||||
self.y = np.random.uniform(-3, 5, (20, 2, 4)).astype("float64")
|
||||
self.cond = np.array(np.random.randint(2, size=(20, 2, 4)), dtype=bool)
|
||||
|
||||
|
||||
class TestWhereAPI(unittest.TestCase):
|
||||
def test_api(self, use_cuda=False):
|
||||
main_program = Program()
|
||||
with fluid.program_guard(main_program):
|
||||
x = fluid.layers.data(name='x', shape=[4], dtype='float32')
|
||||
y = fluid.layers.data(name='y', shape=[4], dtype='float32')
|
||||
x_i = np.array([0.9383, 0.1983, 3.2, 1.2]).astype("float32")
|
||||
y_i = np.array([1.0, 1.0, 1.0, 1.0]).astype("float32")
|
||||
cond_i = np.array([False, False, True, True]).astype("bool")
|
||||
result = tensor.where(x > 1, X=x, Y=y)
|
||||
|
||||
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
|
||||
exe = fluid.Executor(place)
|
||||
out = exe.run(fluid.default_main_program(),
|
||||
feed={'x': x_i,
|
||||
'y': y_i},
|
||||
fetch_list=[result])
|
||||
assert np.array_equal(out[0], np.where(cond_i, x_i, y_i))
|
||||
|
||||
def test_grad(self, use_cuda=False):
|
||||
main_program = Program()
|
||||
for x_stop_gradient, y_stop_gradient in [[False, False], [True, False],
|
||||
[False, True]]:
|
||||
with fluid.program_guard(main_program):
|
||||
x = fluid.layers.data(name='x', shape=[4], dtype='float32')
|
||||
y = fluid.layers.data(name='y', shape=[4], dtype='float32')
|
||||
x.stop_gradient = x_stop_gradient
|
||||
y.stop_gradient = y_stop_gradient
|
||||
x_i = np.array([0.9383, 0.1983, 3.2, 1.2]).astype("float32")
|
||||
y_i = np.array([1.0, 1.0, 1.0, 1.0]).astype("float32")
|
||||
cond_i = np.array([False, False, True, True]).astype("bool")
|
||||
result = tensor.where(x > 1, X=x, Y=y)
|
||||
x_mean = layers.mean(x)
|
||||
append_backward(x_mean)
|
||||
y_mean = layers.mean(y)
|
||||
append_backward(y_mean)
|
||||
|
||||
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
|
||||
exe = fluid.Executor(place)
|
||||
out = exe.run(fluid.default_main_program(),
|
||||
feed={'x': x_i,
|
||||
'y': y_i},
|
||||
fetch_list=[result, x.grad_name, y.grad_name])
|
||||
x_grad = [0.25] * 4
|
||||
y_grad = [0.25] * 4
|
||||
assert np.array_equal(out[0], np.where(cond_i, x_i, y_i))
|
||||
assert np.array_equal(out[1], x_grad)
|
||||
assert np.array_equal(out[2], y_grad)
|
||||
|
||||
def test_api_broadcast(self, use_cuda=False):
|
||||
main_program = Program()
|
||||
with fluid.program_guard(main_program):
|
||||
x = fluid.layers.data(name='x', shape=[4, 1], dtype='float32')
|
||||
y = fluid.layers.data(name='y', shape=[4, 2], dtype='float32')
|
||||
x_i = np.array([[0.9383, 0.1983, 3.2, 1.2]]).astype("float32")
|
||||
y_i = np.array(
|
||||
[[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]]).astype("float32")
|
||||
cond_i = np.array([[False, False, True, True],
|
||||
[False, False, True, True]]).astype("bool")
|
||||
result = tensor.where(x > 1, X=x, Y=y)
|
||||
|
||||
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
|
||||
exe = fluid.Executor(place)
|
||||
out = exe.run(fluid.default_main_program(),
|
||||
feed={'x': x_i,
|
||||
'y': y_i},
|
||||
fetch_list=[result])
|
||||
assert np.array_equal(out[0], np.where(cond_i, x_i, y_i))
|
||||
|
||||
def test_fw_bw(self):
|
||||
if core.is_compiled_with_cuda():
|
||||
self.test_api(use_cuda=True)
|
||||
self.test_api_broadcast(use_cuda=True)
|
||||
self.test_grad(use_cuda=True)
|
||||
|
||||
|
||||
class TestWhereDygraphAPI(unittest.TestCase):
|
||||
def test_api(self):
|
||||
with fluid.dygraph.guard():
|
||||
x_i = np.array([0.9383, 0.1983, 3.2, 1.2]).astype("float64")
|
||||
y_i = np.array([1.0, 1.0, 1.0, 1.0]).astype("float64")
|
||||
cond_i = np.array([False, False, True, True]).astype("bool")
|
||||
x = fluid.dygraph.to_variable(x_i)
|
||||
y = fluid.dygraph.to_variable(y_i)
|
||||
cond = fluid.dygraph.to_variable(cond_i)
|
||||
out = tensor.where(cond, x, y)
|
||||
assert np.array_equal(out.numpy(), np.where(cond_i, x_i, y_i))
|
||||
|
||||
|
||||
class TestWhereOpError(unittest.TestCase):
|
||||
def test_errors(self):
|
||||
with program_guard(Program(), Program()):
|
||||
x_i = np.array([0.9383, 0.1983, 3.2, 1.2]).astype("float64")
|
||||
y_i = np.array([1.0, 1.0, 1.0, 1.0]).astype("float64")
|
||||
cond_i = np.array([False, False, True, True]).astype("bool")
|
||||
|
||||
def test_Variable():
|
||||
tensor.where(cond_i, x_i, y_i)
|
||||
|
||||
self.assertRaises(TypeError, test_Variable)
|
||||
|
||||
def test_type():
|
||||
x = fluid.layers.data(name='x', shape=[4], dtype='bool')
|
||||
y = fluid.layers.data(name='y', shape=[4], dtype='float16')
|
||||
cond = fluid.layers.data(name='cond', shape=[4], dtype='int32')
|
||||
tensor.where(cond, x, y)
|
||||
|
||||
self.assertRaises(TypeError, test_type)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Loading…
Reference in new issue