add is empty op (#5639)
parent
9cf6036533
commit
186581d2cc
@ -0,0 +1,67 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/framework/op_registry.h"
|
||||||
|
#include "paddle/framework/operator.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
|
||||||
|
constexpr char kInput[] = "X";
|
||||||
|
constexpr char kOutput[] = "Out";
|
||||||
|
|
||||||
|
class IsEmptyOp : public framework::OperatorBase {
|
||||||
|
public:
|
||||||
|
IsEmptyOp(const std::string &type, const framework::VariableNameMap &inputs,
|
||||||
|
const framework::VariableNameMap &outputs,
|
||||||
|
const framework::AttributeMap &attrs)
|
||||||
|
: OperatorBase(type, inputs, outputs, attrs) {}
|
||||||
|
|
||||||
|
void Run(const framework::Scope &scope,
|
||||||
|
const platform::DeviceContext &dev_ctx) const override {
|
||||||
|
// get input
|
||||||
|
auto *var = scope.FindVar(Input(kInput));
|
||||||
|
PADDLE_ENFORCE_NOT_NULL(var);
|
||||||
|
auto &tensor = var->Get<framework::LoDTensor>();
|
||||||
|
// get output
|
||||||
|
auto *out = scope.FindVar(Output(kOutput));
|
||||||
|
PADDLE_ENFORCE_NOT_NULL(out);
|
||||||
|
auto *out_tensor = out->GetMutable<framework::LoDTensor>();
|
||||||
|
|
||||||
|
out_tensor->Resize({1});
|
||||||
|
out_tensor->mutable_data<bool>(platform::CPUPlace())[0] =
|
||||||
|
framework::product(tensor.dims()) == 0;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class IsEmptyOpProtoMaker : public framework::OpProtoAndCheckerMaker {
|
||||||
|
public:
|
||||||
|
IsEmptyOpProtoMaker(framework::OpProto *proto,
|
||||||
|
framework::OpAttrChecker *op_checker)
|
||||||
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||||
|
AddInput(kInput, "(Tensor) Tensor which is to be checked.");
|
||||||
|
AddOutput(kOutput, "(Tensor) a boolean Tensor that indicate empty or not.");
|
||||||
|
AddComment(R"DOC(
|
||||||
|
IsEmpty Operator which checks whether a tensor is empty.
|
||||||
|
|
||||||
|
It will just return product(tensor.ddims()) > 0;
|
||||||
|
)DOC");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
||||||
|
|
||||||
|
REGISTER_OP_WITHOUT_GRADIENT(is_empty, paddle::operators::IsEmptyOp,
|
||||||
|
paddle::operators::IsEmptyOpProtoMaker);
|
@ -0,0 +1,43 @@
|
|||||||
|
import unittest
|
||||||
|
import numpy as np
|
||||||
|
from paddle.v2.framework.op import Operator
|
||||||
|
import paddle.v2.framework.core as core
|
||||||
|
|
||||||
|
|
||||||
|
def create_tensor(scope, name, np_data):
|
||||||
|
tensor = scope.var(name).get_tensor()
|
||||||
|
tensor.set_dims(np_data.shape)
|
||||||
|
tensor.set(np_data, core.CPUPlace())
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
class TestIsEmptyOp(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.scope = core.Scope()
|
||||||
|
# create input variables
|
||||||
|
np_data0 = np.array([0, 1, 2])
|
||||||
|
create_tensor(self.scope, "X0", np_data0)
|
||||||
|
|
||||||
|
np_data1 = np.array([1])
|
||||||
|
t = create_tensor(self.scope, "X1", np_data1)
|
||||||
|
t.set_dims([0])
|
||||||
|
|
||||||
|
# create output variables
|
||||||
|
self.scope.var("out")
|
||||||
|
|
||||||
|
def test_no_empty(self):
|
||||||
|
self.one_case("X0", False)
|
||||||
|
|
||||||
|
def test_empty(self):
|
||||||
|
self.one_case("X1", True)
|
||||||
|
|
||||||
|
def one_case(self, input, target):
|
||||||
|
op = Operator(type="is_empty", X=input, Out="out")
|
||||||
|
ctx = core.DeviceContext.create(core.CPUPlace())
|
||||||
|
op.run(self.scope, ctx)
|
||||||
|
out = self.scope.var("out").get_tensor()
|
||||||
|
self.assertEqual(np.array(out)[0], target)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
Loading…
Reference in new issue