parent
e805d06499
commit
26f6daa850
@ -0,0 +1,90 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "backend/kernel_compiler/gpu/cuda_impl/instance_norm_impl.cuh"
|
||||||
|
#include "backend/kernel_compiler/gpu/cuda_impl/util.cuh"
|
||||||
|
|
||||||
|
__global__ void CopyMemKernel(const size_t thread_num, const size_t N, const size_t C,
|
||||||
|
float *gamma_addr, float *beta_addr,
|
||||||
|
float *runing_mean_addr, float *runnig_variance_addr,
|
||||||
|
float *ws_gamma, float *ws_beta, float *ws_mean, float *ws_var) {
|
||||||
|
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < thread_num; pos += gridDim.x * blockDim.x) {
|
||||||
|
size_t cur_addr = pos / (N * C);
|
||||||
|
size_t cur_local_index = pos % (N * C);
|
||||||
|
size_t local_index = 0;
|
||||||
|
switch (cur_addr) {
|
||||||
|
case 0:
|
||||||
|
if (!(gamma_addr && ws_gamma)) break;
|
||||||
|
local_index = cur_local_index % C;
|
||||||
|
ws_gamma[cur_local_index] = gamma_addr[local_index];
|
||||||
|
break;
|
||||||
|
case 1:
|
||||||
|
if (!(beta_addr && ws_beta)) break;
|
||||||
|
local_index = cur_local_index % C;
|
||||||
|
ws_beta[cur_local_index] = beta_addr[local_index];
|
||||||
|
break;
|
||||||
|
case 2:
|
||||||
|
if (!(runing_mean_addr && ws_mean)) break;
|
||||||
|
local_index = cur_local_index % C;
|
||||||
|
ws_mean[cur_local_index] = runing_mean_addr[local_index];
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
if (!(runnig_variance_addr && ws_var)) break;
|
||||||
|
local_index = cur_local_index % C;
|
||||||
|
ws_var[cur_local_index] = runnig_variance_addr[local_index];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
void CopyMemDevice2Device(const size_t N, const size_t C, float *gamma_addr, float *beta_addr,
|
||||||
|
float *runing_mean_addr, float *runnig_variance_addr,
|
||||||
|
float *ws_gamma, float *ws_beta, float *ws_mean, float *ws_var,
|
||||||
|
cudaStream_t cuda_stream) {
|
||||||
|
size_t thread_num = N * C * 4;
|
||||||
|
CopyMemKernel<<<GET_BLOCKS(thread_num), GET_THREADS, 0, cuda_stream>>>(
|
||||||
|
thread_num, N, C, gamma_addr, beta_addr, runing_mean_addr, runnig_variance_addr,
|
||||||
|
ws_gamma, ws_beta, ws_mean, ws_var);
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void ComputeMeanKernel(const size_t thread_num, const size_t N, const size_t C,
|
||||||
|
float *save_mean_addr, float *save_var_addr) {
|
||||||
|
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < thread_num; pos += gridDim.x * blockDim.x) {
|
||||||
|
size_t cur_addr = pos / C;
|
||||||
|
size_t cur_local_index = pos % C;
|
||||||
|
float tmp = 0;
|
||||||
|
if (cur_addr) {
|
||||||
|
for (size_t i = 0; i < N; i++) {
|
||||||
|
tmp += save_var_addr[i * C + cur_local_index];
|
||||||
|
}
|
||||||
|
save_var_addr[cur_local_index] = tmp / N;
|
||||||
|
} else {
|
||||||
|
for (size_t i = 0; i < N; i++) {
|
||||||
|
tmp += save_mean_addr[i * C + cur_local_index];
|
||||||
|
}
|
||||||
|
save_mean_addr[cur_local_index] = tmp / N;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
void ComputeMean(const size_t N, const size_t C,
|
||||||
|
float *save_mean_addr, float *save_var_addr,
|
||||||
|
cudaStream_t cuda_stream) {
|
||||||
|
size_t thread_num = C * 2;
|
||||||
|
ComputeMeanKernel<<<GET_BLOCKS(thread_num), GET_THREADS, 0, cuda_stream>>>(
|
||||||
|
thread_num, N, C, save_mean_addr, save_var_addr);
|
||||||
|
}
|
@ -0,0 +1,27 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_INSTANCE_NORM_IMPL_H_
|
||||||
|
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_INSTANCE_NORM_IMPL_H_
|
||||||
|
|
||||||
|
#include "runtime/device/gpu/cuda_common.h"
|
||||||
|
void CopyMemDevice2Device(const size_t N, const size_t C,
|
||||||
|
float *gamma_addr, float *beta_addr, float *runing_mean_addr, float *runnig_variance_addr,
|
||||||
|
float *ws_gamma, float *ws_beta, float *ws_mean, float *ws_var,
|
||||||
|
cudaStream_t cuda_stream);
|
||||||
|
void ComputeMean(const size_t N, const size_t C, float *save_mean_addr, float *save_var_addr,
|
||||||
|
cudaStream_t cuda_stream);
|
||||||
|
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_INSTANCE_NORM_IMPL_H_
|
@ -0,0 +1,44 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "backend/kernel_compiler/gpu/nn/instance_norm_gpu_kernel.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace kernel {
|
||||||
|
MS_REG_GPU_KERNEL_ONE(InstanceNorm,
|
||||||
|
KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeFloat32)
|
||||||
|
.AddInputAttr(kNumberTypeFloat32)
|
||||||
|
.AddInputAttr(kNumberTypeFloat32)
|
||||||
|
.AddInputAttr(kNumberTypeFloat32)
|
||||||
|
.AddInputAttr(kNumberTypeFloat32)
|
||||||
|
.AddOutputAttr(kNumberTypeFloat32)
|
||||||
|
.AddOutputAttr(kNumberTypeFloat32)
|
||||||
|
.AddOutputAttr(kNumberTypeFloat32),
|
||||||
|
InstanceNormGpuKernel, float)
|
||||||
|
MS_REG_GPU_KERNEL_ONE(InstanceNorm,
|
||||||
|
KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeFloat16)
|
||||||
|
.AddInputAttr(kNumberTypeFloat32)
|
||||||
|
.AddInputAttr(kNumberTypeFloat32)
|
||||||
|
.AddInputAttr(kNumberTypeFloat32)
|
||||||
|
.AddInputAttr(kNumberTypeFloat32)
|
||||||
|
.AddOutputAttr(kNumberTypeFloat16)
|
||||||
|
.AddOutputAttr(kNumberTypeFloat32)
|
||||||
|
.AddOutputAttr(kNumberTypeFloat32),
|
||||||
|
InstanceNormGpuKernel, half)
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace mindspore
|
@ -0,0 +1,240 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_INSTANCE_NORM_GPU_KERNEL_H_
|
||||||
|
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_INSTANCE_NORM_GPU_KERNEL_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||||
|
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||||
|
#include "backend/kernel_compiler/gpu/kernel_constants.h"
|
||||||
|
#include "utils/utils.h"
|
||||||
|
#include "backend/kernel_compiler/gpu/cuda_impl/instance_norm_impl.cuh"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace kernel {
|
||||||
|
template <typename T>
|
||||||
|
class InstanceNormGpuKernel : public GpuKernel {
|
||||||
|
public:
|
||||||
|
InstanceNormGpuKernel()
|
||||||
|
: input_x_size_(0),
|
||||||
|
input_z_size_(0),
|
||||||
|
para_size_(0),
|
||||||
|
output_size_(0),
|
||||||
|
workspace_size_(0),
|
||||||
|
mode_(CUDNN_BATCHNORM_SPATIAL),
|
||||||
|
bn_ops_(CUDNN_BATCHNORM_OPS_BN),
|
||||||
|
is_training_(true),
|
||||||
|
epsilon_(10e-5),
|
||||||
|
exp_avg_factor_(0.1),
|
||||||
|
is_null_input_(false),
|
||||||
|
x_desc_(nullptr),
|
||||||
|
y_desc_(nullptr),
|
||||||
|
z_desc_(nullptr),
|
||||||
|
scale_bias_mean_var_desc_(nullptr),
|
||||||
|
handle_(nullptr),
|
||||||
|
cudnn_data_type_(CUDNN_DATA_FLOAT) {}
|
||||||
|
~InstanceNormGpuKernel() override { DestroyResource(); }
|
||||||
|
|
||||||
|
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||||
|
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||||
|
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||||
|
|
||||||
|
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||||
|
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||||
|
VARIABLE_NOT_USED(workspace);
|
||||||
|
VARIABLE_NOT_USED(stream_ptr);
|
||||||
|
if (is_null_input_) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
auto x_addr = GetDeviceAddress<T>(inputs, 0);
|
||||||
|
auto gamma_addr = GetDeviceAddress<float>(inputs, 1);
|
||||||
|
auto beta_addr = GetDeviceAddress<float>(inputs, 2);
|
||||||
|
auto runing_mean_addr = GetDeviceAddress<float>(inputs, 3);
|
||||||
|
auto runnig_variance_addr = GetDeviceAddress<float>(inputs, 4);
|
||||||
|
T *z = nullptr;
|
||||||
|
|
||||||
|
auto y_addr = GetDeviceAddress<T>(outputs, 0);
|
||||||
|
auto save_mean_addr = GetDeviceAddress<float>(outputs, 1);
|
||||||
|
auto save_variance_addr = GetDeviceAddress<float>(outputs, 2);
|
||||||
|
|
||||||
|
float *ws_gamma = GetDeviceAddress<float>(workspace, 0);
|
||||||
|
float *ws_beta = GetDeviceAddress<float>(workspace, 1);
|
||||||
|
float *ws_mean = GetDeviceAddress<float>(workspace, 2);
|
||||||
|
float *ws_var = GetDeviceAddress<float>(workspace, 3);
|
||||||
|
T *workspace_addr = nullptr;
|
||||||
|
if (workspace_size_ != 0) {
|
||||||
|
workspace_addr = GetDeviceAddress<T>(workspace, 4);
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t N = input_shape_[0];
|
||||||
|
size_t C = input_shape_[1];
|
||||||
|
CopyMemDevice2Device(N, C, gamma_addr, beta_addr, runing_mean_addr, runnig_variance_addr, ws_gamma, ws_beta,
|
||||||
|
ws_mean, ws_var, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||||
|
|
||||||
|
const float alpha = 1;
|
||||||
|
const float beta = 0;
|
||||||
|
float *reserve_addr = nullptr;
|
||||||
|
if (is_training_) {
|
||||||
|
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||||
|
kernel_node_,
|
||||||
|
cudnnBatchNormalizationForwardTrainingEx(
|
||||||
|
handle_, mode_, bn_ops_, &alpha, &beta, x_desc_, x_addr, z_desc_, z, y_desc_, y_addr,
|
||||||
|
scale_bias_mean_var_desc_, ws_gamma, ws_beta, exp_avg_factor_, ws_mean, ws_var, epsilon_, save_mean_addr,
|
||||||
|
save_variance_addr, nullptr, workspace_addr, workspace_size_, reserve_addr, 0),
|
||||||
|
"Kernel launch failed");
|
||||||
|
} else {
|
||||||
|
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
|
||||||
|
cudnnBatchNormalizationForwardInference(
|
||||||
|
handle_, mode_, &alpha, &beta, x_desc_, x_addr, y_desc_, y_addr,
|
||||||
|
scale_bias_mean_var_desc_, ws_gamma, ws_beta, ws_mean, ws_var, epsilon_),
|
||||||
|
"Kernel launch failed");
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Init(const CNodePtr &kernel_node) override {
|
||||||
|
kernel_node_ = kernel_node;
|
||||||
|
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||||
|
std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node);
|
||||||
|
bn_ops_ = CUDNN_BATCHNORM_OPS_BN;
|
||||||
|
|
||||||
|
InitResource();
|
||||||
|
is_training_ = GetAttr<bool>(kernel_node, "is_training");
|
||||||
|
mode_ = is_training_ ? CUDNN_BATCHNORM_SPATIAL_PERSISTENT : CUDNN_BATCHNORM_SPATIAL;
|
||||||
|
epsilon_ = GetAttr<float>(kernel_node, "epsilon");
|
||||||
|
exp_avg_factor_ = GetAttr<float>(kernel_node, "momentum");
|
||||||
|
|
||||||
|
cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0)));
|
||||||
|
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||||
|
if (input_num != 5) {
|
||||||
|
MS_LOG(EXCEPTION) << "input tensor size is " << input_num << ", " << kernel_name << " should be 5";
|
||||||
|
}
|
||||||
|
input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||||
|
if (input_shape_.size() != 4) {
|
||||||
|
MS_LOG(EXCEPTION) << "tensor shape is " << input_shape_.size() << ", InstanceNormGpuKernel should be 4";
|
||||||
|
}
|
||||||
|
is_null_input_ = CHECK_NULL_INPUT(input_shape_);
|
||||||
|
if (is_null_input_) {
|
||||||
|
MS_LOG(WARNING) << "InstanceNormGpuKernel input is null";
|
||||||
|
InitSizeLists();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
SetTensorDescriptor();
|
||||||
|
InitSizeLists();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void DestroyResource() noexcept override {
|
||||||
|
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(x_desc_), "Destroy x desc failed");
|
||||||
|
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(y_desc_), "Destroy y desc failed");
|
||||||
|
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(scale_bias_mean_var_desc_),
|
||||||
|
"Destroy para desc failed");
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
void InitResource() override {
|
||||||
|
handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle();
|
||||||
|
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&x_desc_), "Create x desc failed");
|
||||||
|
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&y_desc_), "Create y desc failed");
|
||||||
|
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&scale_bias_mean_var_desc_),
|
||||||
|
"Create para desc failed");
|
||||||
|
}
|
||||||
|
|
||||||
|
void InitSizeLists() override {
|
||||||
|
if (!is_null_input_) {
|
||||||
|
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnGetTensorSizeInBytes(x_desc_, &input_x_size_),
|
||||||
|
"Get input x size failed");
|
||||||
|
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnGetTensorSizeInBytes(scale_bias_mean_var_desc_, ¶_size_),
|
||||||
|
"Get para size failed");
|
||||||
|
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnGetTensorSizeInBytes(y_desc_, &output_size_),
|
||||||
|
"Get output size failed");
|
||||||
|
|
||||||
|
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||||
|
kernel_node_,
|
||||||
|
cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize(handle_, mode_, bn_ops_, x_desc_, z_desc_, y_desc_,
|
||||||
|
scale_bias_mean_var_desc_, nullptr, &workspace_size_),
|
||||||
|
"cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize failed");
|
||||||
|
}
|
||||||
|
|
||||||
|
input_size_list_.push_back(input_x_size_); // input x
|
||||||
|
input_size_list_.push_back(input_shape_[1]); // gamma
|
||||||
|
input_size_list_.push_back(input_shape_[1]); // beta
|
||||||
|
input_size_list_.push_back(input_shape_[1]); // mean
|
||||||
|
input_size_list_.push_back(input_shape_[1]); // variance
|
||||||
|
|
||||||
|
output_size_list_.push_back(output_size_); // output
|
||||||
|
output_size_list_.push_back(para_size_); // save mean
|
||||||
|
output_size_list_.push_back(para_size_); // save variance
|
||||||
|
|
||||||
|
workspace_size_list_.push_back(para_size_); // ws gamma
|
||||||
|
workspace_size_list_.push_back(para_size_); // ws beta
|
||||||
|
workspace_size_list_.push_back(para_size_); // ws mean
|
||||||
|
workspace_size_list_.push_back(para_size_); // ws variance
|
||||||
|
workspace_size_list_.push_back(workspace_size_);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
void SetTensorDescriptor() {
|
||||||
|
cudnnTensorFormat_t cudnn_format;
|
||||||
|
int batch, channel, height, width;
|
||||||
|
batch = 1;
|
||||||
|
channel = SizeToInt(input_shape_[0]) * SizeToInt(input_shape_[1]);
|
||||||
|
height = SizeToInt(input_shape_[2]);
|
||||||
|
width = SizeToInt(input_shape_[3]);
|
||||||
|
cudnn_format = CUDNN_TENSOR_NCHW;
|
||||||
|
|
||||||
|
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||||
|
kernel_node_, cudnnSetTensor4dDescriptor(x_desc_, cudnn_format, cudnn_data_type_, batch, channel, height, width),
|
||||||
|
"Set x desc failed");
|
||||||
|
|
||||||
|
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||||
|
kernel_node_, cudnnSetTensor4dDescriptor(y_desc_, cudnn_format, cudnn_data_type_, batch, channel, height, width),
|
||||||
|
"Set y desc failed");
|
||||||
|
|
||||||
|
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||||
|
kernel_node_,
|
||||||
|
cudnnSetTensor4dDescriptor(scale_bias_mean_var_desc_, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, channel, 1, 1),
|
||||||
|
"Set para desc failed");
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t input_x_size_;
|
||||||
|
size_t input_z_size_;
|
||||||
|
size_t para_size_;
|
||||||
|
size_t output_size_;
|
||||||
|
size_t workspace_size_;
|
||||||
|
cudnnBatchNormMode_t mode_;
|
||||||
|
cudnnBatchNormOps_t bn_ops_;
|
||||||
|
bool is_training_;
|
||||||
|
double epsilon_;
|
||||||
|
double exp_avg_factor_;
|
||||||
|
bool is_null_input_;
|
||||||
|
cudnnTensorDescriptor_t x_desc_;
|
||||||
|
cudnnTensorDescriptor_t y_desc_;
|
||||||
|
cudnnTensorDescriptor_t z_desc_;
|
||||||
|
cudnnTensorDescriptor_t scale_bias_mean_var_desc_;
|
||||||
|
|
||||||
|
cudnnHandle_t handle_;
|
||||||
|
cudnnDataType_t cudnn_data_type_;
|
||||||
|
std::vector<size_t> input_shape_;
|
||||||
|
std::vector<size_t> input_size_list_;
|
||||||
|
std::vector<size_t> output_size_list_;
|
||||||
|
std::vector<size_t> workspace_size_list_;
|
||||||
|
};
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_INSTANCE_NORM_GPU_KERNEL_H_
|
@ -0,0 +1,44 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "backend/kernel_compiler/gpu/nn/instance_norm_grad_gpu_kernel.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace kernel {
|
||||||
|
MS_REG_GPU_KERNEL_ONE(InstanceNormGrad,
|
||||||
|
KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeFloat32) // dy
|
||||||
|
.AddInputAttr(kNumberTypeFloat32) // x
|
||||||
|
.AddInputAttr(kNumberTypeFloat32) // scale
|
||||||
|
.AddInputAttr(kNumberTypeFloat32) // save_mean
|
||||||
|
.AddInputAttr(kNumberTypeFloat32) // save_variance
|
||||||
|
.AddOutputAttr(kNumberTypeFloat32) // dx
|
||||||
|
.AddOutputAttr(kNumberTypeFloat32) // dscale
|
||||||
|
.AddOutputAttr(kNumberTypeFloat32), // dbias
|
||||||
|
InstanceNormGradGpuKernel, float)
|
||||||
|
MS_REG_GPU_KERNEL_ONE(InstanceNormGrad,
|
||||||
|
KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeFloat16) // dy
|
||||||
|
.AddInputAttr(kNumberTypeFloat16) // x
|
||||||
|
.AddInputAttr(kNumberTypeFloat32) // scale
|
||||||
|
.AddInputAttr(kNumberTypeFloat32) // save_mean
|
||||||
|
.AddInputAttr(kNumberTypeFloat32) // save_variance
|
||||||
|
.AddOutputAttr(kNumberTypeFloat16) // dx
|
||||||
|
.AddOutputAttr(kNumberTypeFloat32) // dscale
|
||||||
|
.AddOutputAttr(kNumberTypeFloat32), // dbias
|
||||||
|
InstanceNormGradGpuKernel, half)
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace mindspore
|
@ -0,0 +1,238 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_INSTANCE_NORM_GRAD_GPU_KERNEL_H_
|
||||||
|
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_INSTANCE_NORM_GRAD_GPU_KERNEL_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include "utils/utils.h"
|
||||||
|
|
||||||
|
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||||
|
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||||
|
#include "backend/kernel_compiler/gpu/kernel_constants.h"
|
||||||
|
#include "backend/kernel_compiler/gpu/cuda_impl/instance_norm_impl.cuh"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace kernel {
|
||||||
|
template <typename T>
|
||||||
|
class InstanceNormGradGpuKernel : public GpuKernel {
|
||||||
|
public:
|
||||||
|
InstanceNormGradGpuKernel()
|
||||||
|
: x_size_(0),
|
||||||
|
para_size_(0),
|
||||||
|
workspace_size_(0),
|
||||||
|
mode_(CUDNN_BATCHNORM_SPATIAL),
|
||||||
|
bn_ops_(CUDNN_BATCHNORM_OPS_BN),
|
||||||
|
epsilon_(10e-5),
|
||||||
|
is_training_(true),
|
||||||
|
is_null_input_(false),
|
||||||
|
x_desc_(nullptr),
|
||||||
|
y_desc_(nullptr),
|
||||||
|
dy_desc_(nullptr),
|
||||||
|
dx_desc_(nullptr),
|
||||||
|
dz_desc_(nullptr),
|
||||||
|
scale_bias_diff_desc_(nullptr),
|
||||||
|
activation_desc_(nullptr),
|
||||||
|
handle_(nullptr),
|
||||||
|
cudnn_data_type_(CUDNN_DATA_FLOAT),
|
||||||
|
beta_data_diff_(0) {}
|
||||||
|
~InstanceNormGradGpuKernel() override { DestroyResource(); }
|
||||||
|
|
||||||
|
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||||
|
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||||
|
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||||
|
|
||||||
|
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||||
|
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||||
|
VARIABLE_NOT_USED(workspace);
|
||||||
|
VARIABLE_NOT_USED(stream_ptr);
|
||||||
|
if (is_null_input_) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
auto dy = GetDeviceAddress<T>(inputs, 0);
|
||||||
|
auto x = GetDeviceAddress<T>(inputs, 1);
|
||||||
|
auto gamma = GetDeviceAddress<float>(inputs, 2);
|
||||||
|
auto save_mean = GetDeviceAddress<float>(inputs, 3);
|
||||||
|
auto save_variance = GetDeviceAddress<float>(inputs, 4);
|
||||||
|
void *beta = nullptr;
|
||||||
|
T *y = nullptr;
|
||||||
|
|
||||||
|
auto dx = GetDeviceAddress<T>(outputs, 0);
|
||||||
|
auto dgamma = GetDeviceAddress<float>(outputs, 1);
|
||||||
|
auto dbeta = GetDeviceAddress<float>(outputs, 2);
|
||||||
|
T *dz = nullptr;
|
||||||
|
|
||||||
|
float *ws_gamma = GetDeviceAddress<float>(workspace, 0);
|
||||||
|
void *workspace_addr = nullptr;
|
||||||
|
if (workspace_size_ != 0) {
|
||||||
|
workspace_addr = GetDeviceAddress<T>(workspace, 3);
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t N = input_shape_[0];
|
||||||
|
size_t C = input_shape_[1];
|
||||||
|
CopyMemDevice2Device(N, C, gamma, nullptr, nullptr, nullptr, ws_gamma, nullptr, nullptr, nullptr,
|
||||||
|
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||||
|
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, cudaStreamSynchronize(reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||||
|
"cudaStreamSynchronized failed");
|
||||||
|
|
||||||
|
const float alpha_data_diff = 1;
|
||||||
|
const float alpha_param_diff = 1;
|
||||||
|
const float beta_param_diff = 0;
|
||||||
|
float *reserve_addr = nullptr;
|
||||||
|
if (is_training_) {
|
||||||
|
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||||
|
kernel_node_,
|
||||||
|
cudnnBatchNormalizationBackwardEx(
|
||||||
|
handle_, mode_, bn_ops_, &alpha_data_diff, &beta_data_diff_, &alpha_param_diff, &beta_param_diff, x_desc_, x,
|
||||||
|
y_desc_, y, dy_desc_, dy, dz_desc_, dz, dx_desc_, dx, scale_bias_diff_desc_, ws_gamma, beta, dgamma, dbeta,
|
||||||
|
epsilon_, save_mean, save_variance, activation_desc_, workspace_addr, workspace_size_, reserve_addr, 0),
|
||||||
|
"Kernel launch failed");
|
||||||
|
ComputeMean(N, C, dgamma, dbeta, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||||
|
} else {
|
||||||
|
MS_LOG(EXCEPTION) << "The backward of InstanceNorm operator in evaluation mode is not implemented yet.";
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Init(const CNodePtr &kernel_node) override {
|
||||||
|
kernel_node_ = kernel_node;
|
||||||
|
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||||
|
std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node);
|
||||||
|
bn_ops_ = CUDNN_BATCHNORM_OPS_BN;
|
||||||
|
|
||||||
|
InitResource();
|
||||||
|
is_training_ = GetAttr<bool>(kernel_node, "is_training");
|
||||||
|
mode_ = is_training_ ? CUDNN_BATCHNORM_SPATIAL_PERSISTENT : CUDNN_BATCHNORM_SPATIAL;
|
||||||
|
epsilon_ = GetAttr<float>(kernel_node, "epsilon");
|
||||||
|
|
||||||
|
cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0)));
|
||||||
|
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||||
|
if (input_num != 5) {
|
||||||
|
MS_LOG(EXCEPTION) << "input tensor size is " << input_num << ", " << kernel_name << " should be 5";
|
||||||
|
}
|
||||||
|
|
||||||
|
input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||||
|
if (input_shape_.size() != 4) {
|
||||||
|
MS_LOG(EXCEPTION) << "tensor shape is " << input_shape_.size() << ", InstanceNormGradGpuKernel should be 4";
|
||||||
|
}
|
||||||
|
is_null_input_ = CHECK_NULL_INPUT(input_shape_);
|
||||||
|
if (is_null_input_) {
|
||||||
|
MS_LOG(WARNING) << "InstanceNormGradGpuKernel input is null";
|
||||||
|
InitSizeLists();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
beta_data_diff_ = GetAttrWithDefault(kernel_node, "inplace_algo", std::string("cover")) == "cover" ? 0 : 1;
|
||||||
|
SetTensorDescriptor();
|
||||||
|
InitSizeLists();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
void InitResource() override {
|
||||||
|
handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle();
|
||||||
|
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&x_desc_), "Create x desc failed");
|
||||||
|
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&dy_desc_), "Create dy desc failed");
|
||||||
|
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&dx_desc_), "Create dx desc failed");
|
||||||
|
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&scale_bias_diff_desc_),
|
||||||
|
"Create para desc failed");
|
||||||
|
}
|
||||||
|
|
||||||
|
void InitSizeLists() override {
|
||||||
|
if (!is_null_input_) {
|
||||||
|
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnGetTensorSizeInBytes(x_desc_, &x_size_), "Get x size failed");
|
||||||
|
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnGetTensorSizeInBytes(scale_bias_diff_desc_, ¶_size_),
|
||||||
|
"Get para size failed");
|
||||||
|
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
|
||||||
|
cudnnGetBatchNormalizationBackwardExWorkspaceSize(
|
||||||
|
handle_, mode_, bn_ops_, x_desc_, y_desc_, dy_desc_, dz_desc_, dx_desc_,
|
||||||
|
scale_bias_diff_desc_, activation_desc_, &workspace_size_),
|
||||||
|
"cudnnGetBatchNormalizationBackwardExWorkspaceSize failed");
|
||||||
|
}
|
||||||
|
|
||||||
|
input_size_list_.push_back(x_size_);
|
||||||
|
input_size_list_.push_back(x_size_);
|
||||||
|
input_size_list_.push_back(input_shape_[1]);
|
||||||
|
input_size_list_.push_back(para_size_);
|
||||||
|
input_size_list_.push_back(para_size_);
|
||||||
|
|
||||||
|
output_size_list_.push_back(x_size_);
|
||||||
|
output_size_list_.push_back(para_size_);
|
||||||
|
output_size_list_.push_back(para_size_);
|
||||||
|
|
||||||
|
workspace_size_list_.push_back(para_size_); // ws gamma
|
||||||
|
workspace_size_list_.push_back(workspace_size_);
|
||||||
|
}
|
||||||
|
void DestroyResource() noexcept override {
|
||||||
|
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(x_desc_), "Destroy x desc failed");
|
||||||
|
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(dy_desc_), "Destroy dy desc failed");
|
||||||
|
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(dx_desc_), "Destroy dx desc failed");
|
||||||
|
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(scale_bias_diff_desc_),
|
||||||
|
"Destroy para desc failed");
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
void SetTensorDescriptor() {
|
||||||
|
int batch, channel, height, width;
|
||||||
|
batch = 1;
|
||||||
|
channel = SizeToInt(input_shape_[0]) * SizeToInt(input_shape_[1]);
|
||||||
|
height = SizeToInt(input_shape_[2]);
|
||||||
|
width = SizeToInt(input_shape_[3]);
|
||||||
|
cudnnTensorFormat_t cudnn_format = CUDNN_TENSOR_NCHW;
|
||||||
|
|
||||||
|
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||||
|
kernel_node_, cudnnSetTensor4dDescriptor(x_desc_, cudnn_format, cudnn_data_type_, batch, channel, height, width),
|
||||||
|
"Set x desc failed");
|
||||||
|
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||||
|
kernel_node_, cudnnSetTensor4dDescriptor(dy_desc_, cudnn_format, cudnn_data_type_, batch, channel, height, width),
|
||||||
|
"Set dy desc failed");
|
||||||
|
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||||
|
kernel_node_, cudnnSetTensor4dDescriptor(dx_desc_, cudnn_format, cudnn_data_type_, batch, channel, height, width),
|
||||||
|
"Set dx desc failed");
|
||||||
|
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||||
|
kernel_node_,
|
||||||
|
cudnnSetTensor4dDescriptor(scale_bias_diff_desc_, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, channel, 1, 1),
|
||||||
|
"Set para desc failed");
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t x_size_;
|
||||||
|
size_t para_size_;
|
||||||
|
size_t workspace_size_;
|
||||||
|
cudnnBatchNormMode_t mode_;
|
||||||
|
cudnnBatchNormOps_t bn_ops_;
|
||||||
|
double epsilon_;
|
||||||
|
bool is_training_;
|
||||||
|
bool is_null_input_;
|
||||||
|
|
||||||
|
cudnnTensorDescriptor_t x_desc_;
|
||||||
|
cudnnTensorDescriptor_t y_desc_;
|
||||||
|
cudnnTensorDescriptor_t dy_desc_;
|
||||||
|
cudnnTensorDescriptor_t dx_desc_;
|
||||||
|
cudnnTensorDescriptor_t dz_desc_;
|
||||||
|
cudnnTensorDescriptor_t scale_bias_diff_desc_;
|
||||||
|
cudnnActivationDescriptor_t activation_desc_;
|
||||||
|
|
||||||
|
cudnnHandle_t handle_;
|
||||||
|
cudnnDataType_t cudnn_data_type_;
|
||||||
|
float beta_data_diff_;
|
||||||
|
std::vector<size_t> input_shape_;
|
||||||
|
std::vector<size_t> input_size_list_;
|
||||||
|
std::vector<size_t> output_size_list_;
|
||||||
|
std::vector<size_t> workspace_size_list_;
|
||||||
|
};
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_INSTANCE_NORM_GRAD_GPU_KERNEL_H_
|
@ -0,0 +1,62 @@
|
|||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import mindspore.context as context
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore.common.api import ms_function
|
||||||
|
from mindspore.ops import functional as F
|
||||||
|
from mindspore.ops.composite import GradOperation
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||||
|
|
||||||
|
class Grad(nn.Cell):
|
||||||
|
def __init__(self, network):
|
||||||
|
super(Grad, self).__init__()
|
||||||
|
self.grad = GradOperation(get_all=True, sens_param=True)
|
||||||
|
self.network = network
|
||||||
|
|
||||||
|
@ms_function
|
||||||
|
def construct(self, input_x, grad):
|
||||||
|
return self.grad(self.network)(input_x, grad)
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self, n):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.ops = nn.BatchNorm2d(n, use_batch_statistics=True, gamma_init=0.5, beta_init=0.5)
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
shape = F.shape(x)
|
||||||
|
return F.reshape(self.ops(F.reshape(x, (1, -1, shape[2], shape[3]))), shape)
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_InstanceNorm2d_fp32():
|
||||||
|
x_np = np.random.randn(3, 3, 2, 2).astype(np.float32)
|
||||||
|
bn_instance_comp = Net(3 * 3)
|
||||||
|
bn_instance_op = nn.InstanceNorm2d(3, use_batch_statistics=True, gamma_init=0.5, beta_init=0.5)
|
||||||
|
comp_out = bn_instance_comp(Tensor(x_np))
|
||||||
|
op_out = bn_instance_op(Tensor(x_np))
|
||||||
|
assert np.allclose(comp_out.asnumpy(), op_out.asnumpy())
|
||||||
|
|
||||||
|
sens = np.random.randn(3, 3, 2, 2).astype(np.float32)
|
||||||
|
bn_comp_backward_net = Grad(bn_instance_comp)
|
||||||
|
bn_op_backward_net = Grad(bn_instance_op)
|
||||||
|
output1 = bn_comp_backward_net(Tensor(x_np), Tensor(sens))
|
||||||
|
output2 = bn_op_backward_net(Tensor(x_np), Tensor(sens))
|
||||||
|
assert np.allclose(output1[0].asnumpy(), output2[0].asnumpy())
|
Loading…
Reference in new issue