|
|
|
@ -96,6 +96,7 @@ class BatchNormGpuKernel : public GpuKernel {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
InitResource();
|
|
|
|
|
is_train_ = GetAttr<bool>(kernel_node, "is_training");
|
|
|
|
|
if (is_train_) {
|
|
|
|
|
mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
|
|
|
|
|
} else {
|
|
|
|
@ -133,7 +134,6 @@ class BatchNormGpuKernel : public GpuKernel {
|
|
|
|
|
}
|
|
|
|
|
SetTensorDescriptor(format, shape);
|
|
|
|
|
InitSizeLists();
|
|
|
|
|
is_train_ = GetAttr<bool>(kernel_node, "is_training");
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -229,8 +229,8 @@ class BatchNormGpuKernel : public GpuKernel {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
output_size_list_.push_back(output_size_); // output
|
|
|
|
|
output_size_list_.push_back(reserve_size_); // reserve space
|
|
|
|
|
output_size_list_.push_back(para_size_); // save scale
|
|
|
|
|
output_size_list_.push_back(reserve_size_); // reserve space
|
|
|
|
|
output_size_list_.push_back(para_size_); // save mean
|
|
|
|
|
output_size_list_.push_back(para_size_); // save variance
|
|
|
|
|
|
|
|
|
|