#!/usr/bin/env python import sys import subprocess from subprocess import PIPE import logging import numpy as np import tensorflow as tf from packaging.version import Version from packaging.specifiers import SpecifierSet LOGGER = logging.getLogger(__name__) logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) def main(): check_tensorflow() def check_tensorflow(): try: process = subprocess.Popen(["bash", "-c", "nvidia-smi"], stdout=PIPE) process.communicate() ret_val = process.wait() is_gpu = not ret_val # Creates a session with log_device_placement set to False. sess = tf.compat.v1.Session(config=tf.compat.v1.ConfigProto(log_device_placement=False)) with sess: tf_version = Version(tf.__version__) device_tag = ("XLA_" if tf_version in SpecifierSet(">2,<2.4") else "") + ("GPU" if is_gpu else "CPU") with tf.device(f"/job:localhost/replica:0/task:0/device:{device_tag}:0"): x_mat = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], shape=[2, 3], name='x_mat') y_mat = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], shape=[3, 2], name='y_mat') z_mat = tf.matmul(x_mat, y_mat) assert sess.run(z_mat).any(), f"Error: {device_tag} sanity test for TensorFlow failed." except Exception as excp: LOGGER.debug("Error: check_tensorflow test failed.") LOGGER.debug("Exception: {}".format(excp)) raise if __name__ == "__main__": try: sys.exit(main()) except KeyboardInterrupt: pass