Pre Merge pull request !14748 from yepei6/master_check

pull/14748/MERGE
yepei6 4 years ago committed by Gitee
commit 31346f305a

@ -0,0 +1,102 @@
"""
mindspore.check
The goal is to provide a convenient api to check if the installation is successful or failed.
"""
from abc import abstractmethod
from importlib import import_module
import numpy as np
try:
mp = import_module("mindspore")
logger = import_module("mindspore.log")
except ModuleNotFoundError:
mp = None
logger = None
class Checker(mp.nn.Cell):
"""basic class for check"""
@abstractmethod
def check_mul(self):
pass
@abstractmethod
def check_install(self):
pass
class CPUChecker(Checker):
"""cpu environment check"""
def __init__(self):
super(CPUChecker, self).__init__()
self.mul = mp.ops.operations.Mul()
def construct(self, x, y):
return self.mul(x, y)
def check_mul(self):
"""Define the cpu mul method."""
input_x = mp.Tensor(np.array([1.0, 2.0, 3.0]), mp.float32)
input_y = mp.Tensor(np.array([4.0, 5.0, 6.0]), mp.float32)
mul = CPUChecker()
print(mul(input_x, input_y))
class GPUChecker(Checker):
"""gpu environment check"""
def __init__(self):
super(GPUChecker, self).__init__()
def check_mul(self):
"""Define the gpu mul method."""
input_x = mp.Tensor(np.array([1.0, 2.0, 3.0]), mp.float32)
input_y = mp.Tensor(np.array([4.0, 5.0, 6.0]), mp.float32)
mul = mp.ops.Mul()
output = mul(input_x, input_y)
print(output)
class AscendChecker(Checker):
"""gpu environment check"""
def __init__(self):
super(AscendChecker, self).__init__()
def check_mul(self):
"""Define the gpu mul method."""
input_x = mp.Tensor(np.array([1.0, 2.0, 3.0]), mp.float32)
input_y = mp.Tensor(np.array([4.0, 5.0, 6.0]), mp.float32)
mul = mp.ops.Mul()
output = mul(input_x, input_y)
print(output)
def check_install():
"""Define the check install method."""
logger.info("mindspore version:", mp.__version__)
def check():
try:
context = import_module("mindspore.context")
except ModuleNotFoundError:
context = None
device_target = input("please input device_target['Ascend', 'GPU', 'CPU']):")
valid_targets = ["CPU", "GPU", "Ascend"]
if not device_target in valid_targets:
raise ValueError(f"Target device name {device_target} is invalid! It must be one of {valid_targets}")
context.set_context(mode=context.GRAPH_MODE, device_target=device_target)
if context.get_context("device_target") == "GPU":
checker = GPUChecker()
elif context.get_context("device_target") == "Ascend":
checker = AscendChecker()
elif context.get_context("device_target") == "CPU":
checker = CPUChecker()
else:
logger.warning(f"Package version {device_target} does not need to check any environment variable, skipping.")
return
try:
check_install()
logger.info(f"mindspore mul operate {device_target} result:")
checker.check_mul()
except ValueError:
logger.warning(f"Target device name {device_target} is invalid! It must be one of ['CPU', 'GPU', 'Ascend']")
Loading…
Cancel
Save