Add Diag Op(#17027)
parent
8a2caacdbc
commit
1bfff02047
@ -0,0 +1,60 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/diag_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class DiagOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext *ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("Diagonal"),
|
||||
"Input(Diagonal) of DiagOp should not be null.");
|
||||
|
||||
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
||||
"Output(Out) of DiagOp should not be null.");
|
||||
|
||||
auto s_dims = ctx->GetInputDim("Diagonal");
|
||||
PADDLE_ENFORCE(s_dims.size() == 1,
|
||||
"The rank of Input(Diagonal) should only be 1.");
|
||||
|
||||
ctx->SetOutputDim("Out", {s_dims[0], s_dims[0]});
|
||||
}
|
||||
};
|
||||
|
||||
class DiagOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput("Diagonal",
|
||||
"Diagonal values of square matrix. It is a tensor with rank 1.");
|
||||
AddOutput("Out", "A square matrix.");
|
||||
AddComment(R"DOC(
|
||||
Return a square matrix with specified diagonal values.
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OPERATOR(diag, ops::DiagOp, ops::DiagOpMaker,
|
||||
paddle::framework::EmptyGradOpMaker);
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
diag, ops::DiagKernel<paddle::platform::CPUDeviceContext, int>,
|
||||
ops::DiagKernel<paddle::platform::CPUDeviceContext, float>,
|
||||
ops::DiagKernel<paddle::platform::CPUDeviceContext, double>,
|
||||
ops::DiagKernel<paddle::platform::CPUDeviceContext, int64_t>);
|
@ -0,0 +1,23 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/operators/diag_op.h"
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_CUDA_KERNEL(
|
||||
diag, ops::DiagKernel<paddle::platform::CUDADeviceContext, int>,
|
||||
ops::DiagKernel<paddle::platform::CUDADeviceContext, int64_t>,
|
||||
ops::DiagKernel<paddle::platform::CUDADeviceContext, float>,
|
||||
ops::DiagKernel<paddle::platform::CUDADeviceContext, double>);
|
@ -0,0 +1,59 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/operators/math/math_function.h"
|
||||
#include "paddle/fluid/platform/for_range.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename T>
|
||||
struct DiagFunctor {
|
||||
DiagFunctor(const T* diagonal, int64_t numel, T* output)
|
||||
: diagonal_(diagonal), numel_(numel), output_(output) {}
|
||||
|
||||
HOSTDEVICE void operator()(size_t idx) const {
|
||||
output_[idx * numel_ + idx] = diagonal_[idx];
|
||||
}
|
||||
|
||||
const T* diagonal_;
|
||||
int64_t numel_;
|
||||
T* output_;
|
||||
};
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class DiagKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
auto* diagonal = context.Input<framework::Tensor>("Diagonal");
|
||||
auto* diag_data = diagonal->data<T>();
|
||||
auto numel = diagonal->numel();
|
||||
auto* out = context.Output<framework::Tensor>("Out");
|
||||
T* out_data = out->mutable_data<T>(context.GetPlace());
|
||||
|
||||
math::SetConstant<DeviceContext, T> set_zero;
|
||||
auto& dev_ctx = context.template device_context<DeviceContext>();
|
||||
set_zero(dev_ctx, out, static_cast<T>(0));
|
||||
|
||||
platform::ForRange<DeviceContext> for_range(dev_ctx, numel);
|
||||
DiagFunctor<T> functor(diag_data, numel, out_data);
|
||||
for_range(functor);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,43 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
from op_test import OpTest
|
||||
|
||||
|
||||
class TestDiagOp(OpTest):
|
||||
def setUp(self):
|
||||
self.op_type = "diag"
|
||||
self.init_config()
|
||||
self.inputs = {'Diagonal': self.case}
|
||||
|
||||
self.outputs = {'Out': np.diag(self.inputs['Diagonal'])}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
def init_config(self):
|
||||
self.case = np.arange(3, 6)
|
||||
|
||||
|
||||
class TestDiagOpCase1(TestDiagOp):
|
||||
def init_config(self):
|
||||
self.case = np.array([3], dtype='int32')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in new issue