|
|
|
@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include <set>
|
|
|
|
|
#include <vector>
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/operators/math/math_function.h"
|
|
|
|
|
#include "paddle/fluid/operators/math/selected_rows_functor.h"
|
|
|
|
@ -190,7 +189,7 @@ template struct SelectedRowsAddToTensor<platform::CPUDeviceContext, int64_t>;
|
|
|
|
|
// add or mul.
|
|
|
|
|
namespace scatter {
|
|
|
|
|
|
|
|
|
|
size_t FindPos(const std::vector<int64_t>& rows, int64_t value) {
|
|
|
|
|
static size_t FindPos(const std::vector<int64_t>& rows, int64_t value) {
|
|
|
|
|
return std::find(rows.begin(), rows.end(), value) - rows.begin();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -206,14 +205,31 @@ struct MergeAdd<platform::CPUDeviceContext, T> {
|
|
|
|
|
void operator()(const platform::CPUDeviceContext& context,
|
|
|
|
|
const framework::SelectedRows& input,
|
|
|
|
|
framework::SelectedRows* output) {
|
|
|
|
|
framework::SelectedRows& out = *output;
|
|
|
|
|
auto input_rows = input.rows();
|
|
|
|
|
std::set<int64_t> row_set(input_rows.begin(), input_rows.end());
|
|
|
|
|
std::vector<int64_t> merge_rows(row_set.begin(), row_set.end());
|
|
|
|
|
std::vector<const framework::SelectedRows*> inputs;
|
|
|
|
|
inputs.push_back(&input);
|
|
|
|
|
(*this)(context, inputs, output);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto input_width = input.value().dims()[1];
|
|
|
|
|
void operator()(const platform::CPUDeviceContext& context,
|
|
|
|
|
const std::vector<const framework::SelectedRows*>& inputs,
|
|
|
|
|
framework::SelectedRows* output) {
|
|
|
|
|
PADDLE_ENFORCE_GT(inputs.size(), 0, "should have at least one input");
|
|
|
|
|
auto input_width = inputs[0]->value().dims()[1];
|
|
|
|
|
auto input_height = inputs[0]->height();
|
|
|
|
|
framework::SelectedRows& out = *output;
|
|
|
|
|
std::set<int64_t> merged_row_set;
|
|
|
|
|
for (auto* input : inputs) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(input_width, input->value().dims()[1],
|
|
|
|
|
"all input should have same "
|
|
|
|
|
"dimension except for the first one");
|
|
|
|
|
PADDLE_ENFORCE_EQ(input_height, input->height(),
|
|
|
|
|
"all input should have same height");
|
|
|
|
|
merged_row_set.insert(input->rows().begin(), input->rows().end());
|
|
|
|
|
}
|
|
|
|
|
std::vector<int64_t> merge_rows(merged_row_set.begin(),
|
|
|
|
|
merged_row_set.end());
|
|
|
|
|
out.set_rows(merge_rows);
|
|
|
|
|
out.set_height(input.height());
|
|
|
|
|
out.set_height(input_height);
|
|
|
|
|
out.mutable_value()->mutable_data<T>(
|
|
|
|
|
framework::make_ddim(
|
|
|
|
|
{static_cast<int64_t>(merge_rows.size()), input_width}),
|
|
|
|
@ -223,12 +239,16 @@ struct MergeAdd<platform::CPUDeviceContext, T> {
|
|
|
|
|
constant_functor(context, out.mutable_value(), 0.0);
|
|
|
|
|
|
|
|
|
|
auto* out_data = out.mutable_value()->data<T>();
|
|
|
|
|
auto* input_data = input.value().data<T>();
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < input_rows.size(); i++) {
|
|
|
|
|
size_t out_i = FindPos(merge_rows, input_rows[i]);
|
|
|
|
|
for (int64_t j = 0; j < input_width; j++) {
|
|
|
|
|
out_data[out_i * input_width + j] += input_data[i * input_width + j];
|
|
|
|
|
for (auto* input : inputs) {
|
|
|
|
|
auto* input_data = input->value().data<T>();
|
|
|
|
|
auto& input_rows = input->rows();
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < input_rows.size(); i++) {
|
|
|
|
|
size_t out_i = FindPos(merge_rows, input_rows[i]);
|
|
|
|
|
for (int64_t j = 0; j < input_width; j++) {
|
|
|
|
|
out_data[out_i * input_width + j] += input_data[i * input_width + j];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|