|
|
|
@ -116,7 +116,7 @@ int arity(const DDim& ddim);
|
|
|
|
|
std::ostream& operator<<(std::ostream&, const DDim&);
|
|
|
|
|
|
|
|
|
|
// Reshape a tensor to a matrix. The matrix's first dimension(column length)
|
|
|
|
|
// will be the product of tensor's first `num_col_dims` dimensions
|
|
|
|
|
// will be the product of tensor's first `num_col_dims` dimensions.
|
|
|
|
|
DDim flatten_to_2d(const DDim& src, int num_col_dims);
|
|
|
|
|
|
|
|
|
|
DDim flatten_to_1d(const DDim& src);
|
|
|
|
|