|
|
|
@ -62,7 +62,8 @@ struct DDimPlusVisitor {
|
|
|
|
|
|
|
|
|
|
DDim DDim::operator+(const DDim& d) const {
|
|
|
|
|
PADDLE_ENFORCE(rank_ == d.rank_);
|
|
|
|
|
DDim ret(rank_);
|
|
|
|
|
DDim ret;
|
|
|
|
|
ret.rank_ = rank_;
|
|
|
|
|
ret.apply_visitor(DDimPlusVisitor(Get(), d.Get()));
|
|
|
|
|
return ret;
|
|
|
|
|
}
|
|
|
|
@ -82,7 +83,8 @@ struct DDimMulVisitor {
|
|
|
|
|
|
|
|
|
|
DDim DDim::operator*(const DDim& d) const {
|
|
|
|
|
PADDLE_ENFORCE(rank_ == d.rank_);
|
|
|
|
|
DDim ret(rank_);
|
|
|
|
|
DDim ret;
|
|
|
|
|
ret.rank_ = rank_;
|
|
|
|
|
ret.apply_visitor(DDimMulVisitor(Get(), d.Get()));
|
|
|
|
|
return ret;
|
|
|
|
|
}
|
|
|
|
@ -121,7 +123,9 @@ int64_t product(const DDim& ddim) {
|
|
|
|
|
DDim slice_ddim(const DDim& dim, int begin, int end) {
|
|
|
|
|
PADDLE_ENFORCE(begin >= 0,
|
|
|
|
|
"Begin index can't be less than zero in ddim slice.");
|
|
|
|
|
DDim ret(end - begin);
|
|
|
|
|
int len = end - begin;
|
|
|
|
|
DDim ret;
|
|
|
|
|
ret.rank_ = len;
|
|
|
|
|
dynamic_dim_assign(dim.Get() + begin, ret.GetMutable(), ret.rank_);
|
|
|
|
|
return ret;
|
|
|
|
|
}
|
|
|
|
@ -156,7 +160,8 @@ DDim flatten_to_2d(const DDim& src, int num_col_dims) {
|
|
|
|
|
DDim flatten_to_1d(const DDim& src) { return make_ddim({product(src)}); }
|
|
|
|
|
|
|
|
|
|
DDim stride(const DDim& ddim) {
|
|
|
|
|
DDim strides(ddim.size());
|
|
|
|
|
DDim strides;
|
|
|
|
|
strides.rank_ = ddim.size();
|
|
|
|
|
strides[ddim.size() - 1] = 1;
|
|
|
|
|
for (int i = ddim.size() - 2; i >= 0; --i) {
|
|
|
|
|
strides[i] = strides[i + 1] * ddim[i + 1];
|
|
|
|
@ -165,7 +170,8 @@ DDim stride(const DDim& ddim) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
DDim stride_numel(const DDim& ddim) {
|
|
|
|
|
DDim strides(ddim.size());
|
|
|
|
|
DDim strides;
|
|
|
|
|
strides.rank_ = ddim.size();
|
|
|
|
|
strides[ddim.size() - 1] = ddim[ddim.size() - 1];
|
|
|
|
|
for (int i = ddim.size() - 2; i >= 0; --i) {
|
|
|
|
|
strides[i] = strides[i + 1] * ddim[i];
|
|
|
|
|