Add IPT Ascend

* Merge branch 'IPT' of gitee.com:xiaoan95/mindspore into ipt_ascend
* Add IPT Ascend
pull/13726/head
An Xiao 4 years ago
parent 9359983123
commit 44942be2df

@ -23,7 +23,7 @@ from mindspore import context
import mindspore.dataset as de
from mindspore.train.serialization import load_checkpoint, load_param_into_net
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU", device_id=0)
context.set_context(mode=context.GRAPH_MODE, device_target="ASCEND", device_id=0)
def main():
@ -46,11 +46,12 @@ def main():
net_m.set_train(False)
num_imgs = train_de_dataset.get_dataset_size()
psnrs = np.zeros((num_imgs, 1))
inference = ipt.IPT_post(net_m, args)
for batch_idx, imgs in enumerate(train_loader):
lr = imgs['LR']
hr = imgs['HR']
hr_np = np.float32(hr.asnumpy())
pred = net_m.infrc(lr)
pred = inference.forward(lr)
pred_np = np.float32(pred.asnumpy())
pred_np = quantize(pred_np, 255)
psnr = calc_psnr(pred_np, hr_np, 4, 255.0, y_only=True)

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save