import logging import tvm from typing import Dict, List, Optional from pathlib import Path from tvm import relay from tvm.error import OpError from tvm.relay.frontend.tensorflow_parser import TFParser from .abstract_model_loader import AbstractModelLoader from .convert_layout_mixin import ConvertLayoutMixin from .helpers.tf_model_helper import TFModelHelper from ._base import GraphIR logger = logging.getLogger(__name__) class TensorflowModelLoader(AbstractModelLoader, ConvertLayoutMixin): def __init__(self, model_artifacts: List[str], data_shape: Dict[str, List[int]]) -> None: super(TensorflowModelLoader, self).__init__(model_artifacts, data_shape) self.__model_path = None self.__output_tensor_names = None self.__tf_graph = None self.__tf_model_helper = None self.__is_tf2_model = False self.__model_tf_version = None @property def ir_format(self) -> GraphIR: return GraphIR.relay @property def model_objects(self) -> (tvm.IRModule, tvm.nd.NDArray): return self._relay_module_object, self._params @property def aux_files(self) -> List[Path]: return [] def __get_model_dir_from_model_artifacts(self) -> Optional[Path]: model_dirs = [] for path in self.model_artifacts: if path.is_dir(): if path.joinpath('variables').exists(): model_dirs.append(path) if path.joinpath('checkpoint').exists(): raise RuntimeError('InputConfiguration: TF Checkpoints are not supported. ' 'Please make sure the framework you select is correct.') if len(model_dirs) > 1: raise RuntimeError('InputConfiguration: Exactly one saved model is allowed for TensorFlow models.') elif len(model_dirs) == 1: return model_dirs[0] else: return None def __get_model_file_from_model_artifacts(self) -> Optional[Path]: model_files = self._get_files_from_model_artifacts_with_extensions(["pb", "pbtxt"]) if len(model_files) > 1: raise RuntimeError('InputConfiguration: Exactly one .pb or .pbtxt file is allowed for TensorFlow models.') elif len(model_files) == 1: return model_files[0] else: return None def __extract_model_path_from_model_artifacts(self) -> None: model_file = self.__get_model_file_from_model_artifacts() model_dir = self.__get_model_dir_from_model_artifacts() if model_dir: self.__model_path = model_dir elif model_file: self.__model_path = model_file else: raise RuntimeError('InputConfiguration: No valid TensorFlow model found in input files. ' 'Please make sure the framework you select is correct.') def __extract_metadata_and_output_tensor_names_from_model(self) -> None: try: self.__tf_model_helper.extract_input_and_output_tensors() except Exception as e: logger.warning("Try to extract input and output tensor for potential TF2 model.") try: self.__tf_model_helper.extract_input_and_output_tensors_v2() self.__is_tf2_model = True except Exception as error: logger.exception("Framework cannot load model. {}".format(error)) raise RuntimeError("InputConfiguration: Framework cannot load Tensorflow model: {}".format(e)) try: self.__output_tensor_names = [name.rstrip(":0") for name in self.__tf_model_helper.output_tensor_names] self._metadata = self.__tf_model_helper.get_metadata() except Exception as e: logger.exception("Framework cannot load model.") raise RuntimeError("InputConfiguration: Framework cannot load Tensorflow model: {}".format(e)) def __extract_tf_graph(self): if self.__is_tf2_model: try: logger.info("Loading TF model for potential TF 2.x model.") self.__tf_graph = self.__tf_model_helper.get_tf_graph_from_graph_model_v2() except Exception as e: logger.exception("Failed to load TF model. %s" % repr(e)) raise RuntimeError("InputConfiguration: Framework cannot load Tensorflow model: {}".format(e)) else: try: logger.info("Loading TF model from TFParser.") self.__tf_graph = TFParser(self.__model_path.as_posix(), self.__output_tensor_names).parse() self.__model_tf_version = "1.x" except Exception as e: # Temp workaround for TF2 models, remove the logic when TF2 is introduced try: logger.warning("Failed to load TF model from TFParser, will try to load with compat.v2. %s" % repr(e)) self.__tf_graph = self.__tf_model_helper.get_tf_graph_from_graph_model_v2() except Exception as error: logger.exception("Failed to load TF model. %s" % repr(e)) raise RuntimeError("InputConfiguration: Framework cannot load Tensorflow model: {}".format(e)) def load_model(self) -> None: self.__extract_model_path_from_model_artifacts() self.__tf_model_helper = TFModelHelper(self.__model_path.as_posix(), self.data_shape) self.__extract_metadata_and_output_tensor_names_from_model() self.__extract_tf_graph() self.__model_tf_version = self.__tf_model_helper.get_tensorflow_version() logger.info("Model Version tf-{}".format(self.__model_tf_version)) try: self._relay_module_object, self._params = relay.frontend.from_tensorflow( self.__tf_graph, shape=self.data_shape, outputs=self.__output_tensor_names ) self._relay_module_object = self.convert_layout(self._relay_module_object) self.update_missing_metadata() except OpError: raise except Exception as e: logger.exception("Failed to convert tensorflow model. %s" % repr(e)) msg = "InputConfiguration: TVM cannot convert Tensorflow model. Please make sure the framework you selected is correct. {}".format(e) msg += self.model_version_hint_message() raise RuntimeError(msg) def model_version_hint_message(self) -> str: if self.__model_tf_version: model_version = self.__model_tf_version if model_version.startswith("2."): return "\nTensorflow version selected: {}. Model version founded: {}.".format("1.x", model_version) return ""