@ -11,15 +11,17 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND , either express or implied .
See the License for the specific language governing permissions and
limitations under the License . */
# include <gflags/gflags.h>
# include <glog/logging.h>
# include "paddle/fluid/framework/operator.h"
# include <algorithm>
# include <sstream>
# include <string>
# include <vector>
# include "gflags/gflags.h"
# include "glog/logging.h"
# include "paddle/fluid/framework/data_transform.h"
# include "paddle/fluid/framework/executor.h"
# include "paddle/fluid/framework/lod_tensor.h"
# include "paddle/fluid/framework/operator.h"
# include "paddle/fluid/framework/op _proto_make r.h"
# include "paddle/fluid/framework/shape_inference.h"
# include "paddle/fluid/framework/var_type.h"
# include "paddle/fluid/platform/profiler.h"
@ -127,19 +129,48 @@ static LoD GetLoD(const Scope& scope, const std::string& name) {
}
void OperatorBase : : Run ( const Scope & scope , const platform : : Place & place ) {
VLOG ( 4 ) < < place < < " " < < DebugStringEx ( & scope ) ;
if ( platform : : is_gpu_place ( place ) ) {
try {
if ( VLOG_IS_ON ( 4 ) ) {
VLOG ( 4 ) < < place < < " " < < DebugStringEx ( & scope ) ;
}
if ( platform : : is_gpu_place ( place ) ) {
# ifndef PADDLE_WITH_CUDA
PADDLE_THROW ( " Cannot run operator on place %s " , place ) ;
PADDLE_THROW ( " Cannot run operator on place %s " , place ) ;
# else
auto dev_id = boost : : get < platform : : CUDAPlace > ( place ) . device ;
platform : : SetDeviceId ( dev_id ) ;
auto dev_id = boost : : get < platform : : CUDAPlace > ( place ) . device ;
platform : : SetDeviceId ( dev_id ) ;
# endif
}
platform : : DeviceContextPool & pool = platform : : DeviceContextPool : : Instance ( ) ;
platform : : RecordEvent record_event ( Type ( ) , pool . Get ( place ) ) ;
RunImpl ( scope , place ) ;
if ( VLOG_IS_ON ( 3 ) ) {
VLOG ( 3 ) < < place < < " " < < DebugStringEx ( & scope ) ;
}
} catch ( platform : : EnforceNotMet exception ) {
if ( Attrs ( ) . count ( " sub_block " ) ! = 0 ) {
throw exception ;
}
auto & callstack = Attr < std : : vector < std : : string > > (
OpProtoAndCheckerMaker : : OpCreationCallstackAttrName ( ) ) ;
if ( callstack . empty ( ) ) {
throw exception ;
}
std : : ostringstream sout ;
sout < < " Invoke operator " < < Type ( ) < < " error. \n " ;
sout < < " Python Callstacks: \n " ;
for ( auto & line : callstack ) {
sout < < line ;
}
sout < < " C++ Callstacks: \n " ;
sout < < exception . err_str_ ;
exception . err_str_ = sout . str ( ) ;
throw exception ;
} catch ( . . . ) {
std : : rethrow_exception ( std : : current_exception ( ) ) ;
}
platform : : DeviceContextPool & pool = platform : : DeviceContextPool : : Instance ( ) ;
platform : : RecordEvent record_event ( Type ( ) , pool . Get ( place ) ) ;
RunImpl ( scope , place ) ;
VLOG ( 3 ) < < place < < " " < < DebugStringEx ( & scope ) ;
}
bool OperatorBase : : HasInputs ( const std : : string & name ) const {
@ -167,7 +198,7 @@ const std::vector<std::string>& OperatorBase::Inputs(
}
bool OperatorBase : : HasOutputs ( const std : : string & name ) const {
if ( outputs_ . find( name ) ! = outputs_ . end ( ) ) {
if ( outputs_ . end( ) ! = outputs_ . find ( name ) ) {
return true ;
} else {
return false ;