|
|
|
@ -17,7 +17,7 @@ import argparse
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
|
from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export
|
|
|
|
|
|
|
|
|
|
from eval import BuildEvalNetwork
|
|
|
|
|
from src.nets import net_factory
|
|
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser(description='checkpoint export')
|
|
|
|
@ -43,6 +43,7 @@ if __name__ == '__main__':
|
|
|
|
|
network = net_factory.nets_map['deeplab_v3_s16']('eval', args.num_classes, 16, True)
|
|
|
|
|
else:
|
|
|
|
|
network = net_factory.nets_map['deeplab_v3_s8']('eval', args.num_classes, 8, True)
|
|
|
|
|
network = BuildEvalNetwork(network)
|
|
|
|
|
param_dict = load_checkpoint(args.ckpt_file)
|
|
|
|
|
|
|
|
|
|
# load the parameter into net
|
|
|
|
|