fix get_mid_dims annotation (#8490)

tonyyang-svail-patch-1
chengduo 8 years ago committed by Abhinav Arora
parent 77ee8fb240
commit 0e187bc93e

@ -35,10 +35,10 @@ namespace operators {
* For example:
* 1. shape(X) = (2, 3, 4, 5), shape(Y) = (3, 4), with axis=1
* pre=2, n=3*4, post=5
* x.shape(2, 12, 5) * y.shape(1,12,1).broadcast(2,12,5)
* x.shape(2, 12, 5) * y.shape(1, 12, 1).broadcast(2, 12, 5)
* 2. shape(X) = (2, 3, 4, 5), shape(Y) = (4,5)
* pre=2*3, n=4*5, post=1
* x.shape(2, 3, 20) * y.shape(1,1,20).broadcast(2,3,20)
* x.shape(6, 20, 1) * y.shape(1, 20, 1).broadcast(6, 20, 1)
*/
inline void get_mid_dims(const framework::DDim& x_dims,
const framework::DDim& y_dims, const int axis,

Loading…
Cancel
Save