add diag_embed op (#23385)
* add diag_embed op, test=develop * add TestCase of diag_embed API * Modified diag embed python API teastcase from dygraph to static graph, test=develop * delete useless log and trigger ci, test=develop * modified float16 of diag_embed, test=develop * modified en doc of diag_embed * trigger ci, test=develop * add fp16 in dtype check of python API, test=develop * modified __init__ and fix a big, test=develop * modified a test bug of test_bicubic_interp_op and test_trilinear_interp_op, test=develop * modified to use one kernel on cpu and cuda, test=developrevert-22778-infer_var_type
parent
8e555ba650
commit
87d8dc3dc0
@ -0,0 +1,113 @@
|
||||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/operators/diag_embed_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class DiagEmbedOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext *ctx) const override {
|
||||
PADDLE_ENFORCE_EQ(
|
||||
ctx->HasInput("Input"), true,
|
||||
platform::errors::NotFound("Input of DiagEmbedOp is not found."));
|
||||
|
||||
PADDLE_ENFORCE_EQ(
|
||||
ctx->HasOutput("Out"), true,
|
||||
platform::errors::NotFound("Output of DiagEmbedOp is not found."));
|
||||
|
||||
int offset = ctx->Attrs().Get<int>("offset");
|
||||
int dim1 = ctx->Attrs().Get<int>("dim1");
|
||||
int dim2 = ctx->Attrs().Get<int>("dim2");
|
||||
|
||||
auto x_dims = ctx->GetInputDim("Input");
|
||||
|
||||
int dim1_ = dim1 < 0 ? x_dims.size() + dim1 + 1 : dim1;
|
||||
int dim2_ = dim2 < 0 ? x_dims.size() + dim2 + 1 : dim2;
|
||||
int offset_ = std::abs(offset);
|
||||
|
||||
PADDLE_ENFORCE_LE(
|
||||
dim1_, x_dims.size(),
|
||||
platform::errors::OutOfRange(
|
||||
"Dim1 is out of range (expected to be in range of [%ld, "
|
||||
"%ld], but got %ld).",
|
||||
-(x_dims.size() + 1), x_dims.size(), dim1));
|
||||
PADDLE_ENFORCE_LE(
|
||||
dim2_, x_dims.size(),
|
||||
platform::errors::OutOfRange(
|
||||
"Dim2 is out of range (expected to be in range of [%ld, "
|
||||
"%ld], but got %ld).",
|
||||
-(x_dims.size() + 1), x_dims.size(), dim2));
|
||||
PADDLE_ENFORCE_NE(dim1_, dim2_,
|
||||
platform::errors::InvalidArgument(
|
||||
"diagonal dimensions should not be identical "
|
||||
"%ld vs %ld.",
|
||||
dim1, dim2));
|
||||
|
||||
int new_dim_len = offset_ + x_dims[x_dims.size() - 1];
|
||||
auto sizes = vectorize(x_dims);
|
||||
sizes.pop_back();
|
||||
sizes.insert(sizes.begin() + std::min(dim1_, dim2_), new_dim_len);
|
||||
sizes.insert(sizes.begin() + std::max(dim1_, dim2_), new_dim_len);
|
||||
ctx->SetOutputDim("Out", framework::make_ddim(sizes));
|
||||
}
|
||||
};
|
||||
|
||||
class DiagEmbedOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput("Input", "The input tensor. Must be at least 1-dimensional.");
|
||||
AddOutput("Out", "A matrix whose certain 2D planes is diagonal matrix.");
|
||||
|
||||
AddAttr<int>(
|
||||
"offset",
|
||||
R"DOC((int, default 0), which diagonal to consider. Default: 0 (main diagonal).
|
||||
)DOC")
|
||||
.SetDefault(0);
|
||||
AddAttr<int>(
|
||||
"dim1",
|
||||
R"DOC((int, default -2), first dimension with respect to which to take diagonal. Default: -2.
|
||||
)DOC")
|
||||
.SetDefault(-2);
|
||||
AddAttr<int>(
|
||||
"dim2",
|
||||
R"DOC((int, default -1), second dimension with respect to which to take diagonal. Default: -1.
|
||||
)DOC")
|
||||
.SetDefault(-1);
|
||||
|
||||
AddComment(R"DOC(Creates a tensor whose diagonals of certain 2D planes
|
||||
(specified by dim1 and dim2) are filled by input.
|
||||
To facilitate creating batched diagonal matrices,
|
||||
the 2D planes formed by the last two dimensions of the returned tensor
|
||||
are chosen by default.
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
namespace platform = paddle::platform;
|
||||
REGISTER_OPERATOR(
|
||||
diag_embed, ops::DiagEmbedOp, ops::DiagEmbedOpMaker,
|
||||
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
|
||||
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
diag_embed, ops::DiagEmbedKernel<paddle::platform::CPUDeviceContext, int>,
|
||||
ops::DiagEmbedKernel<paddle::platform::CPUDeviceContext, float>,
|
||||
ops::DiagEmbedKernel<paddle::platform::CPUDeviceContext, double>,
|
||||
ops::DiagEmbedKernel<paddle::platform::CPUDeviceContext, int64_t>);
|
@ -0,0 +1,26 @@
|
||||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/operators/diag_embed_op.h"
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
namespace platform = paddle::platform;
|
||||
REGISTER_OP_CUDA_KERNEL(
|
||||
diag_embed, ops::DiagEmbedKernel<paddle::platform::CUDADeviceContext, int>,
|
||||
ops::DiagEmbedKernel<paddle::platform::CUDADeviceContext, int64_t>,
|
||||
ops::DiagEmbedKernel<paddle::platform::CUDADeviceContext, float>,
|
||||
ops::DiagEmbedKernel<paddle::platform::CUDADeviceContext,
|
||||
platform::float16>,
|
||||
ops::DiagEmbedKernel<paddle::platform::CUDADeviceContext, double>);
|
@ -0,0 +1,121 @@
|
||||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <algorithm>
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/framework/operator.h"
|
||||
#include "paddle/fluid/operators/math/math_function.h"
|
||||
#include "paddle/fluid/platform/for_range.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename T>
|
||||
struct DiagEmbedFunctor {
|
||||
DiagEmbedFunctor(const T* input, int64_t numel, const int64_t* dim,
|
||||
int64_t offset, int64_t dims_size, T* output,
|
||||
const int64_t* strides)
|
||||
: input_(input),
|
||||
numel_(numel),
|
||||
dim_(dim),
|
||||
offset_(offset),
|
||||
dims_size_(dims_size),
|
||||
output_(output),
|
||||
strides_(strides) {}
|
||||
|
||||
HOSTDEVICE void operator()(size_t idx) const {
|
||||
int64_t position = 0;
|
||||
auto numel = numel_;
|
||||
int64_t num = idx;
|
||||
for (int64_t i = 0; i < dims_size_; i++) {
|
||||
numel = numel / dim_[i];
|
||||
position += num / numel * strides_[i];
|
||||
num = num % numel;
|
||||
}
|
||||
output_[position + offset_] = input_[idx];
|
||||
}
|
||||
|
||||
const T* input_;
|
||||
int64_t numel_;
|
||||
const int64_t* dim_;
|
||||
int64_t offset_;
|
||||
int64_t dims_size_;
|
||||
T* output_;
|
||||
const int64_t* strides_;
|
||||
};
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class DiagEmbedKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
auto* input = context.Input<framework::Tensor>("Input");
|
||||
auto* out = context.Output<framework::Tensor>("Out");
|
||||
|
||||
const int64_t offset = context.Attr<int>("offset");
|
||||
const int64_t dim1 = context.Attr<int>("dim1");
|
||||
const int64_t dim2 = context.Attr<int>("dim2");
|
||||
auto* input_data = input->data<T>();
|
||||
|
||||
T* out_data = out->mutable_data<T>(context.GetPlace());
|
||||
math::SetConstant<DeviceContext, T> set_zero;
|
||||
auto& dev_ctx = context.template device_context<DeviceContext>();
|
||||
set_zero(dev_ctx, out, static_cast<T>(0.0));
|
||||
|
||||
auto out_dims = out->dims();
|
||||
int dim1_ = dim1 < 0 ? out_dims.size() + dim1 : dim1;
|
||||
int dim2_ = dim2 < 0 ? out_dims.size() + dim2 : dim2;
|
||||
auto stride = framework::stride(out_dims);
|
||||
int64_t diag_size;
|
||||
int64_t storage_offset = 0;
|
||||
if (offset >= 0) {
|
||||
int64_t dim = out_dims[dim2_] - offset;
|
||||
diag_size = std::max<int64_t>(std::min(out_dims[dim1_], dim), 0);
|
||||
} else {
|
||||
int64_t dim = out_dims[dim1_] + offset;
|
||||
diag_size = std::max<int64_t>(std::min(dim, out_dims[dim2_]), 0);
|
||||
}
|
||||
if (diag_size == 0) {
|
||||
// skip
|
||||
} else if (offset >= 0) {
|
||||
storage_offset += offset * stride[dim2_];
|
||||
} else {
|
||||
storage_offset -= offset * stride[dim1_];
|
||||
}
|
||||
auto strides = vectorize(stride);
|
||||
strides.erase(strides.begin() + std::max(dim1_, dim2_));
|
||||
strides.erase(strides.begin() + std::min(dim1_, dim2_));
|
||||
strides.push_back(stride[dim1_] + stride[dim2_]);
|
||||
const auto dims = vectorize(input->dims());
|
||||
|
||||
#ifdef __NVCC__
|
||||
thrust::device_vector<int64_t> dims_vec(dims);
|
||||
const int64_t* dims_arr = thrust::raw_pointer_cast(dims_vec.data());
|
||||
thrust::device_vector<int64_t> strides_vec(strides);
|
||||
const int64_t* strides_arr = thrust::raw_pointer_cast(strides_vec.data());
|
||||
#else
|
||||
const int64_t* dims_arr = dims.data();
|
||||
const int64_t* strides_arr = strides.data();
|
||||
#endif
|
||||
|
||||
platform::ForRange<DeviceContext> for_range(dev_ctx, input->numel());
|
||||
DiagEmbedFunctor<T> functor(input_data, input->numel(), dims_arr,
|
||||
storage_offset, dims.size(), out_data,
|
||||
strides_arr);
|
||||
for_range(functor);
|
||||
}
|
||||
};
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,73 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
from op_test import OpTest
|
||||
import paddle.nn.functional as F
|
||||
import paddle.fluid as fluid
|
||||
import paddle.fluid.dygraph as dg
|
||||
import paddle.fluid.core as core
|
||||
|
||||
|
||||
class TestDiagEmbedOp(OpTest):
|
||||
def setUp(self):
|
||||
self.op_type = "diag_embed"
|
||||
self.init_config()
|
||||
self.outputs = {'Out': self.target}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
def init_config(self):
|
||||
self.case = np.random.randn(2, 3).astype('float32')
|
||||
self.inputs = {'Input': self.case}
|
||||
self.attrs = {'offset': 0, 'dim1': -2, 'dim2': -1}
|
||||
self.target = np.stack([np.diag(r, 0) for r in self.inputs['Input']], 0)
|
||||
|
||||
|
||||
class TestDiagEmbedOpCase1(TestDiagEmbedOp):
|
||||
def init_config(self):
|
||||
self.case = np.random.randn(2, 3).astype('float32')
|
||||
self.inputs = {'Input': self.case}
|
||||
self.attrs = {'offset': -1, 'dim1': 0, 'dim2': 2}
|
||||
self.target = np.stack([np.diag(r, -1) for r in self.inputs['Input']],
|
||||
1)
|
||||
|
||||
|
||||
class TestDiagEmbedAPICase(unittest.TestCase):
|
||||
def test_case1(self):
|
||||
diag_embed = np.random.randn(2, 3, 4).astype('float32')
|
||||
data1 = fluid.data(name='data1', shape=[2, 3, 4], dtype='float32')
|
||||
out1 = F.diag_embed(data1)
|
||||
out2 = F.diag_embed(data1, offset=1, dim1=-2, dim2=3)
|
||||
|
||||
place = core.CPUPlace()
|
||||
exe = fluid.Executor(place)
|
||||
results = exe.run(fluid.default_main_program(),
|
||||
feed={"data1": diag_embed},
|
||||
fetch_list=[out1, out2],
|
||||
return_numpy=True)
|
||||
target1 = np.stack(
|
||||
[np.stack([np.diag(s, 0) for s in r], 0) for r in diag_embed], 0)
|
||||
target2 = np.stack(
|
||||
[np.stack([np.diag(s, 1) for s in r], 0) for r in diag_embed], 0)
|
||||
self.assertTrue(np.allclose(results[0], target1))
|
||||
self.assertTrue(np.allclose(results[1], target2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in new issue