add ones_like op (#17388)
parent
67b48d7fe7
commit
d3b3443d10
@ -0,0 +1,64 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/fill_any_like_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class FillAnyLikeOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext *ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("X"),
|
||||
"Input(X) of FillAnyLikeOp should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
||||
"Output(Out) of FillAnyLikeOp should not be null.");
|
||||
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
|
||||
ctx->ShareLoD("X", /*->*/ "Out");
|
||||
}
|
||||
};
|
||||
|
||||
class FillAnyLikeOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput("X", "The input of fill-zeros-like op.");
|
||||
AddOutput("Out", "The variable will be filled up with specified value.");
|
||||
AddAttr<float>("value", "The filled value").SetDefault(0.0);
|
||||
AddComment(R"DOC(
|
||||
FillAnyLike Operator.
|
||||
|
||||
Fill up a variable with Attr(value).
|
||||
The output will have the same shape and dtype as the input.
|
||||
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_WITHOUT_GRADIENT(fill_any_like, ops::FillAnyLikeOp,
|
||||
ops::FillAnyLikeOpMaker);
|
||||
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
fill_any_like,
|
||||
ops::FillAnyLikeKernel<paddle::platform::CPUDeviceContext, int>,
|
||||
ops::FillAnyLikeKernel<paddle::platform::CPUDeviceContext, int64_t>,
|
||||
ops::FillAnyLikeKernel<paddle::platform::CPUDeviceContext, float>,
|
||||
ops::FillAnyLikeKernel<paddle::platform::CPUDeviceContext,
|
||||
paddle::platform::float16>,
|
||||
ops::FillAnyLikeKernel<paddle::platform::CPUDeviceContext, bool>);
|
@ -0,0 +1,27 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/operators/fill_any_like_op.h"
|
||||
#include "paddle/fluid/platform/float16.h"
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_CUDA_KERNEL(
|
||||
fill_any_like,
|
||||
ops::FillAnyLikeKernel<paddle::platform::CUDADeviceContext, int32_t>,
|
||||
ops::FillAnyLikeKernel<paddle::platform::CUDADeviceContext, int64_t>,
|
||||
ops::FillAnyLikeKernel<paddle::platform::CUDADeviceContext, float>,
|
||||
ops::FillAnyLikeKernel<paddle::platform::CUDADeviceContext,
|
||||
paddle::platform::float16>,
|
||||
ops::FillAnyLikeKernel<paddle::platform::CUDADeviceContext, bool>);
|
@ -0,0 +1,60 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include <cmath>
|
||||
#include <limits>
|
||||
#include <type_traits>
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/operators/math/math_function.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class FillAnyLikeKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
using CommonType = typename std::common_type<
|
||||
float,
|
||||
typename std::conditional<std::is_same<T, platform::float16>::value,
|
||||
float, T>::type>::type;
|
||||
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
auto* out = context.Output<framework::Tensor>("Out");
|
||||
out->mutable_data<T>(context.GetPlace());
|
||||
|
||||
// TODO(fangzeyang): Once context.Attribute supports double dtype, this
|
||||
// kernel should be updated to support double dtype, too.
|
||||
float value = context.Attr<float>("value");
|
||||
|
||||
auto common_type_value = static_cast<CommonType>(value);
|
||||
|
||||
PADDLE_ENFORCE(
|
||||
(common_type_value >=
|
||||
static_cast<CommonType>(std::numeric_limits<T>::lowest())) &&
|
||||
(common_type_value <=
|
||||
static_cast<CommonType>(std::numeric_limits<T>::max())),
|
||||
"filled value is out of range for targeted type in fill_any_like "
|
||||
"kernel");
|
||||
|
||||
PADDLE_ENFORCE(!std::isnan(value), "filled value is NaN");
|
||||
|
||||
math::SetConstant<DeviceContext, T> setter;
|
||||
setter(context.template device_context<DeviceContext>(), out,
|
||||
static_cast<T>(value));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,81 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle.fluid.core as core
|
||||
import paddle.compat as cpt
|
||||
import unittest
|
||||
import numpy as np
|
||||
from op_test import OpTest
|
||||
|
||||
|
||||
class TestFillAnyLikeOp(OpTest):
|
||||
def setUp(self):
|
||||
self.op_type = "fill_any_like"
|
||||
self.dtype = np.int32
|
||||
self.value = 0.0
|
||||
self.init()
|
||||
self.inputs = {'X': np.random.random((219, 232)).astype(self.dtype)}
|
||||
self.attrs = {'value': self.value}
|
||||
self.outputs = {'Out': self.value * np.ones_like(self.inputs["X"])}
|
||||
|
||||
def init(self):
|
||||
pass
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
|
||||
class TestFillAnyLikeOpFloat32(TestFillAnyLikeOp):
|
||||
def init(self):
|
||||
self.dtype = np.float32
|
||||
self.value = 0.0
|
||||
|
||||
|
||||
class TestFillAnyLikeOpValue1(TestFillAnyLikeOp):
|
||||
def init(self):
|
||||
self.value = 1.0
|
||||
|
||||
|
||||
class TestFillAnyLikeOpValue2(TestFillAnyLikeOp):
|
||||
def init(self):
|
||||
self.value = 1e-10
|
||||
|
||||
|
||||
class TestFillAnyLikeOpValue3(TestFillAnyLikeOp):
|
||||
def init(self):
|
||||
self.value = 1e-100
|
||||
|
||||
|
||||
class TestFillAnyLikeOpOverflow(TestFillAnyLikeOp):
|
||||
def init(self):
|
||||
self.value = 1e100
|
||||
|
||||
def test_check_output(self):
|
||||
exception = None
|
||||
try:
|
||||
self.check_output()
|
||||
except core.EnforceNotMet as ex:
|
||||
exception = ex
|
||||
self.assertIsNotNone(exception)
|
||||
|
||||
|
||||
class TestFillAnyLikeOpFloat16(TestFillAnyLikeOp):
|
||||
def init(self):
|
||||
self.dtype = np.float16
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in new issue