|
|
|
@ -1,5 +1,5 @@
|
|
|
|
|
/**
|
|
|
|
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
|
|
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
|
|
|
|
*
|
|
|
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
* you may not use this file except in compliance with the License.
|
|
|
|
@ -113,6 +113,9 @@ void NoBroadcastGrad(const int &nums, const bool &grad_x1, const bool &grad_x2,
|
|
|
|
|
NoBroadcastGradKernel<<<GET_BLOCKS(nums), GET_THREADS, 0, stream>>>(nums, grad_x1, grad_x2, op, x1, x2, dy, dx1, dx2);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template void NoBroadcastGrad(const int &nums, const bool &grad_x1, const bool &grad_x2, enum BroadcastGradOpType op,
|
|
|
|
|
const double *x1, const double *x2, const double *dy, double *dx1, double *dx2,
|
|
|
|
|
cudaStream_t stream);
|
|
|
|
|
template void NoBroadcastGrad(const int &nums, const bool &grad_x1, const bool &grad_x2, enum BroadcastGradOpType op,
|
|
|
|
|
const float *x1, const float *x2, const float *dy, float *dx1, float *dx2,
|
|
|
|
|
cudaStream_t stream);
|
|
|
|
@ -124,6 +127,10 @@ template void NoBroadcastGrad(const int &nums, const bool &grad_x1, const bool &
|
|
|
|
|
template void NoBroadcastGrad(const int &nums, const bool &grad_x1, const bool &grad_x2, enum BroadcastGradOpType op,
|
|
|
|
|
const int64_t *x1, const int64_t *x2, const int64_t *dy, int64_t *dx1, int64_t *dx2,
|
|
|
|
|
cudaStream_t stream);
|
|
|
|
|
template void BroadcastGrad(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1,
|
|
|
|
|
const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3,
|
|
|
|
|
const bool &grad_x1, const bool &grad_x2, enum BroadcastGradOpType op, const double *x1,
|
|
|
|
|
const double *x2, const double *dy, double *dx1, double *dx2, cudaStream_t stream);
|
|
|
|
|
template void BroadcastGrad(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1,
|
|
|
|
|
const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3,
|
|
|
|
|
const bool &grad_x1, const bool &grad_x2, enum BroadcastGradOpType op, const float *x1,
|
|
|
|
|