|
|
|
@ -288,14 +288,11 @@ DDim::DDim(std::initializer_list<int64_t> init_list) {
|
|
|
|
|
// will be the product of tensor's first `num_col_dims` dimensions
|
|
|
|
|
DDim flatten_to_2d(const DDim& src, int num_col_dims) {
|
|
|
|
|
int rank = src.size();
|
|
|
|
|
return make_ddim(
|
|
|
|
|
{static_cast<int>(product(slice_ddim(src, 0, num_col_dims))),
|
|
|
|
|
static_cast<int>(product(slice_ddim(src, num_col_dims, rank)))});
|
|
|
|
|
return make_ddim({product(slice_ddim(src, 0, num_col_dims)),
|
|
|
|
|
product(slice_ddim(src, num_col_dims, rank))});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
DDim flatten_to_1d(const DDim& src) {
|
|
|
|
|
return make_ddim({static_cast<int>(product(src))});
|
|
|
|
|
}
|
|
|
|
|
DDim flatten_to_1d(const DDim& src) { return make_ddim({product(src)}); }
|
|
|
|
|
|
|
|
|
|
} // namespace framework
|
|
|
|
|
} // namespace paddle
|
|
|
|
|