|
|
|
@ -27,8 +27,6 @@ class AssignValueKernel : public framework::OpKernel<T> {
|
|
|
|
|
virtual void Compute(const framework::ExecutionContext& ctx) const {
|
|
|
|
|
auto shape = ctx.Attr<std::vector<int>>("shape");
|
|
|
|
|
auto* out = ctx.Output<framework::Tensor>("Out");
|
|
|
|
|
out->Resize(framework::make_ddim(shape));
|
|
|
|
|
auto* dst = out->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
int dtype = ctx.Attr<int>("dtype");
|
|
|
|
|
const char* value_name = nullptr;
|
|
|
|
|
switch (dtype) {
|
|
|
|
@ -43,12 +41,9 @@ class AssignValueKernel : public framework::OpKernel<T> {
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
auto values = ctx.Attr<std::vector<T>>(value_name);
|
|
|
|
|
Copy(dst, values.data(), sizeof(T) * values.size(), ctx);
|
|
|
|
|
framework::CopyFromVector(values, ctx.device_context(), out);
|
|
|
|
|
out->Resize(framework::make_ddim(shape));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
virtual void Copy(void* dst, const void* src, size_t size,
|
|
|
|
|
const framework::ExecutionContext& ctx) const = 0;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
} // namespace operators
|
|
|
|
|