add dynamic shape and testcases to GPU biasadd

pull/11247/head
TFBunny 4 years ago
parent cf2734da8e
commit 6a58479e42

@ -14,7 +14,7 @@
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/math/bias_add_gpu_kernel.h"
#include "backend/kernel_compiler/gpu/nn/bias_add_gpu_kernel.h"
namespace mindspore {
namespace kernel {

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_BIAS_ADD_GPU_KERNEL_H
#define MINDSPORE_BIAS_ADD_GPU_KERNEL_H
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_BIAS_ADD_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_BIAS_ADD_GPU_KERNEL_H_
#include <cuda_runtime_api.h>
#include <string>
#include <algorithm>
@ -30,13 +30,7 @@ namespace kernel {
template <typename T>
class BiasAddGpuKernel : public GpuKernel {
public:
BiasAddGpuKernel()
: cudnn_handle_(nullptr),
cudnn_data_type_(CUDNN_DATA_FLOAT),
x_desc_(nullptr),
b_desc_(nullptr),
op_desc_(nullptr),
is_null_input_(false) {}
BiasAddGpuKernel() { ResetResource(); }
~BiasAddGpuKernel() override { DestroyResource(); }
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
@ -117,6 +111,18 @@ class BiasAddGpuKernel : public GpuKernel {
return true;
}
void ResetResource() noexcept override {
cudnn_handle_ = nullptr;
cudnn_data_type_ = CUDNN_DATA_FLOAT;
x_desc_ = nullptr;
b_desc_ = nullptr;
op_desc_ = nullptr;
is_null_input_ = false;
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
}
void DestroyResource() noexcept override {
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyOpTensorDescriptor(op_desc_),
"cudnnDestroyTensorDescriptor failed");
@ -136,6 +142,7 @@ class BiasAddGpuKernel : public GpuKernel {
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateOpTensorDescriptor(&op_desc_),
"cudnnCreateOpTensorDescriptor failed");
}
void InitSizeLists() override {
size_t x_size, b_size;
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnGetTensorSizeInBytes(x_desc_, &x_size),
@ -161,4 +168,4 @@ class BiasAddGpuKernel : public GpuKernel {
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_BIAS_ADD_GPU_KERNEL_H
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_BIAS_ADD_GPU_KERNEL_H_

@ -63,6 +63,8 @@ AbstractBasePtr InferImplConv2DBackpropInput(const AnalysisEnginePtr &, const Pr
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplConv2DBackpropFilter(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplBiasAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplBiasAddGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplGelu(const AnalysisEnginePtr &, const PrimitivePtr &primitive,

@ -470,6 +470,41 @@ AbstractBasePtr InferImplConv2DBackpropFilter(const AnalysisEnginePtr &, const P
return args_spec_list[2]->Broaden();
}
AbstractBasePtr InferImplBiasAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 2);
auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
auto bias = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
MS_EXCEPTION_IF_NULL(x);
MS_EXCEPTION_IF_NULL(x->shape());
ShapeVector x_shape = x->shape()->shape();
MS_EXCEPTION_IF_NULL(bias);
MS_EXCEPTION_IF_NULL(bias->shape());
ShapeVector bias_shape = bias->shape()->shape();
ShapeVector x_min_shape = x->shape()->min_shape();
ShapeVector x_max_shape = x->shape()->max_shape();
std::set<std::string> available_data_format{"NCHW", "NHWC"};
auto data_format_ptr = primitive->GetAttr("data_format");
std::string data_format = "NCHW";
if ((data_format_ptr != nullptr) && data_format_ptr->isa<StringImm>()) {
data_format = data_format_ptr->cast<StringImmPtr>()->value();
}
if (available_data_format.find(data_format) == available_data_format.end()) {
MS_LOG(EXCEPTION) << "Unsupported data format: " << data_format << ", use NCHW or NHWC.";
}
auto x_channel = data_format == "NHWC" ? x_shape[x_shape.size() - 1] : x_shape[1];
// Additional check for dynamic shape
// Last infer will be real shape values
bool x_not_dyn = std::all_of(x_shape.begin(), x_shape.end(), [](int64_t value) { return value != Shape::SHP_ANY; });
if (x_not_dyn && bias_shape[0] != x_channel) {
MS_LOG(EXCEPTION) << "BiasAdd shape error, data format is " << data_format
<< ", got bias_shape[0]: " << bias_shape[0] << ", x_channel: " << x_channel << ".";
}
(void)CheckMinMaxShape(x_shape, &x_min_shape, &x_max_shape);
return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(x_shape, x_min_shape, x_max_shape));
}
AbstractBasePtr InferImplBiasAddGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: at least one tensor(y_backprop)

@ -114,6 +114,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimConv2D, {InferImplConv2D, true}},
{prim::kPrimConv2DBackpropInput, {InferImplConv2DBackpropInput, true}},
{prim::kPrimConv2DBackpropFilter, {InferImplConv2DBackpropFilter, true}},
{prim::kPrimBiasAdd, {InferImplBiasAdd, true}},
{prim::kPrimBiasAddGrad, {InferImplBiasAddGrad, true}},
{prim::kPrimRelu, {InferImplRelu, true}},
{prim::kPrimZerosLike, {InferImplZerosLike, true}},

@ -1887,19 +1887,21 @@ class Conv2DBackpropInput(PrimitiveWithInfer):
return out
class BiasAdd(PrimitiveWithInfer):
class BiasAdd(PrimitiveWithCheck):
r"""
Returns sum of input and bias tensor.
Adds the 1-D bias tensor to the input tensor, and broadcasts the shape on all axis
except for the channel axis.
Args:
data_format (str): The format of input and output data. It should be 'NHWC' or 'NCHW'
default is 'NCHW'.
Inputs:
- **input_x** (Tensor) - The input tensor. The shape can be 2-4 dimensions.
- **bias** (Tensor) - The bias tensor, with shape :math:`(C)`.
- **data_format** (str) - The format of input and output data. It should be 'NHWC' or 'NCHW'\
default is 'NCHW'.
The shape of `bias` must be the same as `input_x` in the second dimension.
- **bias** (Tensor) - The bias tensor, with shape :math:`(C)`. The shape of
`bias` must be the same as `input_x`'s channel dimension.
Outputs:
Tensor, with the same shape and type as `input_x`.
@ -1924,17 +1926,16 @@ class BiasAdd(PrimitiveWithInfer):
raise ValueError("NHWC format only support in GPU target.")
self.add_prim_attr('data_format', self.format)
def infer_shape(self, x_shape, b_shape):
def check_shape(self, x_shape, b_shape):
validator.check_int(len(x_shape), 2, Rel.GE, "x rank", self.name)
validator.check_equal_int(len(b_shape), 1, "bias rank", self.name)
x_channel = x_shape[1] if self.format == "NCHW" else x_shape[-1]
validator.check("b_shape[0]", b_shape[0], "x_shape[1]", x_channel, Rel.EQ, self.name)
return x_shape
if np.all(np.array(x_shape) != -1):
validator.check("b_shape[0]", b_shape[0], "x_channel", x_channel, Rel.EQ, self.name)
def infer_dtype(self, x_type, b_type):
def check_dtype(self, x_type, b_type):
args = {"input_x": x_type, "bias": b_type}
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
return x_type
class TopK(PrimitiveWithInfer):

@ -23,7 +23,7 @@ from mindspore.common.parameter import ParameterTuple
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore.ops.composite import GradOperation
from mindspore.ops.operations import _inner_ops as inner
class BiasAdd(nn.Cell):
def __init__(self):
@ -442,3 +442,66 @@ def test_biasadd_4d():
error = np.ones(shape=[3]) * 1.0e-6
assert np.all(diff < error)
assert np.all(-diff < error)
class BiasAddDynamic(nn.Cell):
def __init__(self):
super(BiasAddDynamic, self).__init__()
self.ba = P.BiasAdd()
self.test_dynamic = inner.GpuConvertToDynamicShape()
def construct(self, x, b):
x = self.test_dynamic(x)
output = self.ba(x, b)
return output
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_bias_add_dynamic_two_inputs():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
net = BiasAddDynamic()
x_1 = Tensor(np.array([[0.1, 0.2, 0.3, 0.4],
[0.5, 0.6, 0.7, 0.8],
[0.9, 1.0, 1.1, 1.2]]).astype(np.float32))
b_1 = Tensor(np.array([0.1, 0.2, 0.3, 0.4]).astype(np.float32))
expect_1 = np.array([[0.2, 0.4, 0.6, 0.8],
[0.6, 0.8, 1.0, 1.2],
[1.0, 1.2, 1.4, 1.6]])
error_1 = np.ones(shape=[3, 4]) * 1.0e-6
result_1 = net(x_1, b_1)
diff_1 = result_1.asnumpy() - expect_1
assert np.all(diff_1 < error_1)
assert np.all(-diff_1 < error_1)
x_2 = Tensor(np.array([[[1, 2, 3, 4, 5, 6, 7, 8],
[9, 10, 11, 12, 13, 14, 15, 16],
[17, 18, 19, 20, 21, 22, 23, 24],
[25, 26, 27, 28, 29, 30, 31, 32]],
[[33, 34, 35, 36, 37, 38, 39, 40],
[41, 42, 43, 44, 45, 46, 47, 48],
[49, 50, 51, 52, 53, 54, 55, 56],
[57, 58, 59, 60, 61, 62, 63, 64]],
[[65, 66, 67, 68, 69, 70, 71, 72],
[73, 74, 75, 76, 77, 78, 79, 80],
[81, 82, 83, 84, 85, 86, 87, 88],
[89, 90, 91, 92, 93, 94, 95, 96]]]).astype(np.float32))
b_2 = Tensor(np.array([1, 2, 3, 4]).astype(np.float32))
expect_2 = np.array([[[2, 3, 4, 5, 6, 7, 8, 9],
[11, 12, 13, 14, 15, 16, 17, 18],
[20, 21, 22, 23, 24, 25, 26, 27],
[29, 30, 31, 32, 33, 34, 35, 36]],
[[34, 35, 36, 37, 38, 39, 40, 41],
[43, 44, 45, 46, 47, 48, 49, 50],
[52, 53, 54, 55, 56, 57, 58, 59],
[61, 62, 63, 64, 65, 66, 67, 68]],
[[66, 67, 68, 69, 70, 71, 72, 73],
[75, 76, 77, 78, 79, 80, 81, 82],
[84, 85, 86, 87, 88, 89, 90, 91],
[93, 94, 95, 96, 97, 98, 99, 100]]])
error_2 = np.ones(shape=[3, 4, 8]) * 1.0e-6
result_2 = net(x_2, b_2)
diff_2 = result_2.asnumpy() - expect_2
assert np.all(diff_2 < error_2)
assert np.all(-diff_2 < error_2)

Loading…
Cancel
Save