remove conflict

emailweixu-patch-1
chengduoZH 7 years ago
commit 8ea2288e10

@ -16,12 +16,10 @@ function(copy TARGET)
foreach(index RANGE ${len}) foreach(index RANGE ${len})
list(GET copy_lib_SRCS ${index} src) list(GET copy_lib_SRCS ${index} src)
list(GET copy_lib_DSTS ${index} dst) list(GET copy_lib_DSTS ${index} dst)
add_custom_command(TARGET ${TARGET} PRE_BUILD COMMAND mkdir -p "${dst}") add_custom_command(TARGET ${TARGET} PRE_BUILD
if(IS_DIRECTORY ${src}) COMMAND mkdir -p "${dst}"
add_custom_command(TARGET ${TARGET} PRE_BUILD COMMAND cp -r "${src}" "${dst}") COMMAND cp -r "${src}" "${dst}"
else() COMMENT "copying ${src} -> ${dst}")
add_custom_command(TARGET ${TARGET} PRE_BUILD COMMAND cp "${src}" "${dst}")
endif()
endforeach() endforeach()
endfunction() endfunction()
@ -53,11 +51,11 @@ IF(NOT PROTOBUF_FOUND)
ENDIF(NOT PROTOBUF_FOUND) ENDIF(NOT PROTOBUF_FOUND)
# paddle fluid module # paddle fluid module
set(src_dir "${PADDLE_SOURCE_DIR}/paddle") set(src_dir "${PADDLE_SOURCE_DIR}/paddle/fluid")
set(dst_dir "${CMAKE_INSTALL_PREFIX}/paddle") set(dst_dir "${CMAKE_INSTALL_PREFIX}/paddle/fluid")
set(module "framework") set(module "framework")
copy(framework_lib DEPS framework_py_proto copy(framework_lib DEPS framework_py_proto
SRCS ${src_dir}/${module}/*.h ${src_dir}/${module}/details/*.h ${PADDLE_BINARY_DIR}/paddle/framework/framework.pb.h SRCS ${src_dir}/${module}/*.h ${src_dir}/${module}/details/*.h ${PADDLE_BINARY_DIR}/paddle/fluid/framework/framework.pb.h
DSTS ${dst_dir}/${module} ${dst_dir}/${module}/details ${dst_dir}/${module} DSTS ${dst_dir}/${module} ${dst_dir}/${module}/details ${dst_dir}/${module}
) )
@ -69,7 +67,7 @@ copy(memory_lib
set(module "inference") set(module "inference")
copy(inference_lib DEPENDS paddle_fluid_shared copy(inference_lib DEPENDS paddle_fluid_shared
SRCS ${src_dir}/${module}/*.h ${PADDLE_BINARY_DIR}/paddle/inference/libpaddle_fluid.so SRCS ${src_dir}/${module}/*.h ${PADDLE_BINARY_DIR}/paddle/fluid/inference/libpaddle_fluid.so
DSTS ${dst_dir}/${module} ${dst_dir}/${module} DSTS ${dst_dir}/${module} ${dst_dir}/${module}
) )

@ -25,7 +25,10 @@ namespace framework {
class CosineOp : public OperatorBase { class CosineOp : public OperatorBase {
public: public:
using OperatorBase::OperatorBase; using OperatorBase::OperatorBase;
void Run(const Scope& scope, const platform::Place& place) const override {}
private:
void RunImpl(const Scope& scope,
const platform::Place& place) const override {}
}; };
class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
@ -44,7 +47,10 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
class MyTestOp : public OperatorBase { class MyTestOp : public OperatorBase {
public: public:
using OperatorBase::OperatorBase; using OperatorBase::OperatorBase;
void Run(const Scope& scope, const platform::Place& place) const override {}
private:
void RunImpl(const Scope& scope,
const platform::Place& place) const override {}
}; };
class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {

@ -64,6 +64,18 @@ static LoD GetLoD(const Scope& scope, const std::string& name) {
} }
} }
void OperatorBase::Run(const Scope& scope, const platform::Place& place) {
if (platform::is_gpu_place(place)) {
#ifndef PADDLE_WITH_CUDA
PADDLE_THROW("Cannot run operator on place %s", place);
#else
auto dev_id = boost::get<platform::CUDAPlace>(place).device;
platform::SetDeviceId(dev_id);
#endif
}
RunImpl(scope, place);
}
std::string OperatorBase::Input(const std::string& name) const { std::string OperatorBase::Input(const std::string& name) const {
auto& ins = Inputs(name); auto& ins = Inputs(name);
PADDLE_ENFORCE_LE(ins.size(), 1UL, PADDLE_ENFORCE_LE(ins.size(), 1UL,
@ -479,8 +491,8 @@ class RuntimeInferShapeContext : public InferShapeContext {
const Scope& scope_; const Scope& scope_;
}; };
void OperatorWithKernel::Run(const Scope& scope, void OperatorWithKernel::RunImpl(const Scope& scope,
const platform::Place& place) const { const platform::Place& place) const {
RuntimeInferShapeContext infer_shape_ctx(*this, scope); RuntimeInferShapeContext infer_shape_ctx(*this, scope);
this->InferShape(&infer_shape_ctx); this->InferShape(&infer_shape_ctx);
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();

@ -89,8 +89,9 @@ class OperatorBase {
std::string DebugString() const { return DebugStringEx(nullptr); } std::string DebugString() const { return DebugStringEx(nullptr); }
/// Net will call this function to Run an op. /// Net will call this interface function to Run an op.
virtual void Run(const Scope& scope, const platform::Place& place) const = 0; // The implementation should be written at RunImpl
void Run(const Scope& scope, const platform::Place& place);
// FIXME(typhoonzero): this is only used for recv_op to stop event_loop. // FIXME(typhoonzero): this is only used for recv_op to stop event_loop.
virtual void Stop() {} virtual void Stop() {}
@ -144,6 +145,8 @@ class OperatorBase {
private: private:
void GenerateTemporaryNames(); void GenerateTemporaryNames();
void CheckAllInputOutputSet() const; void CheckAllInputOutputSet() const;
virtual void RunImpl(const Scope& scope,
const platform::Place& place) const = 0;
}; };
// Macro for define a clone method. // Macro for define a clone method.
@ -168,10 +171,13 @@ class OperatorBase {
class NOP : public OperatorBase { class NOP : public OperatorBase {
public: public:
using OperatorBase::OperatorBase; using OperatorBase::OperatorBase;
void Run(const Scope& scope, const platform::Place& place) const override {}
std::unique_ptr<OperatorBase> Clone() const override { std::unique_ptr<OperatorBase> Clone() const override {
return std::unique_ptr<OperatorBase>(new NOP(*this)); return std::unique_ptr<OperatorBase>(new NOP(*this));
} }
private:
void RunImpl(const Scope& scope,
const platform::Place& place) const override {}
}; };
class ExecutionContext { class ExecutionContext {
@ -363,8 +369,6 @@ class OperatorWithKernel : public OperatorBase {
const VariableNameMap& outputs, const AttributeMap& attrs) const VariableNameMap& outputs, const AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const Scope& scope, const platform::Place& place) const final;
static std::unordered_map<std::string /* op_type */, OpKernelMap>& static std::unordered_map<std::string /* op_type */, OpKernelMap>&
AllOpKernels() { AllOpKernels() {
static std::unordered_map<std::string, OpKernelMap> g_all_op_kernels; static std::unordered_map<std::string, OpKernelMap> g_all_op_kernels;
@ -393,6 +397,7 @@ class OperatorWithKernel : public OperatorBase {
// indicate kernel DataType by input data. Defaultly all input data must be // indicate kernel DataType by input data. Defaultly all input data must be
// same. // same.
proto::DataType IndicateDataType(const ExecutionContext& ctx) const; proto::DataType IndicateDataType(const ExecutionContext& ctx) const;
void RunImpl(const Scope& scope, const platform::Place& place) const final;
}; };
extern bool OpSupportGPU(const std::string& op_type); extern bool OpSupportGPU(const std::string& op_type);

@ -28,7 +28,10 @@ class OpWithoutKernelTest : public OperatorBase {
OpWithoutKernelTest(const std::string& type, const VariableNameMap& inputs, OpWithoutKernelTest(const std::string& type, const VariableNameMap& inputs,
const VariableNameMap& outputs, const AttributeMap& attrs) const VariableNameMap& outputs, const AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs), x(1) {} : OperatorBase(type, inputs, outputs, attrs), x(1) {}
void Run(const Scope& scope, const platform::Place& place) const override {
private:
void RunImpl(const Scope& scope,
const platform::Place& place) const override {
++op_run_num; ++op_run_num;
ASSERT_EQ(static_cast<int>(inputs_.size()), 1); ASSERT_EQ(static_cast<int>(inputs_.size()), 1);
ASSERT_EQ(static_cast<int>(outputs_.size()), 1); ASSERT_EQ(static_cast<int>(outputs_.size()), 1);
@ -259,8 +262,10 @@ class OperatorClone : public paddle::framework::OperatorBase {
const paddle::framework::VariableNameMap& outputs, const paddle::framework::VariableNameMap& outputs,
const paddle::framework::AttributeMap& attrs) const paddle::framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const paddle::framework::Scope& scope,
const paddle::platform::Place& place) const override {} private:
void RunImpl(const paddle::framework::Scope& scope,
const paddle::platform::Place& place) const override {}
}; };
TEST(Operator, Clone) { TEST(Operator, Clone) {

@ -31,8 +31,10 @@ class ArrayToLoDTensorOp : public framework::OperatorBase {
const framework::VariableNameMap &outputs, const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::Place &dev_place) const override { private:
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
auto &x = scope.FindVar(Input("X"))->Get<framework::LoDTensorArray>(); auto &x = scope.FindVar(Input("X"))->Get<framework::LoDTensorArray>();
auto &rank_table = auto &rank_table =
scope.FindVar(Input("RankTable"))->Get<framework::LoDRankTable>(); scope.FindVar(Input("RankTable"))->Get<framework::LoDRankTable>();

@ -71,8 +71,10 @@ class AssignOp : public framework::OperatorBase {
const framework::VariableNameMap &outputs, const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::Place &place) const override { private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto *x = scope.FindVar(Input("X")); auto *x = scope.FindVar(Input("X"));
if (x == nullptr) { if (x == nullptr) {
return; return;

@ -55,8 +55,10 @@ class BeamSearchDecodeOp : public framework::OperatorBase {
const framework::VariableNameMap& outputs, const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs) const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope& scope,
const platform::Place& dev_place) const override { private:
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& dev_ctx = *pool.Get(dev_place); auto& dev_ctx = *pool.Get(dev_place);

@ -204,8 +204,9 @@ class BeamSearchOp : public framework::OperatorBase {
PADDLE_THROW("Not Implemented"); PADDLE_THROW("Not Implemented");
} }
void Run(const framework::Scope& scope, private:
const platform::Place& dev_place) const override { void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override {
auto ids_var = scope.FindVar(Input("ids")); auto ids_var = scope.FindVar(Input("ids"));
auto scores_var = scope.FindVar(Input("scores")); auto scores_var = scope.FindVar(Input("scores"));
auto pre_ids_var = scope.FindVar(Input("pre_ids")); auto pre_ids_var = scope.FindVar(Input("pre_ids"));

@ -38,7 +38,7 @@ class ConcatKernel : public framework::OpKernel<T> {
auto in_stride = framework::stride_numel(in->dims()); auto in_stride = framework::stride_numel(in->dims());
StridedNumelCopyWithAxis<T>(ctx.device_context(), axis, StridedNumelCopyWithAxis<T>(ctx.device_context(), axis,
out->data<T>() + output_offset, out_stride, out->data<T>() + output_offset, out_stride,
in->data<T>(), in_stride); in->data<T>(), in_stride, in_stride[axis]);
output_offset += in_stride[axis]; output_offset += in_stride[axis];
} }
} }
@ -59,7 +59,7 @@ class ConcatGradKernel : public framework::OpKernel<T> {
auto out_stride = framework::stride_numel(out->dims()); auto out_stride = framework::stride_numel(out->dims());
StridedNumelCopyWithAxis<T>(ctx.device_context(), axis, out->data<T>(), StridedNumelCopyWithAxis<T>(ctx.device_context(), axis, out->data<T>(),
out_stride, in->data<T>() + input_offset, out_stride, in->data<T>() + input_offset,
in_stride); in_stride, out_stride[axis]);
input_offset += out_stride[axis]; input_offset += out_stride[axis];
} }
} }

@ -193,7 +193,7 @@ void CondOp::MergeDataFromSubnet(const framework::Scope& scope,
} }
} }
void CondOp::Run(const Scope& scope, const platform::Place& place) const { void CondOp::RunImpl(const Scope& scope, const platform::Place& place) const {
// get device context from pool // get device context from pool
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& dev_ctx = *pool.Get(place); auto& dev_ctx = *pool.Get(place);

@ -77,8 +77,9 @@ class CondOp : public framework::OperatorBase {
sub_net_op_[FALSE_BRANCH] = std::move(net); sub_net_op_[FALSE_BRANCH] = std::move(net);
} }
void Run(const framework::Scope& scope, private:
const platform::Place& place) const override; void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override;
private: private:
const int TRUE_BRANCH = 0; const int TRUE_BRANCH = 0;

@ -65,8 +65,10 @@ class ConditionalBlockOp : public ConditionalOp {
const framework::VariableNameMap &outputs, const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: ConditionalOp(type, inputs, outputs, attrs) {} : ConditionalOp(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::Place &dev_place) const override { private:
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
auto xs = InputTensors(scope); auto xs = InputTensors(scope);
bool need_run; bool need_run;
@ -128,8 +130,10 @@ class ConditionalBlockGradOp : public ConditionalOp {
const framework::VariableNameMap &outputs, const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: ConditionalOp(type, inputs, outputs, attrs) {} : ConditionalOp(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::Place &dev_place) const override { private:
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
auto xs = this->InputTensors(scope); auto xs = this->InputTensors(scope);
bool need_run; bool need_run;

@ -106,8 +106,10 @@ template <typename T>
class CreateRandomDataGeneratorOp : public framework::OperatorBase { class CreateRandomDataGeneratorOp : public framework::OperatorBase {
public: public:
using framework::OperatorBase::OperatorBase; using framework::OperatorBase::OperatorBase;
void Run(const framework::Scope& scope,
const platform::Place& dev_place) const override { private:
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override {
const auto& shape_concat = Attr<std::vector<int>>("shape_concat"); const auto& shape_concat = Attr<std::vector<int>>("shape_concat");
const auto& ranks = Attr<std::vector<int>>("ranks"); const auto& ranks = Attr<std::vector<int>>("ranks");
PADDLE_ENFORCE(!shape_concat.empty() && !ranks.empty()); PADDLE_ENFORCE(!shape_concat.empty() && !ranks.empty());
@ -155,8 +157,10 @@ class CreateRandomDataGeneratorOpMaker
class CreateShuffleReaderOp : public framework::OperatorBase { class CreateShuffleReaderOp : public framework::OperatorBase {
public: public:
using framework::OperatorBase::OperatorBase; using framework::OperatorBase::OperatorBase;
void Run(const framework::Scope& scope,
const platform::Place& dev_place) const override { private:
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override {
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader")) const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>(); ->Get<framework::ReaderHolder>();
auto* out = scope.FindVar(Output("Out")) auto* out = scope.FindVar(Output("Out"))
@ -187,8 +191,10 @@ class CreateShuffleReaderOpMaker : public framework::OpProtoAndCheckerMaker {
class CreateBatchReaderOp : public framework::OperatorBase { class CreateBatchReaderOp : public framework::OperatorBase {
public: public:
using framework::OperatorBase::OperatorBase; using framework::OperatorBase::OperatorBase;
void Run(const framework::Scope& scope,
const platform::Place& dev_place) const override { private:
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override {
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader")) const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>(); ->Get<framework::ReaderHolder>();
auto* out = scope.FindVar(Output("Out")) auto* out = scope.FindVar(Output("Out"))

@ -24,8 +24,10 @@ class FeedOp : public framework::OperatorBase {
const framework::VariableNameMap &outputs, const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::Place &place) const override { private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto feed_var_name = Input("X"); auto feed_var_name = Input("X");
auto *feed_var = scope.FindVar(feed_var_name); auto *feed_var = scope.FindVar(feed_var_name);

@ -26,8 +26,9 @@ class FetchOp : public framework::OperatorBase {
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope, private:
const platform::Place &place) const override { void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto fetch_var_name = Input("X"); auto fetch_var_name = Input("X");
auto *fetch_var = scope.FindVar(fetch_var_name); auto *fetch_var = scope.FindVar(fetch_var_name);
PADDLE_ENFORCE(fetch_var != nullptr, PADDLE_ENFORCE(fetch_var != nullptr,

@ -33,8 +33,10 @@ class FillConstantInferShape : public framework::InferShapeBase {
class FillConstantOp : public framework::OperatorBase { class FillConstantOp : public framework::OperatorBase {
public: public:
using framework::OperatorBase::OperatorBase; using framework::OperatorBase::OperatorBase;
void Run(const framework::Scope &scope,
const platform::Place &dev_place) const override { private:
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
auto data_type = auto data_type =
static_cast<framework::proto::DataType>(Attr<int>("dtype")); static_cast<framework::proto::DataType>(Attr<int>("dtype"));
auto value = Attr<float>("value"); auto value = Attr<float>("value");

@ -42,8 +42,10 @@ class FillOp : public framework::OperatorBase {
const framework::VariableNameMap &outputs, const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::Place &place) const override { private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto &out = auto &out =
detail::Ref(detail::Ref(scope.FindVar(Output("Out")), detail::Ref(detail::Ref(scope.FindVar(Output("Out")),
"Cannot find variable %s", Output("Out")) "Cannot find variable %s", Output("Out"))

@ -37,8 +37,10 @@ class GetPlacesOp : public framework::OperatorBase {
const framework::VariableNameMap &outputs, const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::Place &place) const override { private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
bool is_gpu; bool is_gpu;
if (Attr<std::string>("device_type") == "AUTO") { if (Attr<std::string>("device_type") == "AUTO") {
is_gpu = platform::is_gpu_place(place); is_gpu = platform::is_gpu_place(place);

@ -51,8 +51,9 @@ class IncrementOp : public framework::OperatorBase {
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope, private:
const platform::Place &place) const override { void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto &x = scope.FindVar(Input("X"))->Get<framework::LoDTensor>(); auto &x = scope.FindVar(Input("X"))->Get<framework::LoDTensor>();
auto &out = auto &out =
*scope.FindVar(Output("Out"))->GetMutable<framework::LoDTensor>(); *scope.FindVar(Output("Out"))->GetMutable<framework::LoDTensor>();

@ -28,8 +28,9 @@ class IsEmptyOp : public framework::OperatorBase {
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope, private:
const platform::Place &place) const override { void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
// get input // get input
auto *var = scope.FindVar(Input(kInput)); auto *var = scope.FindVar(Input(kInput));
PADDLE_ENFORCE_NOT_NULL(var); PADDLE_ENFORCE_NOT_NULL(var);

@ -26,8 +26,10 @@ class LoadCombineOp : public framework::OperatorBase {
const framework::VariableNameMap &outputs, const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::Place &place) const override { private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto filename = Attr<std::string>("file_path"); auto filename = Attr<std::string>("file_path");
std::ifstream fin(filename); std::ifstream fin(filename);

@ -25,8 +25,10 @@ class LoadOp : public framework::OperatorBase {
const framework::VariableNameMap &outputs, const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::Place &place) const override { private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto filename = Attr<std::string>("file_path"); auto filename = Attr<std::string>("file_path");
std::ifstream fin(filename); std::ifstream fin(filename);
PADDLE_ENFORCE(static_cast<bool>(fin), "Cannot open file %s for load op", PADDLE_ENFORCE(static_cast<bool>(fin), "Cannot open file %s for load op",

@ -25,8 +25,10 @@ class LoDArrayLengthOp : public framework::OperatorBase {
const framework::VariableNameMap &outputs, const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::Place &place) const override { private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto &x = scope.FindVar(Input("X"))->Get<framework::LoDTensorArray>(); auto &x = scope.FindVar(Input("X"))->Get<framework::LoDTensorArray>();
auto &out = auto &out =
*scope.FindVar(Output("Out"))->GetMutable<framework::LoDTensor>(); *scope.FindVar(Output("Out"))->GetMutable<framework::LoDTensor>();

@ -23,8 +23,10 @@ class LoDRankTableOp : public framework::OperatorBase {
const framework::VariableNameMap &outputs, const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::Place &dev_place) const override { private:
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
auto x = scope.FindVar(Input("X"))->Get<framework::LoDTensor>(); auto x = scope.FindVar(Input("X"))->Get<framework::LoDTensor>();
auto *out = auto *out =
scope.FindVar(Output("Out"))->GetMutable<framework::LoDRankTable>(); scope.FindVar(Output("Out"))->GetMutable<framework::LoDRankTable>();

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save