Source code for chop.tools.check_dependency
import logging
from importlib.util import find_spec
from chop.passes.utils import PassFactory
logger = logging.getLogger(__name__)
[docs]
def check_deps_tensorRT_pass(silent: bool = True):
dependencies = ["pytorch_quantization", "tensorrt", "pynvml", "pycuda", "cuda"]
availabilities = [find_spec(dep) is not None for dep in dependencies]
unavailable_deps = [
dep for dep, avail in zip(dependencies, availabilities) if not avail
]
if not silent:
if not all(availabilities):
logger.warning(
f"TensorRT pass is unavailable because the following dependencies are not installed: {', '.join(unavailable_deps)}."
)
else:
logger.info("Extension: All dependencies for TensorRT pass are available.")
return all(availabilities)
[docs]
def find_missing_dependencies(
pass_name: str,
):
dependencies = PassFactory._dependencies_dict.get(pass_name, None)
if dependencies is None:
return []
availabilities = [find_spec(dep) is not None for dep in dependencies]
unavailable_deps = [
dep for dep, avail in zip(dependencies, availabilities) if not avail
]
return unavailable_deps
[docs]
def check_dependencies(
pass_name: str,
silent: bool = True,
):
unavailable_deps = find_missing_dependencies(pass_name)
if not silent:
if len(unavailable_deps) > 0:
logger.warning(
f"Pass: {pass_name} is unavailable because the following dependencies are not installed: {', '.join(unavailable_deps)}."
)
else:
logger.info(
f"Extension: All dependencies for the {pass_name} pass are available."
)
return len(unavailable_deps) == 0