[ROCM] fix conv2d and conv3d op, test=develop (#31553)

pull/1/head
Qi Li 4 years ago committed by GitHub
parent f302bb4f8b
commit 3d5aa9d10a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

File diff suppressed because it is too large Load Diff

@ -127,17 +127,29 @@ struct SearchAlgorithm<miopenConvFwdAlgorithm_t> {
template <typename T>
static algo_t Find(const ConvArgs& args, bool exhaustive_search,
bool deterministic,
bool deterministic, size_t workspace_size,
const framework::ExecutionContext& ctx) {
auto dtype = platform::CudnnDataType<T>::type;
bool has_got_workspace_size = true;
size_t workspace_size_limit = FLAGS_conv_workspace_size_limit * 1024 * 1024;
size_t workspace_size = 0;
algo_t algo;
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
int find_count;
miopenConvAlgoPerf_t find_result;
auto cudnn_find_func = [&](void* cudnn_workspace_ptr) {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::miopenFindConvolutionForwardAlgorithm(
args.handle, args.idesc.desc(), args.x->data<T>(),
args.wdesc.desc(), args.w->data<T>(), args.cdesc.desc(),
args.odesc.desc(), const_cast<T*>(args.o->data<T>()),
kNUM_CUDNN_FWD_ALGS, &find_count, &find_result,
cudnn_workspace_ptr, workspace_size, false));
};
if (!exhaustive_search && !deterministic) {
workspace_handle.RunFuncSync(cudnn_find_func, workspace_size);
algo = find_result.fwd_algo;
} else {
auto& temp = ctx.cuda_device_context();
AlgorithmsCache<algo_t>& algo_cache =
*(framework::ConvSearchCache::Instance().GetForward());
@ -152,32 +164,15 @@ struct SearchAlgorithm<miopenConvFwdAlgorithm_t> {
algo = algo_cache.GetAlgorithm(
x_dims, w_dims, args.s, args.p, args.d, 0,
static_cast<int64_t>(args.cudnn_dtype), [&]() {
int returned_algo_count;
std::array<perf_t, kNUM_CUDNN_FWD_ALGS> perf_stat;
auto cudnn_find_func = [&](void* cudnn_workspace_ptr) {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::miopenFindConvolutionForwardAlgorithm(
args.handle, args.idesc.desc(), args.x->data<T>(),
args.wdesc.desc(), args.w->data<T>(), args.cdesc.desc(),
args.odesc.desc(), const_cast<T*>(args.o->data<T>()),
kNUM_CUDNN_FWD_ALGS, &returned_algo_count, perf_stat.data(),
cudnn_workspace_ptr, workspace_size_limit, false));
};
workspace_handle.RunFuncSync(cudnn_find_func, workspace_size_limit);
VLOG(3) << "FwdAlgo Perf result: (algo: stat, time, memory)";
for (int i = 0; i < returned_algo_count; ++i) {
const auto& stat = perf_stat[i];
VLOG(3) << stat.fwd_algo;
}
return perf_stat[0].fwd_algo;
workspace_handle.RunFuncSync(cudnn_find_func, workspace_size);
return find_result.fwd_algo;
});
}
VLOG(3) << "choose algo " << algo;
return algo;
}
static size_t GetWorkspaceSize(const ConvArgs& args, algo_t algo) {
static size_t GetWorkspaceSize(const ConvArgs& args) {
size_t workspace_size = 0;
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::miopenConvolutionForwardGetWorkSpaceSize(
@ -194,17 +189,29 @@ struct SearchAlgorithm<miopenConvBwdDataAlgorithm_t> {
template <typename T>
static algo_t Find(const ConvArgs& args, bool exhaustive_search,
bool deterministic,
bool deterministic, size_t workspace_size,
const framework::ExecutionContext& ctx) {
auto dtype = platform::CudnnDataType<T>::type;
size_t workspace_size_limit = FLAGS_conv_workspace_size_limit * 1024 * 1024;
size_t workspace_size = 0;
bool has_got_workspace_size = true;
algo_t algo;
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
int find_count;
miopenConvAlgoPerf_t find_result;
auto cudnn_find_func = [&](void* cudnn_workspace_ptr) {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::miopenFindConvolutionBackwardDataAlgorithm(
args.handle, args.odesc.desc(), args.o->data<T>(),
args.wdesc.desc(), args.w->data<T>(), args.cdesc.desc(),
args.idesc.desc(), const_cast<T*>(args.x->data<T>()),
kNUM_CUDNN_BWD_DATA_ALGS, &find_count, &find_result,
cudnn_workspace_ptr, workspace_size, false));
};
if (!exhaustive_search && !deterministic) {
workspace_handle.RunFuncSync(cudnn_find_func, workspace_size);
algo = find_result.bwd_data_algo;
} else {
AlgorithmsCache<algo_t>& algo_cache =
*(framework::ConvSearchCache::Instance().GetBackwardData());
@ -218,34 +225,15 @@ struct SearchAlgorithm<miopenConvBwdDataAlgorithm_t> {
algo = algo_cache.GetAlgorithm(
x_dims, w_dims, args.s, args.p, args.d, 0,
static_cast<int64_t>(args.cudnn_dtype), [&]() {
int returned_algo_count;
std::array<perf_t, kNUM_CUDNN_FWD_ALGS> perf_stat;
auto cudnn_find_func = [&](void* cudnn_workspace_ptr) {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::miopenFindConvolutionBackwardDataAlgorithm(
args.handle, args.odesc.desc(), args.o->data<T>(),
args.wdesc.desc(), args.w->data<T>(), args.cdesc.desc(),
args.idesc.desc(), const_cast<T*>(args.x->data<T>()),
kNUM_CUDNN_BWD_DATA_ALGS, &returned_algo_count,
perf_stat.data(), cudnn_workspace_ptr, workspace_size_limit,
false));
};
workspace_handle.RunFuncSync(cudnn_find_func, workspace_size_limit);
VLOG(3) << "BwdDataAlgo Perf result: (algo: stat, time, memory)";
for (int i = 0; i < returned_algo_count; ++i) {
const auto& stat = perf_stat[i];
VLOG(3) << stat.bwd_data_algo;
}
return perf_stat[0].bwd_data_algo;
workspace_handle.RunFuncSync(cudnn_find_func, workspace_size);
return find_result.bwd_data_algo;
});
}
VLOG(3) << "choose algo " << algo;
return algo;
}
static size_t GetWorkspaceSize(const ConvArgs& args, algo_t algo) {
static size_t GetWorkspaceSize(const ConvArgs& args) {
size_t workspace_size = 0;
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::miopenConvolutionBackwardDataGetWorkSpaceSize(
@ -262,16 +250,29 @@ struct SearchAlgorithm<miopenConvBwdWeightsAlgorithm_t> {
template <typename T>
static algo_t Find(const ConvArgs& args, bool exhaustive_search,
bool deterministic,
bool deterministic, size_t workspace_size,
const framework::ExecutionContext& ctx) {
auto dtype = platform::CudnnDataType<T>::type;
size_t workspace_size_limit = FLAGS_conv_workspace_size_limit * 1024 * 1024;
size_t workspace_size = 0;
bool has_got_workspace_size = true;
algo_t algo;
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
int find_count;
miopenConvAlgoPerf_t find_result;
auto cudnn_find_func = [&](void* cudnn_workspace_ptr) {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::miopenFindConvolutionBackwardWeightsAlgorithm(
args.handle, args.odesc.desc(), args.o->data<T>(),
args.idesc.desc(), args.x->data<T>(), args.cdesc.desc(),
args.wdesc.desc(), const_cast<T*>(args.w->data<T>()),
kNUM_CUDNN_BWD_FILTER_ALGS, &find_count, &find_result,
cudnn_workspace_ptr, workspace_size, false));
};
if (!exhaustive_search && !deterministic) {
workspace_handle.RunFuncSync(cudnn_find_func, workspace_size);
algo = find_result.bwd_weights_algo;
} else {
AlgorithmsCache<algo_t>& algo_cache =
*(framework::ConvSearchCache::Instance().GetBackwardFilter());
@ -285,33 +286,15 @@ struct SearchAlgorithm<miopenConvBwdWeightsAlgorithm_t> {
algo = algo_cache.GetAlgorithm(
x_dims, w_dims, args.s, args.p, args.d, 0,
static_cast<int64_t>(args.cudnn_dtype), [&]() {
int returned_algo_count;
std::array<perf_t, kNUM_CUDNN_FWD_ALGS> perf_stat;
auto cudnn_find_func = [&](void* cudnn_workspace_ptr) {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::
miopenFindConvolutionBackwardWeightsAlgorithm(
args.handle, args.odesc.desc(), args.o->data<T>(),
args.idesc.desc(), args.x->data<T>(), args.cdesc.desc(),
args.wdesc.desc(), const_cast<T*>(args.w->data<T>()),
kNUM_CUDNN_BWD_FILTER_ALGS, &returned_algo_count,
perf_stat.data(), cudnn_workspace_ptr,
workspace_size_limit, false));
};
workspace_handle.RunFuncSync(cudnn_find_func, workspace_size_limit);
VLOG(3) << "BwdFilterAlgo Perf result: (algo: stat, time, memory)";
for (int i = 0; i < returned_algo_count; ++i) {
const auto& stat = perf_stat[i];
VLOG(3) << stat.bwd_weights_algo;
}
return perf_stat[0].bwd_weights_algo;
workspace_handle.RunFuncSync(cudnn_find_func, workspace_size);
return find_result.bwd_weights_algo;
});
}
VLOG(3) << "choose algo " << algo;
return algo;
}
static size_t GetWorkspaceSize(const ConvArgs& args, algo_t algo) {
static size_t GetWorkspaceSize(const ConvArgs& args) {
size_t workspace_size = 0;
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::miopenConvolutionBackwardWeightsGetWorkSpaceSize(

@ -244,13 +244,14 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel<T> {
#ifdef PADDLE_WITH_HIP
using search = SearchAlgorithm<miopenConvBwdDataAlgorithm_t>;
workspace_size = std::max(workspace_size, search::GetWorkspaceSize(args));
algo = search::Find<T>(args, false, deterministic, workspace_size, ctx);
#else
using search = SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t>;
#endif
algo = search::Find<T>(args, false, deterministic, ctx);
workspace_size =
std::max(workspace_size, search::GetWorkspaceSize(args, algo));
#endif
// ------------------- cudnn conv transpose forward ---------------------
int input_offset =
@ -504,12 +505,16 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
platform::AllowTF32Cudnn(), c_groups);
#ifdef PADDLE_WITH_HIP
using search1 = SearchAlgorithm<miopenConvFwdAlgorithm_t>;
workspace_size =
std::max(workspace_size, search1::GetWorkspaceSize(args1));
data_algo =
search1::Find<T>(args1, false, deterministic, workspace_size, ctx);
#else
using search1 = SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t>;
#endif
data_algo = search1::Find<T>(args1, false, deterministic, ctx);
workspace_size =
std::max(workspace_size, search1::GetWorkspaceSize(args1, data_algo));
#endif
}
if (filter_grad) {
@ -522,12 +527,16 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
platform::AllowTF32Cudnn(), c_groups);
#ifdef PADDLE_WITH_HIP
using search2 = SearchAlgorithm<miopenConvBwdWeightsAlgorithm_t>;
workspace_size =
std::max(workspace_size, search2::GetWorkspaceSize(args2));
filter_algo =
search2::Find<T>(args2, false, deterministic, workspace_size, ctx);
#else
using search2 = SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t>;
#endif
filter_algo = search2::Find<T>(args2, false, deterministic, ctx);
workspace_size = std::max(workspace_size,
search2::GetWorkspaceSize(args2, filter_algo));
#endif
}
// ------------------- cudnn conv backward data ---------------------
@ -942,11 +951,14 @@ class CUDNNConvTransposeDoubleGradOpKernel : public framework::OpKernel<T> {
args1.cdesc.set(dtype, padding_common, strides, dilations, c_group);
#ifdef PADDLE_WITH_HIP
using search1 = SearchAlgorithm<miopenConvBwdDataAlgorithm_t>;
workspace_size = search1::GetWorkspaceSize(args1);
bwd_algo1 =
search1::Find<T>(args1, false, deterministic, workspace_size, ctx);
#else
using search1 = SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t>;
#endif
bwd_algo1 = search1::Find<T>(args1, false, deterministic, ctx);
workspace_size = search1::GetWorkspaceSize(args1, bwd_algo1);
#endif
}
if (ddW) {
@ -958,12 +970,16 @@ class CUDNNConvTransposeDoubleGradOpKernel : public framework::OpKernel<T> {
args2.cdesc.set(dtype, padding_common, strides, dilations, c_group);
#ifdef PADDLE_WITH_HIP
using search2 = SearchAlgorithm<miopenConvBwdDataAlgorithm_t>;
workspace_size =
std::max(workspace_size, search2::GetWorkspaceSize(args2));
bwd_algo2 =
search2::Find<T>(args2, false, deterministic, workspace_size, ctx);
#else
using search2 = SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t>;
#endif
bwd_algo2 = search2::Find<T>(args2, false, deterministic, ctx);
workspace_size = std::max(workspace_size,
search2::GetWorkspaceSize(args2, bwd_algo2));
#endif
}
}
@ -978,12 +994,16 @@ class CUDNNConvTransposeDoubleGradOpKernel : public framework::OpKernel<T> {
args3.cdesc.set(dtype, padding_common, strides, dilations, c_group);
#ifdef PADDLE_WITH_HIP
using search3 = SearchAlgorithm<miopenConvBwdWeightsAlgorithm_t>;
workspace_size =
std::max(workspace_size, search3::GetWorkspaceSize(args3));
filter_algo =
search3::Find<T>(args3, false, deterministic, workspace_size, ctx);
#else
using search3 = SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t>;
#endif
filter_algo = search3::Find<T>(args3, false, deterministic, ctx);
workspace_size = std::max(workspace_size,
search3::GetWorkspaceSize(args3, filter_algo));
#endif
}
if (ddW && dX) {
@ -996,12 +1016,16 @@ class CUDNNConvTransposeDoubleGradOpKernel : public framework::OpKernel<T> {
args4.cdesc.set(dtype, padding_common, strides, dilations, c_group);
#ifdef PADDLE_WITH_HIP
using search4 = SearchAlgorithm<miopenConvFwdAlgorithm_t>;
workspace_size =
std::max(workspace_size, search4::GetWorkspaceSize(args4));
data_algo =
search4::Find<T>(args4, false, deterministic, workspace_size, ctx);
#else
using search4 = SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t>;
#endif
data_algo = search4::Find<T>(args4, false, deterministic, ctx);
workspace_size =
std::max(workspace_size, search4::GetWorkspaceSize(args4, data_algo));
#endif
}
int i_n, i_c, i_d, i_h, i_w;

@ -199,19 +199,24 @@ class FilterDescriptor {
void set(const Tensor& tensor, const miopenTensorFormat_t format,
const int groups = 1) {
auto dims = framework::vectorize<int>(tensor.dims());
std::vector<int> transformed_dims;
PADDLE_ENFORCE_EQ(format, MIOPEN_TENSOR_NCHW,
platform::errors::InvalidArgument(
"format should ONLY be NCHW in MIOPEN."));
transformed_dims = dims;
// if (groups > 1) {
// transformed_dims[1] = transformed_dims[1] / groups;
// }
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::miopenSet4dTensorDescriptor(
(miopenTensorDescriptor_t)desc_.get(), ToCudnnDataType(tensor.type()),
transformed_dims[0], transformed_dims[1], transformed_dims[2],
transformed_dims[3]));
auto dims = framework::vectorize<int>(tensor.dims());
std::vector<int> strides(dims.size());
strides[dims.size() - 1] = 1;
for (int i = dims.size() - 2; i >= 0; i--) {
strides[i] = dims[i + 1] * strides[i + 1];
}
std::vector<int> dims_with_group(dims.begin(), dims.end());
if (groups > 1) {
dims_with_group[1] = dims_with_group[1] / groups;
}
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::miopenSetTensorDescriptor(
(miopenTensorDescriptor_t)(desc_.get()), ToCudnnDataType(tensor.type()),
static_cast<int>(dims_with_group.size()),
const_cast<int*>(dims_with_group.data()),
const_cast<int*>(strides.data())));
}
private:

@ -128,6 +128,8 @@ def create_test_cudnn_class(parent):
class TestCUDNNCase(parent):
def init_kernel_type(self):
self.use_cudnn = True
self.dtype = np.float32 if core.is_compiled_with_rocm(
) else np.float64
cls_name = "{0}_{1}".format(parent.__name__, "CUDNN")
TestCUDNNCase.__name__ = cls_name
@ -185,6 +187,8 @@ def create_test_cudnn_channel_last_class(parent):
class TestCudnnChannelLastCase(parent):
def init_kernel_type(self):
self.use_cudnn = True
self.dtype = np.float32 if core.is_compiled_with_rocm(
) else np.float64
def init_data_format(self):
self.data_format = "NHWC"
@ -264,6 +268,8 @@ def create_test_cudnn_padding_SAME_class(parent):
class TestCUDNNPaddingSMAECase(parent):
def init_kernel_type(self):
self.use_cudnn = True
self.dtype = np.float32 if core.is_compiled_with_rocm(
) else np.float64
def init_paddings(self):
self.pad = [1, 1]
@ -280,6 +286,8 @@ def create_test_cudnn_padding_VALID_class(parent):
class TestCUDNNPaddingVALIDCase(parent):
def init_kernel_type(self):
self.use_cudnn = True
self.dtype = np.float32 if core.is_compiled_with_rocm(
) else np.float64
def init_paddings(self):
self.pad = [1, 1]
@ -299,8 +307,7 @@ class TestConv2DOp(OpTest):
self.use_mkldnn = False
self.fuse_relu_before_depthwise_conv = False
self.data_format = "AnyLayout"
# explicilty use float32 for ROCm, as MIOpen does not yet support float64
self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
self.dtype = np.float64
self.init_kernel_type()
self.init_group()
self.init_dilation()
@ -693,6 +700,7 @@ class TestCUDNNExhaustiveSearch(TestConv2DOp):
def init_kernel_type(self):
self.use_cudnn = True
self.exhaustive_search = True
self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
class TestConv2DOpError(unittest.TestCase):
@ -734,8 +742,7 @@ class TestConv2DOp_v2(OpTest):
self.use_cuda = False
self.use_mkldnn = False
self.fuse_relu_before_depthwise_conv = False
# explicilty use float32 for ROCm, as MIOpen does not yet support float64
self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
self.dtype = np.float64
self.init_kernel_type()
self.init_group()
self.init_dilation()

@ -135,6 +135,8 @@ def create_test_cudnn_class(parent):
class TestCUDNNCase(parent):
def init_kernel_type(self):
self.use_cudnn = True
self.dtype = np.float32 if core.is_compiled_with_rocm(
) else np.float64
cls_name = "{0}_{1}".format(parent.__name__, "CUDNN")
TestCUDNNCase.__name__ = cls_name
@ -169,6 +171,8 @@ def create_test_cudnn_padding_SAME_class(parent):
class TestCUDNNPaddingSMAECase(parent):
def init_kernel_type(self):
self.use_cudnn = True
self.dtype = np.float32 if core.is_compiled_with_rocm(
) else np.float64
def init_paddings(self):
self.pad = [1, 1, 1]
@ -185,6 +189,8 @@ def create_test_cudnn_padding_VALID_class(parent):
class TestCUDNNPaddingVALIDCase(parent):
def init_kernel_type(self):
self.use_cudnn = True
self.dtype = np.float32 if core.is_compiled_with_rocm(
) else np.float64
def init_paddings(self):
self.pad = [1, 1, 1]
@ -215,6 +221,8 @@ def create_test_cudnn_channel_last_class(parent):
class TestCudnnChannelLastCase(parent):
def init_kernel_type(self):
self.use_cudnn = True
self.dtype = np.float32 if core.is_compiled_with_rocm(
) else np.float64
def init_data_format(self):
self.data_format = "NDHWC"
@ -410,6 +418,7 @@ class TestWithDilation(TestConv3DOp):
class TestCUDNN(TestConv3DOp):
def init_kernel_type(self):
self.use_cudnn = True
self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
@unittest.skipIf(not core.is_compiled_with_cuda(),
@ -431,6 +440,7 @@ class TestFP16CUDNN(TestConv3DOp):
class TestWithGroup1CUDNN(TestWithGroup1):
def init_kernel_type(self):
self.use_cudnn = True
self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
@unittest.skipIf(not core.is_compiled_with_cuda(),
@ -452,6 +462,7 @@ class TestFP16WithGroup1CUDNN(TestWithGroup1):
class TestWithGroup2CUDNN(TestWithGroup2):
def init_kernel_type(self):
self.use_cudnn = True
self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
@unittest.skipIf(not core.is_compiled_with_cuda(),
@ -473,6 +484,7 @@ class TestFP16WithGroup2CUDNN(TestWithGroup2):
class TestWith1x1CUDNN(TestWith1x1):
def init_kernel_type(self):
self.use_cudnn = True
self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
@unittest.skipIf(not core.is_compiled_with_cuda(),
@ -494,6 +506,7 @@ class TestFP16With1x1CUDNN(TestWith1x1):
class TestWithInput1x1Filter1x1CUDNN(TestWithInput1x1Filter1x1):
def init_kernel_type(self):
self.use_cudnn = True
self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
@unittest.skipIf(not core.is_compiled_with_cuda(),
@ -514,6 +527,7 @@ class TestCUDNNExhaustiveSearch(TestCUDNN):
def init_kernel_type(self):
self.use_cudnn = True
self.exhaustive_search = True
self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
# ---- test asymmetric padding ----

@ -50,7 +50,7 @@ class TestSyncBatchNormOpTraining(unittest.TestCase):
def setUp(self):
"""Setup."""
#self.dtype = np.float32
self.dtype = np.float64
self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
self.N = 8
self.C = 16
self.H = 32
@ -92,6 +92,9 @@ class TestSyncBatchNormOpTraining(unittest.TestCase):
moving_variance_name='bn_moving_variance',
data_layout=layout,
is_test=only_forward)
if core.is_compiled_with_rocm():
bn = fluid.layers.cast(bn, 'float32')
else:
bn = fluid.layers.cast(bn, 'float64')
sigmoid = fluid.layers.sigmoid(bn)
out = fluid.layers.reduce_sum(sigmoid)

Loading…
Cancel
Save