[CustomOp] Support duplicable op input and output (#31535)
* support duplicable op inout * add costom concat op testfix_imperative_dygraph_error
parent
def27bc801
commit
95cceb2dd7
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,84 @@
|
|||||||
|
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
#include <vector>
|
||||||
|
#include "paddle/extension.h"
|
||||||
|
|
||||||
|
int64_t GetRows(std::vector<int64_t> shape, int64_t axis) {
|
||||||
|
int64_t rows = 1;
|
||||||
|
for (int64_t i = 0; i < axis; ++i) {
|
||||||
|
rows *= shape[i];
|
||||||
|
}
|
||||||
|
return rows;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int64_t> GetCols(const std::vector<paddle::Tensor>& ins,
|
||||||
|
int64_t rows,
|
||||||
|
int64_t* cols) {
|
||||||
|
std::vector<int64_t> cols_vec(ins.size());
|
||||||
|
for (size_t i = 0; i < ins.size(); ++i) {
|
||||||
|
int64_t t_cols = ins[i].size() / rows;
|
||||||
|
*cols += t_cols;
|
||||||
|
cols_vec[i] = t_cols;
|
||||||
|
}
|
||||||
|
return cols_vec;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename data_t>
|
||||||
|
void ConcatCpuKernel(const std::vector<paddle::Tensor>& ins,
|
||||||
|
paddle::Tensor* out,
|
||||||
|
int64_t axis) {
|
||||||
|
size_t num = ins.size();
|
||||||
|
int64_t out_rows = GetRows(ins[0].shape(), axis);
|
||||||
|
int64_t out_cols = 0;
|
||||||
|
auto ins_cols = GetCols(ins, out_rows, &out_cols);
|
||||||
|
|
||||||
|
auto* out_data = out->mutable_data<data_t>();
|
||||||
|
int64_t col_idx = 0;
|
||||||
|
for (size_t i = 0; i < num; ++i) {
|
||||||
|
int64_t col_len = ins_cols[i];
|
||||||
|
auto* in_data = ins[i].data<data_t>();
|
||||||
|
for (int j = 0; j < out_rows; ++j) {
|
||||||
|
std::memcpy(out_data + j * out_cols + col_idx,
|
||||||
|
in_data + j * col_len,
|
||||||
|
sizeof(data_t) * col_len);
|
||||||
|
}
|
||||||
|
col_idx += col_len;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename data_t>
|
||||||
|
void SplitCpuKernel(const paddle::Tensor& in,
|
||||||
|
const std::vector<paddle::Tensor>& ref_ins,
|
||||||
|
std::vector<paddle::Tensor>* outs,
|
||||||
|
int64_t axis) {
|
||||||
|
size_t num = outs->size();
|
||||||
|
int64_t in_rows = GetRows(ref_ins[0].shape(), axis);
|
||||||
|
int64_t in_cols = 0;
|
||||||
|
auto out_cols = GetCols(ref_ins, in_rows, &in_cols);
|
||||||
|
|
||||||
|
for (size_t i = 0; i < in_rows; ++i) {
|
||||||
|
auto* in_data = in.data<data_t>() + i * in_cols;
|
||||||
|
int64_t col_idx = 0;
|
||||||
|
for (size_t j = 0; j < num; ++j) {
|
||||||
|
int64_t col_len = out_cols[j];
|
||||||
|
auto* out_data = outs->at(j).mutable_data<data_t>() + i * col_len;
|
||||||
|
std::memcpy(out_data, in_data + col_idx, sizeof(data_t) * col_len);
|
||||||
|
col_idx += col_len;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,145 @@
|
|||||||
|
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
#include <vector>
|
||||||
|
#include "concat_and_split.h" // NOLINT
|
||||||
|
#include "paddle/extension.h"
|
||||||
|
|
||||||
|
#define CHECK_INPUT(x) \
|
||||||
|
PD_CHECK(x.place() == paddle::PlaceType::kCPU, #x " must be a CPU Tensor.")
|
||||||
|
|
||||||
|
int64_t ComputeAxis(int64_t axis, int64_t rank) {
|
||||||
|
PD_CHECK(axis >= -rank && axis < rank,
|
||||||
|
"The axis is excepted to be in range of [",
|
||||||
|
-rank,
|
||||||
|
", ",
|
||||||
|
rank,
|
||||||
|
"].");
|
||||||
|
if (axis < 0) {
|
||||||
|
axis = axis + rank;
|
||||||
|
}
|
||||||
|
return axis > 0 ? axis : 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int64_t> ComputeOutShape(
|
||||||
|
std::vector<std::vector<int64_t>> in_shapes, int64_t axis) {
|
||||||
|
size_t n = in_shapes.size();
|
||||||
|
auto out_shape = in_shapes[0];
|
||||||
|
size_t zero_dim_size = out_shape.size();
|
||||||
|
for (size_t i = 1; i < n; ++i) {
|
||||||
|
PD_CHECK(in_shapes[i].size() == out_shape.size(),
|
||||||
|
"Input dimension must be same.");
|
||||||
|
for (size_t j = 0; j < zero_dim_size; ++j) {
|
||||||
|
if (j == axis) {
|
||||||
|
out_shape[axis] += in_shapes[i][j];
|
||||||
|
} else {
|
||||||
|
PD_CHECK(in_shapes[0][j] == in_shapes[i][j],
|
||||||
|
"The ",
|
||||||
|
j,
|
||||||
|
"-th dimension of input must be same.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out_shape;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<paddle::Tensor> ConcatForwardDynamicAxis(
|
||||||
|
const std::vector<paddle::Tensor>& inputs, const paddle::Tensor& axis_t) {
|
||||||
|
// check inputs
|
||||||
|
PD_CHECK(inputs.size() >= 1, "No Tensor need to be concat.");
|
||||||
|
for (auto& t : inputs) {
|
||||||
|
CHECK_INPUT(t);
|
||||||
|
}
|
||||||
|
CHECK_INPUT(axis_t);
|
||||||
|
|
||||||
|
// compute output shape
|
||||||
|
int64_t rank = static_cast<int64_t>(inputs[0].shape().size());
|
||||||
|
int64_t axis = axis_t.data<int64_t>()[0];
|
||||||
|
axis = ComputeAxis(axis, rank);
|
||||||
|
std::vector<std::vector<int64_t>> in_shapes;
|
||||||
|
for (auto& t : inputs) {
|
||||||
|
in_shapes.emplace_back(t.shape());
|
||||||
|
}
|
||||||
|
auto out_shape = ComputeOutShape(in_shapes, axis);
|
||||||
|
|
||||||
|
// create output
|
||||||
|
auto out = paddle::Tensor(paddle::PlaceType::kCPU);
|
||||||
|
out.reshape(out_shape);
|
||||||
|
|
||||||
|
// calc
|
||||||
|
PD_DISPATCH_FLOATING_AND_INTEGRAL_TYPES(
|
||||||
|
inputs[0].type(), "ConcatCpuKernel", ([&] {
|
||||||
|
ConcatCpuKernel<data_t>(inputs, &out, axis);
|
||||||
|
}));
|
||||||
|
|
||||||
|
return {out};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<paddle::Tensor> ConcatBackwardDynamicAxis(
|
||||||
|
const std::vector<paddle::Tensor>& inputs,
|
||||||
|
const paddle::Tensor& grad_out,
|
||||||
|
const paddle::Tensor& axis_t) {
|
||||||
|
// check input
|
||||||
|
PD_CHECK(inputs.size() >= 1, "No Tensor need to be concat.");
|
||||||
|
for (auto& t : inputs) {
|
||||||
|
CHECK_INPUT(t);
|
||||||
|
}
|
||||||
|
CHECK_INPUT(axis_t);
|
||||||
|
CHECK_INPUT(grad_out);
|
||||||
|
|
||||||
|
// compate axis
|
||||||
|
int64_t rank = static_cast<int64_t>(inputs[0].shape().size());
|
||||||
|
int64_t axis = axis_t.data<int64_t>()[0];
|
||||||
|
axis = ComputeAxis(axis, rank);
|
||||||
|
|
||||||
|
// create outputs
|
||||||
|
std::vector<paddle::Tensor> grad_inputs;
|
||||||
|
for (auto& t : inputs) {
|
||||||
|
auto grad = paddle::Tensor(paddle::PlaceType::kCPU);
|
||||||
|
grad.reshape(t.shape());
|
||||||
|
grad_inputs.emplace_back(grad);
|
||||||
|
}
|
||||||
|
|
||||||
|
// calc
|
||||||
|
PD_DISPATCH_FLOATING_AND_INTEGRAL_TYPES(
|
||||||
|
grad_out.type(), "SplitCpuKernel", ([&] {
|
||||||
|
SplitCpuKernel<data_t>(grad_out, inputs, &grad_inputs, axis);
|
||||||
|
}));
|
||||||
|
|
||||||
|
return grad_inputs;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::vector<int64_t>> ConcatInferShapeDynamicAxis(
|
||||||
|
std::vector<std::vector<int64_t>> input_shapes,
|
||||||
|
std::vector<int64_t> axis_shape) {
|
||||||
|
return {std::vector<int64_t>(input_shapes[0].size(), -1)};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<paddle::DataType> ConcatInferDtypeDynamicAxis(
|
||||||
|
std::vector<paddle::DataType> input_dtypes, paddle::DataType axis_dtype) {
|
||||||
|
return {input_dtypes[0]};
|
||||||
|
}
|
||||||
|
|
||||||
|
PD_BUILD_OP(custom_concat)
|
||||||
|
.Inputs({paddle::Vec("X"), "Axis"})
|
||||||
|
.Outputs({"Out"})
|
||||||
|
.SetKernelFn(PD_KERNEL(ConcatForwardDynamicAxis))
|
||||||
|
.SetInferShapeFn(PD_INFER_SHAPE(ConcatInferShapeDynamicAxis))
|
||||||
|
.SetInferDtypeFn(PD_INFER_DTYPE(ConcatInferDtypeDynamicAxis));
|
||||||
|
|
||||||
|
PD_BUILD_GRAD_OP(custom_concat)
|
||||||
|
.Inputs({paddle::Vec("X"), paddle::Grad("Out"), "Axis"})
|
||||||
|
.Outputs({paddle::Grad(paddle::Vec("X"))})
|
||||||
|
.SetKernelFn(PD_KERNEL(ConcatBackwardDynamicAxis));
|
@ -0,0 +1,148 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
import paddle.static as static
|
||||||
|
from paddle.utils.cpp_extension import load, get_build_directory
|
||||||
|
from paddle.utils.cpp_extension.extension_utils import run_cmd
|
||||||
|
from utils import paddle_includes, extra_cc_args, extra_nvcc_args
|
||||||
|
|
||||||
|
# Because Windows don't use docker, the shared lib already exists in the
|
||||||
|
# cache dir, it will not be compiled again unless the shared lib is removed.
|
||||||
|
file = '{}\\custom_relu_module_jit\\custom_relu_module_jit.pyd'.format(
|
||||||
|
get_build_directory())
|
||||||
|
if os.name == 'nt' and os.path.isfile(file):
|
||||||
|
cmd = 'del {}'.format(file)
|
||||||
|
run_cmd(cmd, True)
|
||||||
|
|
||||||
|
if os.name == 'nt':
|
||||||
|
test_include = "..\\python\\paddle\\fluid\\tests\\custom_op"
|
||||||
|
else:
|
||||||
|
test_include = "../python/paddle/fluid/tests/custom_op"
|
||||||
|
paddle_includes.append(test_include)
|
||||||
|
|
||||||
|
custom_ops = load(
|
||||||
|
name='custom_concat_jit',
|
||||||
|
sources=['custom_concat_op.cc'],
|
||||||
|
extra_include_paths=paddle_includes, # add for Coverage CI
|
||||||
|
extra_cxx_cflags=extra_cc_args, # test for cc flags
|
||||||
|
extra_cuda_cflags=extra_nvcc_args, # test for nvcc flags
|
||||||
|
verbose=True)
|
||||||
|
|
||||||
|
|
||||||
|
def concat_dynamic(func, device, dtype, np_inputs, axis_v):
|
||||||
|
paddle.set_device(device)
|
||||||
|
inputs = [
|
||||||
|
paddle.to_tensor(
|
||||||
|
x, dtype=dtype, place=device, stop_gradient=False)
|
||||||
|
for x in np_inputs
|
||||||
|
]
|
||||||
|
axis = paddle.full(shape=[1], dtype='int64', fill_value=axis_v)
|
||||||
|
out = func(inputs, axis)
|
||||||
|
out.stop_gradient = False
|
||||||
|
out.backward()
|
||||||
|
grad_inputs = [x.grad for x in inputs]
|
||||||
|
return out.numpy(), grad_inputs
|
||||||
|
|
||||||
|
|
||||||
|
def concat_static(func, device, dtype, np_inputs, axis_v):
|
||||||
|
paddle.enable_static()
|
||||||
|
paddle.set_device(device)
|
||||||
|
with static.scope_guard(static.Scope()):
|
||||||
|
with static.program_guard(static.Program()):
|
||||||
|
x1 = static.data(name="x1", shape=[2, 3], dtype=dtype)
|
||||||
|
x2 = static.data(name="x2", shape=[2, 3], dtype=dtype)
|
||||||
|
axis = paddle.full(shape=[1], dtype='int64', fill_value=axis_v)
|
||||||
|
x1.stop_gradient = False
|
||||||
|
x2.stop_gradient = False
|
||||||
|
out = func([x1, x2], axis)
|
||||||
|
# mean only support float, so here use sum
|
||||||
|
sum_out = paddle.sum(out)
|
||||||
|
static.append_backward(sum_out)
|
||||||
|
|
||||||
|
exe = static.Executor()
|
||||||
|
exe.run(static.default_startup_program())
|
||||||
|
|
||||||
|
out_v, x1_grad_v, x2_grad_v = exe.run(
|
||||||
|
static.default_main_program(),
|
||||||
|
feed={
|
||||||
|
"x1": np_inputs[0].astype(dtype),
|
||||||
|
"x2": np_inputs[1].astype(dtype),
|
||||||
|
"axis": axis
|
||||||
|
},
|
||||||
|
fetch_list=[out.name, x1.name + "@GRAD", x2.name + "@GRAD"])
|
||||||
|
paddle.disable_static()
|
||||||
|
return out_v, x1_grad_v, x2_grad_v
|
||||||
|
|
||||||
|
|
||||||
|
class TestCustomConcatDynamicAxisJit(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.dtypes = ['float32', 'float64', 'int32', 'int64']
|
||||||
|
self.devices = ['cpu']
|
||||||
|
self.np_inputs = [
|
||||||
|
np.array([[1, 2, 3], [4, 5, 6]]),
|
||||||
|
np.array([[11, 12, 13], [14, 15, 16]])
|
||||||
|
]
|
||||||
|
self.axises = [0, 1]
|
||||||
|
|
||||||
|
def test_dynamic(self):
|
||||||
|
for device in self.devices:
|
||||||
|
for dtype in self.dtypes:
|
||||||
|
for axis in self.axises:
|
||||||
|
out, grad_inputs = concat_dynamic(custom_ops.custom_concat,
|
||||||
|
device, dtype,
|
||||||
|
self.np_inputs, axis)
|
||||||
|
pd_out, pd_grad_inputs = concat_dynamic(
|
||||||
|
paddle.concat, device, dtype, self.np_inputs, axis)
|
||||||
|
|
||||||
|
self.assertTrue(
|
||||||
|
np.array_equal(out, pd_out),
|
||||||
|
"custom op out: {},\n paddle api out: {}".format(
|
||||||
|
out, pd_out))
|
||||||
|
for x_grad, pd_x_grad in zip(grad_inputs, pd_grad_inputs):
|
||||||
|
self.assertTrue(
|
||||||
|
np.array_equal(x_grad, pd_x_grad),
|
||||||
|
"custom op x grad: {},\n paddle api x grad: {}".
|
||||||
|
format(x_grad, pd_x_grad))
|
||||||
|
|
||||||
|
def test_static(self):
|
||||||
|
for device in self.devices:
|
||||||
|
for dtype in self.dtypes:
|
||||||
|
for axis in self.axises:
|
||||||
|
out, x1_grad, x2_grad = concat_static(
|
||||||
|
custom_ops.custom_concat, device, dtype, self.np_inputs,
|
||||||
|
axis)
|
||||||
|
pd_out, pd_x1_grad, pd_x2_grad = concat_static(
|
||||||
|
paddle.concat, device, dtype, self.np_inputs, axis)
|
||||||
|
|
||||||
|
self.assertTrue(
|
||||||
|
np.array_equal(out, pd_out),
|
||||||
|
"custom op out: {},\n paddle api out: {}".format(
|
||||||
|
out, pd_out))
|
||||||
|
self.assertTrue(
|
||||||
|
np.array_equal(x1_grad, pd_x1_grad),
|
||||||
|
"custom op x1_grad: {},\n paddle api x1_grad: {}".
|
||||||
|
format(x1_grad, pd_x1_grad))
|
||||||
|
self.assertTrue(
|
||||||
|
np.array_equal(x2_grad, pd_x2_grad),
|
||||||
|
"custom op x2_grad: {},\n paddle api x2_grad: {}".
|
||||||
|
format(x2_grad, pd_x2_grad))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
Loading…
Reference in new issue