|
|
|
@ -22,6 +22,14 @@ limitations under the License. */
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
std::ostream& operator<<(std::ostream& out, const std::vector<T>& v) {
|
|
|
|
|
out << "[";
|
|
|
|
|
for (auto const& tmp : v) out << tmp << ",";
|
|
|
|
|
out << "]";
|
|
|
|
|
return out;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
using framework::AlgorithmsCache;
|
|
|
|
|
|
|
|
|
|
struct ConvArgs {
|
|
|
|
@ -119,6 +127,11 @@ struct SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t> {
|
|
|
|
|
auto x_dims = framework::vectorize(args.x->dims());
|
|
|
|
|
auto w_dims = framework::vectorize(args.w->dims());
|
|
|
|
|
|
|
|
|
|
VLOG(10) << "cudnnConvolutionFwdAlgoPerf_t algo_cache_id:"
|
|
|
|
|
<< algo_cache_id << ", x_dims:" << x_dims
|
|
|
|
|
<< ", w_dims:" << w_dims << ", args.s" << args.s << ", args.p"
|
|
|
|
|
<< args.p << ", args.d" << args.d;
|
|
|
|
|
|
|
|
|
|
algo = algo_cache.GetAlgorithm(
|
|
|
|
|
x_dims, w_dims, args.s, args.p, args.d, 0, [&]() {
|
|
|
|
|
int returned_algo_count;
|
|
|
|
@ -247,6 +260,11 @@ struct SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t> {
|
|
|
|
|
auto x_dims = framework::vectorize(args.x->dims());
|
|
|
|
|
auto w_dims = framework::vectorize(args.w->dims());
|
|
|
|
|
|
|
|
|
|
VLOG(10) << "cudnnConvolutionFwdAlgoPerf_t algo_cache_id:"
|
|
|
|
|
<< algo_cache_id << ", x_dims:" << x_dims
|
|
|
|
|
<< ", w_dims:" << w_dims << ", args.s" << args.s << ", args.p"
|
|
|
|
|
<< args.p << ", args.d" << args.d;
|
|
|
|
|
|
|
|
|
|
algo = algo_cache.GetAlgorithm(
|
|
|
|
|
x_dims, w_dims, args.s, args.p, args.d, 0, [&]() {
|
|
|
|
|
int returned_algo_count;
|
|
|
|
@ -368,6 +386,11 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> {
|
|
|
|
|
auto x_dims = framework::vectorize(args.x->dims());
|
|
|
|
|
auto w_dims = framework::vectorize(args.w->dims());
|
|
|
|
|
|
|
|
|
|
VLOG(10) << "cudnnConvolutionFwdAlgoPerf_t algo_cache_id:"
|
|
|
|
|
<< algo_cache_id << ", x_dims:" << x_dims
|
|
|
|
|
<< ", w_dims:" << w_dims << ", args.s" << args.s << ", args.p"
|
|
|
|
|
<< args.p << ", args.d" << args.d;
|
|
|
|
|
|
|
|
|
|
algo = algo_cache.GetAlgorithm(
|
|
|
|
|
x_dims, w_dims, args.s, args.p, args.d, 0, [&]() {
|
|
|
|
|
int returned_algo_count;
|
|
|
|
|