Pre Merge pull request !14748 from yepei6/master_check
commit
31346f305a
@ -0,0 +1,102 @@
|
||||
"""
|
||||
mindspore.check
|
||||
|
||||
The goal is to provide a convenient api to check if the installation is successful or failed.
|
||||
"""
|
||||
from abc import abstractmethod
|
||||
from importlib import import_module
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
mp = import_module("mindspore")
|
||||
logger = import_module("mindspore.log")
|
||||
except ModuleNotFoundError:
|
||||
mp = None
|
||||
logger = None
|
||||
|
||||
class Checker(mp.nn.Cell):
|
||||
"""basic class for check"""
|
||||
@abstractmethod
|
||||
def check_mul(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def check_install(self):
|
||||
pass
|
||||
|
||||
class CPUChecker(Checker):
|
||||
"""cpu environment check"""
|
||||
|
||||
def __init__(self):
|
||||
super(CPUChecker, self).__init__()
|
||||
self.mul = mp.ops.operations.Mul()
|
||||
|
||||
def construct(self, x, y):
|
||||
return self.mul(x, y)
|
||||
|
||||
def check_mul(self):
|
||||
"""Define the cpu mul method."""
|
||||
input_x = mp.Tensor(np.array([1.0, 2.0, 3.0]), mp.float32)
|
||||
input_y = mp.Tensor(np.array([4.0, 5.0, 6.0]), mp.float32)
|
||||
mul = CPUChecker()
|
||||
print(mul(input_x, input_y))
|
||||
|
||||
|
||||
class GPUChecker(Checker):
|
||||
"""gpu environment check"""
|
||||
def __init__(self):
|
||||
super(GPUChecker, self).__init__()
|
||||
|
||||
def check_mul(self):
|
||||
"""Define the gpu mul method."""
|
||||
input_x = mp.Tensor(np.array([1.0, 2.0, 3.0]), mp.float32)
|
||||
input_y = mp.Tensor(np.array([4.0, 5.0, 6.0]), mp.float32)
|
||||
mul = mp.ops.Mul()
|
||||
output = mul(input_x, input_y)
|
||||
print(output)
|
||||
|
||||
class AscendChecker(Checker):
|
||||
"""gpu environment check"""
|
||||
def __init__(self):
|
||||
super(AscendChecker, self).__init__()
|
||||
|
||||
def check_mul(self):
|
||||
"""Define the gpu mul method."""
|
||||
input_x = mp.Tensor(np.array([1.0, 2.0, 3.0]), mp.float32)
|
||||
input_y = mp.Tensor(np.array([4.0, 5.0, 6.0]), mp.float32)
|
||||
mul = mp.ops.Mul()
|
||||
output = mul(input_x, input_y)
|
||||
print(output)
|
||||
|
||||
def check_install():
|
||||
"""Define the check install method."""
|
||||
logger.info("mindspore version:", mp.__version__)
|
||||
|
||||
def check():
|
||||
try:
|
||||
context = import_module("mindspore.context")
|
||||
except ModuleNotFoundError:
|
||||
context = None
|
||||
|
||||
device_target = input("please input device_target['Ascend', 'GPU', 'CPU']):")
|
||||
valid_targets = ["CPU", "GPU", "Ascend"]
|
||||
if not device_target in valid_targets:
|
||||
raise ValueError(f"Target device name {device_target} is invalid! It must be one of {valid_targets}")
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=device_target)
|
||||
|
||||
if context.get_context("device_target") == "GPU":
|
||||
checker = GPUChecker()
|
||||
elif context.get_context("device_target") == "Ascend":
|
||||
checker = AscendChecker()
|
||||
elif context.get_context("device_target") == "CPU":
|
||||
checker = CPUChecker()
|
||||
else:
|
||||
logger.warning(f"Package version {device_target} does not need to check any environment variable, skipping.")
|
||||
return
|
||||
|
||||
try:
|
||||
check_install()
|
||||
logger.info(f"mindspore mul operate {device_target} result:")
|
||||
checker.check_mul()
|
||||
except ValueError:
|
||||
logger.warning(f"Target device name {device_target} is invalid! It must be one of ['CPU', 'GPU', 'Ascend']")
|
Loading…
Reference in new issue