parent
3eb12bd100
commit
c4d0305239
@ -0,0 +1,116 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/tril_triu_op.h"
|
||||
#include <memory>
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class TrilTriuOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE_EQ(
|
||||
ctx->HasInput("X"), true,
|
||||
platform::errors::NotFound("Input(X) of TrilTriuOp is not found."));
|
||||
PADDLE_ENFORCE_EQ(
|
||||
ctx->HasOutput("Out"), true,
|
||||
platform::errors::NotFound("Output(Out) of TrilTriuOp is not found."));
|
||||
const auto& x_dims = ctx->GetInputDim("X");
|
||||
PADDLE_ENFORCE_GE(x_dims.size(), 2,
|
||||
platform::errors::InvalidArgument(
|
||||
"Input(X)'s rank must be at least 2 in TrilTriuOp."));
|
||||
ctx->SetOutputDim("Out", x_dims);
|
||||
ctx->ShareLoD("X", /*->*/ "Out");
|
||||
}
|
||||
};
|
||||
|
||||
class TrilTriuOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput("X", "Tensor, the input of tril_triu op");
|
||||
AddOutput("Out",
|
||||
"Tensor, the output tensor, with the same shape and data type as "
|
||||
"input(x)");
|
||||
AddAttr<int>("diagonal", "int number, the diagonal to consider.")
|
||||
.SetDefault(0);
|
||||
AddAttr<bool>("lower", "boolnumber, lower triangular or upper triangular.");
|
||||
AddComment(R"DOC(
|
||||
TrilTriu Operator.
|
||||
|
||||
The tril operator returns the lower triangular part of the matrix (2-D tensor)
|
||||
or batch of matrices $input$. The lower triangular part of the matrix is defined
|
||||
as the elements on and below the diagonal.
|
||||
The triu operator returns the upper triangular part of a matrix (2-D tensor)
|
||||
or batch of matrices $input$. The upper triangular part of the matrix is defined
|
||||
as the elements on and above the diagonal.
|
||||
The other elements of the result tensor out are set to 0.
|
||||
|
||||
The argument diagonal controls which diagonal to consider, default value is 0.
|
||||
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
class TrilTriuGradOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true,
|
||||
platform::errors::NotFound(
|
||||
"Input(Out@GRAD) of TrilTriuOp should not be null"));
|
||||
PADDLE_ENFORCE_EQ(ctx->HasOutput(framework::GradVarName("X")), true,
|
||||
platform::errors::NotFound(
|
||||
"Output(X@Grad) of TrilTriuOp should not be null"));
|
||||
ctx->SetOutputDim(framework::GradVarName("X"),
|
||||
ctx->GetInputDim(framework::GradVarName("Out")));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class TrilTriuGradOpMaker : public framework::SingleGradOpMaker<T> {
|
||||
public:
|
||||
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
|
||||
|
||||
protected:
|
||||
void Apply(GradOpPtr<T> op) const override {
|
||||
op->SetType("tril_triu_grad");
|
||||
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
|
||||
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
|
||||
op->SetAttrMap(this->Attrs());
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OPERATOR(tril_triu, ops::TrilTriuOp, ops::TrilTriuOpMaker,
|
||||
ops::TrilTriuGradOpMaker<paddle::framework::OpDesc>,
|
||||
ops::TrilTriuGradOpMaker<paddle::imperative::OpBase>);
|
||||
REGISTER_OPERATOR(tril_triu_grad, ops::TrilTriuGradOp);
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
tril_triu, ops::TrilTriuOpKernel<paddle::platform::CPUDeviceContext, float>,
|
||||
ops::TrilTriuOpKernel<paddle::platform::CPUDeviceContext, double>,
|
||||
ops::TrilTriuOpKernel<paddle::platform::CPUDeviceContext, int>,
|
||||
ops::TrilTriuOpKernel<paddle::platform::CPUDeviceContext, int64_t>);
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
tril_triu_grad,
|
||||
ops::TrilTriuGradOpKernel<paddle::platform::CPUDeviceContext, float>,
|
||||
ops::TrilTriuGradOpKernel<paddle::platform::CPUDeviceContext, double>,
|
||||
ops::TrilTriuGradOpKernel<paddle::platform::CPUDeviceContext, int>,
|
||||
ops::TrilTriuGradOpKernel<paddle::platform::CPUDeviceContext, int64_t>);
|
@ -0,0 +1,30 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/tril_triu_op.h"
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
|
||||
REGISTER_OP_CUDA_KERNEL(
|
||||
tril_triu,
|
||||
ops::TrilTriuOpKernel<paddle::platform::CUDADeviceContext, float>,
|
||||
ops::TrilTriuOpKernel<paddle::platform::CUDADeviceContext, double>,
|
||||
ops::TrilTriuOpKernel<paddle::platform::CUDADeviceContext, int>,
|
||||
ops::TrilTriuOpKernel<paddle::platform::CUDADeviceContext, int64_t>);
|
||||
REGISTER_OP_CUDA_KERNEL(
|
||||
tril_triu_grad,
|
||||
ops::TrilTriuGradOpKernel<paddle::platform::CUDADeviceContext, float>,
|
||||
ops::TrilTriuGradOpKernel<paddle::platform::CUDADeviceContext, double>,
|
||||
ops::TrilTriuGradOpKernel<paddle::platform::CUDADeviceContext, int>,
|
||||
ops::TrilTriuGradOpKernel<paddle::platform::CUDADeviceContext, int64_t>);
|
@ -0,0 +1,101 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/platform/for_range.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename T>
|
||||
class TrilTriuCompute {
|
||||
public:
|
||||
HOSTDEVICE TrilTriuCompute(const T* in, const int diagonal, const bool lower,
|
||||
const int64_t H, const int64_t W, T* out)
|
||||
: in_(in), diagonal_(diagonal), lower_(lower), H_(H), W_(W), out_(out) {}
|
||||
|
||||
HOSTDEVICE void operator()(int64_t idx) {
|
||||
const int64_t row = (idx / W_) % H_;
|
||||
const int64_t col = idx % W_;
|
||||
const bool mask =
|
||||
lower_ ? (col - row > diagonal_) : (col - row < diagonal_);
|
||||
out_[idx] = mask ? static_cast<T>(0) : in_[idx];
|
||||
}
|
||||
|
||||
private:
|
||||
const T* in_;
|
||||
const int diagonal_;
|
||||
const bool lower_;
|
||||
const int64_t H_;
|
||||
const int64_t W_;
|
||||
T* out_;
|
||||
};
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class TrilTriuOpKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
const auto* x = context.Input<framework::Tensor>("X");
|
||||
const auto* x_data = x->data<T>();
|
||||
auto* out = context.Output<framework::Tensor>("Out");
|
||||
auto* out_data = out->mutable_data<T>(context.GetPlace());
|
||||
|
||||
const int diagonal = context.Attr<int>("diagonal");
|
||||
const bool lower = context.Attr<bool>("lower");
|
||||
|
||||
const auto& dims = x->dims();
|
||||
const auto H = dims[dims.size() - 2];
|
||||
const auto W = dims[dims.size() - 1];
|
||||
|
||||
platform::ForRange<DeviceContext> for_range(
|
||||
context.template device_context<DeviceContext>(),
|
||||
static_cast<size_t>(x->numel()));
|
||||
|
||||
paddle::operators::TrilTriuCompute<T> tril_triu_computer(
|
||||
x_data, diagonal, lower, H, W, out_data);
|
||||
for_range(tril_triu_computer);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class TrilTriuGradOpKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
const auto* d_out =
|
||||
context.Input<framework::Tensor>(framework::GradVarName("Out"));
|
||||
const auto* dout_data = d_out->data<T>();
|
||||
auto* d_x = context.Output<framework::Tensor>(framework::GradVarName("X"));
|
||||
auto* dx_data = d_x->mutable_data<T>(context.GetPlace());
|
||||
|
||||
const int diagonal = context.Attr<int>("diagonal");
|
||||
const bool lower = context.Attr<bool>("lower");
|
||||
|
||||
const auto& dims = d_out->dims();
|
||||
const auto H = dims[dims.size() - 2];
|
||||
const auto W = dims[dims.size() - 1];
|
||||
|
||||
platform::ForRange<DeviceContext> for_range(
|
||||
context.template device_context<DeviceContext>(),
|
||||
static_cast<size_t>(d_out->numel()));
|
||||
|
||||
paddle::operators::TrilTriuCompute<T> tril_triu_grad_computer(
|
||||
dout_data, diagonal, lower, H, W, dx_data);
|
||||
for_range(tril_triu_grad_computer);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,139 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at #
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
from op_test import OpTest
|
||||
import paddle.fluid as fluid
|
||||
import paddle.tensor as tensor
|
||||
|
||||
|
||||
class TrilTriuOpDefaultTest(OpTest):
|
||||
""" the base class of other op testcases
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
self.initTestCase()
|
||||
self.real_np_op = getattr(np, self.real_op_type)
|
||||
|
||||
self.op_type = "tril_triu"
|
||||
self.inputs = {'X': self.X}
|
||||
self.attrs = {
|
||||
'diagonal': self.diagonal,
|
||||
'lower': True if self.real_op_type == 'tril' else False,
|
||||
}
|
||||
self.outputs = {
|
||||
'Out': self.real_np_op(self.X, self.diagonal)
|
||||
if self.diagonal else self.real_np_op(self.X)
|
||||
}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
def test_check_grad_normal(self):
|
||||
self.check_grad(['X'], 'Out')
|
||||
|
||||
def initTestCase(self):
|
||||
self.real_op_type = np.random.choice(['triu', 'tril'])
|
||||
self.diagonal = None
|
||||
self.X = np.arange(1, 101, dtype="float64").reshape([10, -1])
|
||||
|
||||
|
||||
def case_generator(op_type, Xshape, diagonal, expected):
|
||||
"""
|
||||
Generate testcases with the params shape of X, diagonal and op_type.
|
||||
If arg`expercted` is 'success', it will register an Optest case and expect to pass.
|
||||
Otherwise, it will register an API case and check the expect failure.
|
||||
"""
|
||||
cls_name = "{0}_{1}_shape_{2}_diag_{3}".format(expected, op_type, Xshape,
|
||||
diagonal)
|
||||
errmsg = {
|
||||
"diagonal: TypeError":
|
||||
"diagonal in {} must be a python Int".format(op_type),
|
||||
"input: ValueError":
|
||||
"input shape in {} must be at least 2-D".format(op_type),
|
||||
}
|
||||
|
||||
class FailureCase(unittest.TestCase):
|
||||
def test_failure(self):
|
||||
data = fluid.data(shape=Xshape, dtype='float64', name=cls_name)
|
||||
with self.assertRaisesRegexp(
|
||||
eval(expected.split(':')[-1]), errmsg[expected]):
|
||||
getattr(tensor, op_type)(input=data, diagonal=diagonal)
|
||||
|
||||
class SuccessCase(TrilTriuOpDefaultTest):
|
||||
def initTestCase(self):
|
||||
self.real_op_type = op_type
|
||||
self.diagonal = diagonal
|
||||
self.X = np.random.random(Xshape).astype("float64")
|
||||
|
||||
CLASS = locals()['SuccessCase' if expected == "success" else 'FailureCase']
|
||||
CLASS.__name__ = cls_name
|
||||
globals()[cls_name] = CLASS
|
||||
|
||||
|
||||
### NOTE: meaningful diagonal is [1 - min(H, W), max(H, W) -1]
|
||||
### test the diagonal just at the border, upper/lower the border,
|
||||
### negative/positive integer within range and a zero
|
||||
cases = {
|
||||
'success': {
|
||||
(2, 2, 3, 4, 5): [-100, -3, -1, 0, 2, 4, 100], # normal shape
|
||||
(10, 10, 1, 1): [-100, -1, 0, 1, 100], # small size of matrix
|
||||
},
|
||||
'diagonal: TypeError': {
|
||||
(20, 20): [
|
||||
'2020',
|
||||
[20],
|
||||
{
|
||||
20: 20
|
||||
},
|
||||
(20, 20),
|
||||
20.20,
|
||||
], # str, list, dict, tuple, float
|
||||
},
|
||||
'input: ValueError': {
|
||||
(2020, ): [None],
|
||||
},
|
||||
}
|
||||
for _op_type in ['tril', 'triu']:
|
||||
for _expected, _params in cases.items():
|
||||
for _Xshape, _diaglist in _params.items():
|
||||
list(
|
||||
map(lambda _diagonal: case_generator(_op_type, _Xshape, _diagonal, _expected),
|
||||
_diaglist))
|
||||
|
||||
|
||||
class TestTrilTriuOpAPI(unittest.TestCase):
|
||||
""" test case by using API and has -1 dimension
|
||||
"""
|
||||
|
||||
def test_api(self):
|
||||
data = np.random.random([1, 9, 9, 4]).astype('float32')
|
||||
x = fluid.data(shape=[1, 9, -1, 4], dtype='float32', name='x')
|
||||
tril_out, triu_out = tensor.tril(x), tensor.triu(x)
|
||||
|
||||
place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda(
|
||||
) else fluid.CPUPlace()
|
||||
exe = fluid.Executor(place)
|
||||
tril_out, triu_out = exe.run(
|
||||
fluid.default_main_program(),
|
||||
feed={"x": data},
|
||||
fetch_list=[tril_out, triu_out], )
|
||||
self.assertTrue(np.allclose(tril_out, np.tril(data)))
|
||||
self.assertTrue(np.allclose(triu_out, np.triu(data)))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue