add is empty op (#5639)
parent
9cf6036533
commit
186581d2cc
@ -0,0 +1,67 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/framework/op_registry.h"
|
||||
#include "paddle/framework/operator.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
constexpr char kInput[] = "X";
|
||||
constexpr char kOutput[] = "Out";
|
||||
|
||||
class IsEmptyOp : public framework::OperatorBase {
|
||||
public:
|
||||
IsEmptyOp(const std::string &type, const framework::VariableNameMap &inputs,
|
||||
const framework::VariableNameMap &outputs,
|
||||
const framework::AttributeMap &attrs)
|
||||
: OperatorBase(type, inputs, outputs, attrs) {}
|
||||
|
||||
void Run(const framework::Scope &scope,
|
||||
const platform::DeviceContext &dev_ctx) const override {
|
||||
// get input
|
||||
auto *var = scope.FindVar(Input(kInput));
|
||||
PADDLE_ENFORCE_NOT_NULL(var);
|
||||
auto &tensor = var->Get<framework::LoDTensor>();
|
||||
// get output
|
||||
auto *out = scope.FindVar(Output(kOutput));
|
||||
PADDLE_ENFORCE_NOT_NULL(out);
|
||||
auto *out_tensor = out->GetMutable<framework::LoDTensor>();
|
||||
|
||||
out_tensor->Resize({1});
|
||||
out_tensor->mutable_data<bool>(platform::CPUPlace())[0] =
|
||||
framework::product(tensor.dims()) == 0;
|
||||
}
|
||||
};
|
||||
|
||||
class IsEmptyOpProtoMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
IsEmptyOpProtoMaker(framework::OpProto *proto,
|
||||
framework::OpAttrChecker *op_checker)
|
||||
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||
AddInput(kInput, "(Tensor) Tensor which is to be checked.");
|
||||
AddOutput(kOutput, "(Tensor) a boolean Tensor that indicate empty or not.");
|
||||
AddComment(R"DOC(
|
||||
IsEmpty Operator which checks whether a tensor is empty.
|
||||
|
||||
It will just return product(tensor.ddims()) > 0;
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
REGISTER_OP_WITHOUT_GRADIENT(is_empty, paddle::operators::IsEmptyOp,
|
||||
paddle::operators::IsEmptyOpProtoMaker);
|
@ -0,0 +1,43 @@
|
||||
import unittest
|
||||
import numpy as np
|
||||
from paddle.v2.framework.op import Operator
|
||||
import paddle.v2.framework.core as core
|
||||
|
||||
|
||||
def create_tensor(scope, name, np_data):
|
||||
tensor = scope.var(name).get_tensor()
|
||||
tensor.set_dims(np_data.shape)
|
||||
tensor.set(np_data, core.CPUPlace())
|
||||
return tensor
|
||||
|
||||
|
||||
class TestIsEmptyOp(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.scope = core.Scope()
|
||||
# create input variables
|
||||
np_data0 = np.array([0, 1, 2])
|
||||
create_tensor(self.scope, "X0", np_data0)
|
||||
|
||||
np_data1 = np.array([1])
|
||||
t = create_tensor(self.scope, "X1", np_data1)
|
||||
t.set_dims([0])
|
||||
|
||||
# create output variables
|
||||
self.scope.var("out")
|
||||
|
||||
def test_no_empty(self):
|
||||
self.one_case("X0", False)
|
||||
|
||||
def test_empty(self):
|
||||
self.one_case("X1", True)
|
||||
|
||||
def one_case(self, input, target):
|
||||
op = Operator(type="is_empty", X=input, Out="out")
|
||||
ctx = core.DeviceContext.create(core.CPUPlace())
|
||||
op.run(self.scope, ctx)
|
||||
out = self.scope.var("out").get_tensor()
|
||||
self.assertEqual(np.array(out)[0], target)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in new issue