|
|
|
@ -29,20 +29,23 @@ namespace framework {
|
|
|
|
|
return (callback); \
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#define PADDLE_VISIT_DDIM(rank, callback) \
|
|
|
|
|
switch (rank) { \
|
|
|
|
|
PADDLE_VISIT_DDIM_BASE(0, callback); \
|
|
|
|
|
PADDLE_VISIT_DDIM_BASE(1, callback); \
|
|
|
|
|
PADDLE_VISIT_DDIM_BASE(2, callback); \
|
|
|
|
|
PADDLE_VISIT_DDIM_BASE(3, callback); \
|
|
|
|
|
PADDLE_VISIT_DDIM_BASE(4, callback); \
|
|
|
|
|
PADDLE_VISIT_DDIM_BASE(5, callback); \
|
|
|
|
|
PADDLE_VISIT_DDIM_BASE(6, callback); \
|
|
|
|
|
PADDLE_VISIT_DDIM_BASE(7, callback); \
|
|
|
|
|
PADDLE_VISIT_DDIM_BASE(8, callback); \
|
|
|
|
|
PADDLE_VISIT_DDIM_BASE(9, callback); \
|
|
|
|
|
default: \
|
|
|
|
|
PADDLE_THROW("Invalid rank %d", rank); \
|
|
|
|
|
#define PADDLE_VISIT_DDIM(rank, callback) \
|
|
|
|
|
switch (rank) { \
|
|
|
|
|
PADDLE_VISIT_DDIM_BASE(0, callback); \
|
|
|
|
|
PADDLE_VISIT_DDIM_BASE(1, callback); \
|
|
|
|
|
PADDLE_VISIT_DDIM_BASE(2, callback); \
|
|
|
|
|
PADDLE_VISIT_DDIM_BASE(3, callback); \
|
|
|
|
|
PADDLE_VISIT_DDIM_BASE(4, callback); \
|
|
|
|
|
PADDLE_VISIT_DDIM_BASE(5, callback); \
|
|
|
|
|
PADDLE_VISIT_DDIM_BASE(6, callback); \
|
|
|
|
|
PADDLE_VISIT_DDIM_BASE(7, callback); \
|
|
|
|
|
PADDLE_VISIT_DDIM_BASE(8, callback); \
|
|
|
|
|
PADDLE_VISIT_DDIM_BASE(9, callback); \
|
|
|
|
|
default: \
|
|
|
|
|
PADDLE_THROW(platform::errors::Unimplemented( \
|
|
|
|
|
"Invalid dimension to be accessed. Now only supports access to " \
|
|
|
|
|
"dimension 0 to 9, but received dimension is %d.", \
|
|
|
|
|
rank)); \
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T1, typename T2>
|
|
|
|
@ -92,13 +95,31 @@ class DDim {
|
|
|
|
|
|
|
|
|
|
inline int64_t operator[](int idx) const { return dim_[idx]; }
|
|
|
|
|
|
|
|
|
|
inline int64_t& at(int idx) {
|
|
|
|
|
PADDLE_ENFORCE(idx >= 0 && idx < rank_, "Invalid idx %d", idx);
|
|
|
|
|
int64_t& at(int idx) {
|
|
|
|
|
PADDLE_ENFORCE_GE(idx, 0,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Invalid DDim index to be accessed. The valid index "
|
|
|
|
|
"is between 0 and %d, but received index is %d.",
|
|
|
|
|
rank_, idx));
|
|
|
|
|
PADDLE_ENFORCE_LT(idx, rank_,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Invalid DDim index to be accessed. The valid index "
|
|
|
|
|
"is between 0 and %d, but received index is %d.",
|
|
|
|
|
rank_, idx));
|
|
|
|
|
return dim_[idx];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
inline int64_t at(int idx) const {
|
|
|
|
|
PADDLE_ENFORCE(idx >= 0 && idx < rank_, "Invalid idx %d", idx);
|
|
|
|
|
int64_t at(int idx) const {
|
|
|
|
|
PADDLE_ENFORCE_GE(idx, 0,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Invalid DDim index to be accessed. The valid index "
|
|
|
|
|
"is between 0 and %d, but received index is %d.",
|
|
|
|
|
rank_, idx));
|
|
|
|
|
PADDLE_ENFORCE_LT(idx, rank_,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Invalid DDim index to be accessed. The valid index "
|
|
|
|
|
"is between 0 and %d, but received index is %d.",
|
|
|
|
|
rank_, idx));
|
|
|
|
|
return dim_[idx];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|