|
|
|
@ -171,14 +171,17 @@ void SliceOneClass(const platform::DeviceContext& ctx,
|
|
|
|
|
const T* items_data = items.data<T>();
|
|
|
|
|
const int64_t num_item = items.dims()[0];
|
|
|
|
|
const int class_num = items.dims()[1];
|
|
|
|
|
int item_size = 1;
|
|
|
|
|
if (items.dims().size() == 3) {
|
|
|
|
|
item_size = items.dims()[2];
|
|
|
|
|
}
|
|
|
|
|
for (int i = 0; i < num_item; ++i) {
|
|
|
|
|
std::memcpy(item_data + i * item_size,
|
|
|
|
|
items_data + i * class_num * item_size + class_id * item_size,
|
|
|
|
|
sizeof(T) * item_size);
|
|
|
|
|
int item_size = items.dims()[2];
|
|
|
|
|
for (int i = 0; i < num_item; ++i) {
|
|
|
|
|
std::memcpy(item_data + i * item_size,
|
|
|
|
|
items_data + i * class_num * item_size + class_id * item_size,
|
|
|
|
|
sizeof(T) * item_size);
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
for (int i = 0; i < num_item; ++i) {
|
|
|
|
|
item_data[i] = items_data[i * class_num + class_id];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|