|  |  |  | @ -14,8 +14,9 @@ limitations under the License. */ | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  | #pragma once | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  | #include <sstream> | 
			
		
	
		
			
				
					|  |  |  |  | #include <string> | 
			
		
	
		
			
				
					|  |  |  |  | #include <vector> | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  | #include "paddle/fluid/framework/data_type.h" | 
			
		
	
		
			
				
					|  |  |  |  | #include "paddle/fluid/framework/op_registry.h" | 
			
		
	
		
			
				
					|  |  |  |  | #include "paddle/fluid/operators/math/math_function.h" | 
			
		
	
	
		
			
				
					|  |  |  | @ -75,13 +76,28 @@ class FillConstantKernel : public framework::OpKernel<T> { | 
			
		
	
		
			
				
					|  |  |  |  |   void Compute(const paddle::framework::ExecutionContext &ctx) const override { | 
			
		
	
		
			
				
					|  |  |  |  |     auto data_type = | 
			
		
	
		
			
				
					|  |  |  |  |         static_cast<framework::proto::VarType::Type>(ctx.Attr<int>("dtype")); | 
			
		
	
		
			
				
					|  |  |  |  |     auto value = ctx.Attr<float>("value"); | 
			
		
	
		
			
				
					|  |  |  |  |     auto str_value = ctx.Attr<std::string>("str_value"); | 
			
		
	
		
			
				
					|  |  |  |  |     auto float_value = ctx.Attr<float>("value"); | 
			
		
	
		
			
				
					|  |  |  |  |     auto force_cpu = ctx.Attr<bool>("force_cpu"); | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |     framework::Tensor *tensor = nullptr; | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |     framework::Variable *out_var = ctx.OutputVar("Out"); | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |     T value; | 
			
		
	
		
			
				
					|  |  |  |  |     if (str_value.empty()) { | 
			
		
	
		
			
				
					|  |  |  |  |       value = static_cast<T>(float_value); | 
			
		
	
		
			
				
					|  |  |  |  |     } else { | 
			
		
	
		
			
				
					|  |  |  |  |       std::stringstream convert_stream(str_value); | 
			
		
	
		
			
				
					|  |  |  |  |       if (std::is_same<int64_t, T>::value) { | 
			
		
	
		
			
				
					|  |  |  |  |         int64_t tmp_value; | 
			
		
	
		
			
				
					|  |  |  |  |         convert_stream >> tmp_value; | 
			
		
	
		
			
				
					|  |  |  |  |         value = static_cast<T>(tmp_value); | 
			
		
	
		
			
				
					|  |  |  |  |       } else { | 
			
		
	
		
			
				
					|  |  |  |  |         double tmp_value; | 
			
		
	
		
			
				
					|  |  |  |  |         convert_stream >> tmp_value; | 
			
		
	
		
			
				
					|  |  |  |  |         value = static_cast<T>(tmp_value); | 
			
		
	
		
			
				
					|  |  |  |  |       } | 
			
		
	
		
			
				
					|  |  |  |  |     } | 
			
		
	
		
			
				
					|  |  |  |  |     auto shape = GetShape(ctx); | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |     if (out_var->IsType<framework::LoDTensor>()) { | 
			
		
	
	
		
			
				
					|  |  |  | @ -96,15 +112,23 @@ class FillConstantKernel : public framework::OpKernel<T> { | 
			
		
	
		
			
				
					|  |  |  |  |           "supports SelectedRows and LoDTensor"); | 
			
		
	
		
			
				
					|  |  |  |  |     } | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |     if (force_cpu) { | 
			
		
	
		
			
				
					|  |  |  |  |     platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); | 
			
		
	
		
			
				
					|  |  |  |  |     auto &dev_ctx = *pool.Get(ctx.GetPlace()); | 
			
		
	
		
			
				
					|  |  |  |  |     bool cpu_place = force_cpu || ctx.GetPlace() == platform::CPUPlace(); | 
			
		
	
		
			
				
					|  |  |  |  |     if (cpu_place) { | 
			
		
	
		
			
				
					|  |  |  |  |       tensor->mutable_data(platform::CPUPlace(), data_type); | 
			
		
	
		
			
				
					|  |  |  |  |     } else { | 
			
		
	
		
			
				
					|  |  |  |  |       math::SetConstant<platform::CPUDeviceContext, T> functor; | 
			
		
	
		
			
				
					|  |  |  |  |       functor(reinterpret_cast<const platform::CPUDeviceContext &>(dev_ctx), | 
			
		
	
		
			
				
					|  |  |  |  |               tensor, static_cast<T>(value)); | 
			
		
	
		
			
				
					|  |  |  |  |     } | 
			
		
	
		
			
				
					|  |  |  |  | #ifdef PADDLE_WITH_CUDA | 
			
		
	
		
			
				
					|  |  |  |  |     if (!cpu_place) { | 
			
		
	
		
			
				
					|  |  |  |  |       tensor->mutable_data(ctx.GetPlace(), data_type); | 
			
		
	
		
			
				
					|  |  |  |  |       math::SetConstant<platform::CUDADeviceContext, T> functor; | 
			
		
	
		
			
				
					|  |  |  |  |       functor(reinterpret_cast<const platform::CUDADeviceContext &>(dev_ctx), | 
			
		
	
		
			
				
					|  |  |  |  |               tensor, static_cast<T>(value)); | 
			
		
	
		
			
				
					|  |  |  |  |     } | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |     platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); | 
			
		
	
		
			
				
					|  |  |  |  |     auto &dev_ctx = *pool.Get(ctx.GetPlace()); | 
			
		
	
		
			
				
					|  |  |  |  |     math::set_constant(dev_ctx, tensor, value); | 
			
		
	
		
			
				
					|  |  |  |  | #endif | 
			
		
	
		
			
				
					|  |  |  |  |   } | 
			
		
	
		
			
				
					|  |  |  |  | }; | 
			
		
	
		
			
				
					|  |  |  |  | }  // namespace operators
 | 
			
		
	
	
		
			
				
					|  |  |  | 
 |