|
|
|
@ -11,17 +11,15 @@ distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
#include "paddle/fluid/framework/operator.h"
|
|
|
|
|
#include <gflags/gflags.h>
|
|
|
|
|
#include <glog/logging.h>
|
|
|
|
|
|
|
|
|
|
#include <algorithm>
|
|
|
|
|
#include <sstream>
|
|
|
|
|
#include <string>
|
|
|
|
|
#include <vector>
|
|
|
|
|
#include "gflags/gflags.h"
|
|
|
|
|
#include "glog/logging.h"
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/framework/data_transform.h"
|
|
|
|
|
#include "paddle/fluid/framework/executor.h"
|
|
|
|
|
#include "paddle/fluid/framework/lod_tensor.h"
|
|
|
|
|
#include "paddle/fluid/framework/op_proto_maker.h"
|
|
|
|
|
#include "paddle/fluid/framework/operator.h"
|
|
|
|
|
#include "paddle/fluid/framework/shape_inference.h"
|
|
|
|
|
#include "paddle/fluid/framework/var_type.h"
|
|
|
|
|
#include "paddle/fluid/platform/profiler.h"
|
|
|
|
@ -139,48 +137,19 @@ static LoD GetLoD(const Scope& scope, const std::string& name) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void OperatorBase::Run(const Scope& scope, const platform::Place& place) {
|
|
|
|
|
try {
|
|
|
|
|
if (VLOG_IS_ON(4)) {
|
|
|
|
|
VLOG(4) << place << " " << DebugStringEx(&scope);
|
|
|
|
|
}
|
|
|
|
|
if (platform::is_gpu_place(place)) {
|
|
|
|
|
VLOG(4) << place << " " << DebugStringEx(&scope);
|
|
|
|
|
if (platform::is_gpu_place(place)) {
|
|
|
|
|
#ifndef PADDLE_WITH_CUDA
|
|
|
|
|
PADDLE_THROW("Cannot run operator on place %s", place);
|
|
|
|
|
PADDLE_THROW("Cannot run operator on place %s", place);
|
|
|
|
|
#else
|
|
|
|
|
auto dev_id = boost::get<platform::CUDAPlace>(place).device;
|
|
|
|
|
platform::SetDeviceId(dev_id);
|
|
|
|
|
auto dev_id = boost::get<platform::CUDAPlace>(place).device;
|
|
|
|
|
platform::SetDeviceId(dev_id);
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
|
|
|
|
|
platform::RecordEvent record_event(Type(), pool.Get(place));
|
|
|
|
|
RunImpl(scope, place);
|
|
|
|
|
if (VLOG_IS_ON(3)) {
|
|
|
|
|
VLOG(3) << place << " " << DebugStringEx(&scope);
|
|
|
|
|
}
|
|
|
|
|
} catch (platform::EnforceNotMet exception) {
|
|
|
|
|
if (Attrs().count("sub_block") != 0) {
|
|
|
|
|
throw exception;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto& callstack = Attr<std::vector<std::string>>(
|
|
|
|
|
OpProtoAndCheckerMaker::OpCreationCallstackAttrName());
|
|
|
|
|
|
|
|
|
|
if (callstack.empty()) {
|
|
|
|
|
throw exception;
|
|
|
|
|
}
|
|
|
|
|
std::ostringstream sout;
|
|
|
|
|
sout << "Invoke operator " << Type() << " error.\n";
|
|
|
|
|
sout << "Python Callstacks: \n";
|
|
|
|
|
for (auto& line : callstack) {
|
|
|
|
|
sout << line;
|
|
|
|
|
}
|
|
|
|
|
sout << "C++ Callstacks: \n";
|
|
|
|
|
sout << exception.err_str_;
|
|
|
|
|
exception.err_str_ = sout.str();
|
|
|
|
|
throw exception;
|
|
|
|
|
} catch (...) {
|
|
|
|
|
std::rethrow_exception(std::current_exception());
|
|
|
|
|
}
|
|
|
|
|
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
|
|
|
|
|
platform::RecordEvent record_event(Type(), pool.Get(place));
|
|
|
|
|
RunImpl(scope, place);
|
|
|
|
|
VLOG(3) << place << " " << DebugStringEx(&scope);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool OperatorBase::HasInputs(const std::string& name) const {
|
|
|
|
@ -208,7 +177,7 @@ const std::vector<std::string>& OperatorBase::Inputs(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool OperatorBase::HasOutputs(const std::string& name) const {
|
|
|
|
|
if (outputs_.end() != outputs_.find(name)) {
|
|
|
|
|
if (outputs_.find(name) != outputs_.end()) {
|
|
|
|
|
return true;
|
|
|
|
|
} else {
|
|
|
|
|
return false;
|
|
|
|
|