parent
6f2eba3e7e
commit
bac1426d47
@ -1,5 +1,7 @@
|
||||
#include <paddle/operators/add_op.h>
|
||||
#include <paddle/framework/op_registry.h>
|
||||
#define EIGEN_USE_GPU
|
||||
|
||||
#include "paddle/operators/add_op.h"
|
||||
#include "paddle/framework/op_registry.h"
|
||||
|
||||
REGISTER_OP_GPU_KERNEL(add_two,
|
||||
paddle::operators::AddKernel<paddle::platform::GPUPlace>);
|
||||
paddle::operators::AddKernel<paddle::platform::GPUPlace, float>);
|
@ -1,17 +1,26 @@
|
||||
#pragma once
|
||||
#include <glog/logging.h>
|
||||
#include <paddle/framework/operator.h>
|
||||
#include "glog/logging.h"
|
||||
#include "paddle/framework/operator.h"
|
||||
//#include "paddle/operators/add_op_functor.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename Place>
|
||||
// Place can be CPUPlace or GPUPlace
|
||||
template <typename Place, typename DataType>
|
||||
class AddKernel : public framework::OpKernel {
|
||||
public:
|
||||
void Compute(const KernelContext &context) const override {
|
||||
LOG(INFO) << "Add kernel in " << typeid(Place).name();
|
||||
void Compute(const KernelContext& context) const override {
|
||||
auto* input0 = context.Input(0);
|
||||
auto* input1 = context.Input(1);
|
||||
|
||||
auto* output = context.Output(0);
|
||||
output->mutable_data<DataType>(Place());
|
||||
|
||||
output->flat<T>().device(*(context.get_eigen_device<Place>())) =
|
||||
input0->flat<T>() + input1->flat<T>();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
Loading…
Reference in new issue