# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. """Tests for common micro transports.""" import logging import sys import unittest import pytest import tvm.testing # Implementing as a fixture so that the tvm.micro import doesn't occur # until fixture setup time. This is necessary for pytest's collection # phase to work when USE_MICRO=OFF, while still explicitly listing the # tests as skipped. @tvm.testing.fixture def transport(): import tvm.micro class MockTransport_Impl(tvm.micro.transport.Transport): def __init__(self): self.exc = None self.to_return = None def _raise_or_return(self): if self.exc is not None: to_raise = self.exc self.exc = None raise to_raise elif self.to_return is not None: to_return = self.to_return self.to_return = None return to_return else: assert False, "should not get here" def open(self): pass def close(self): pass def timeouts(self): raise NotImplementedError() def read(self, n, timeout_sec): return self._raise_or_return() def write(self, data, timeout_sec): return self._raise_or_return() return MockTransport_Impl() @tvm.testing.fixture def transport_logger(transport): logger = logging.getLogger("transport_logger_test") return tvm.micro.transport.TransportLogger("foo", transport, logger=logger) @tvm.testing.fixture def get_latest_log(caplog): def inner(): return caplog.records[-1].getMessage() with caplog.at_level(logging.INFO, "transport_logger_test"): yield inner @tvm.testing.requires_micro def test_open(transport_logger, get_latest_log): transport_logger.open() assert get_latest_log() == "foo: opening transport" @tvm.testing.requires_micro def test_close(transport_logger, get_latest_log): transport_logger.close() assert get_latest_log() == "foo: closing transport" @tvm.testing.requires_micro def test_read_normal(transport, transport_logger, get_latest_log): transport.to_return = b"data" transport_logger.read(23, 3.0) assert get_latest_log() == ( "foo: read { 3.00s} 23 B -> [ 4 B]: 64 61 74 61" " data" ) @tvm.testing.requires_micro def test_read_multiline(transport, transport_logger, get_latest_log): transport.to_return = b"data" * 6 transport_logger.read(23, 3.0) assert get_latest_log() == ( "foo: read { 3.00s} 23 B -> [ 24 B]:\n" "0000 64 61 74 61 64 61 74 61 64 61 74 61 64 61 74 61 datadatadatadata\n" "0010 64 61 74 61 64 61 74 61 datadata" ) @tvm.testing.requires_micro def test_read_no_timeout_prints(transport, transport_logger, get_latest_log): transport.to_return = b"data" transport_logger.read(15, None) assert get_latest_log() == ( "foo: read { None } 15 B -> [ 4 B]: 64 61 74 61" " data" ) @tvm.testing.requires_micro def test_read_io_timeout(transport, transport_logger, get_latest_log): # IoTimeoutError includes the timeout value. transport.exc = tvm.micro.transport.IoTimeoutError() with pytest.raises(tvm.micro.transport.IoTimeoutError): transport_logger.read(23, 0.0) assert get_latest_log() == ("foo: read { 0.00s} 23 B -> [IoTimeoutError 0.00s]") @tvm.testing.requires_micro def test_read_other_exception(transport, transport_logger, get_latest_log): # Other exceptions are logged by name. transport.exc = tvm.micro.transport.TransportClosedError() with pytest.raises(tvm.micro.transport.TransportClosedError): transport_logger.read(8, 0.0) assert get_latest_log() == ("foo: read { 0.00s} 8 B -> [err: TransportClosedError]") @tvm.testing.requires_micro def test_read_keyboard_interrupt(transport, transport_logger, get_latest_log): # KeyboardInterrupt produces no log record. transport.exc = KeyboardInterrupt() with pytest.raises(KeyboardInterrupt): transport_logger.read(8, 0.0) with pytest.raises(IndexError): get_latest_log() @tvm.testing.requires_micro def test_write_normal(transport, transport_logger, get_latest_log): transport.to_return = 3 transport_logger.write(b"data", 3.0) assert get_latest_log() == ( "foo: write { 3.00s} <- [ 4 B]: 64 61 74 61" " data" ) @tvm.testing.requires_micro def test_write_multiline(transport, transport_logger, get_latest_log): # Normal log, multi-line data written. transport.to_return = 20 transport_logger.write(b"data" * 6, 3.0) assert get_latest_log() == ( "foo: write { 3.00s} <- [ 24 B]:\n" "0000 64 61 74 61 64 61 74 61 64 61 74 61 64 61 74 61 datadatadatadata\n" "0010 64 61 74 61 64 61 74 61 datadata" ) @tvm.testing.requires_micro def test_write_no_timeout_prints(transport, transport_logger, get_latest_log): transport.to_return = 3 transport_logger.write(b"data", None) assert get_latest_log() == ( "foo: write { None } <- [ 4 B]: 64 61 74 61" " data" ) @tvm.testing.requires_micro def test_write_io_timeout(transport, transport_logger, get_latest_log): # IoTimeoutError includes the timeout value. transport.exc = tvm.micro.transport.IoTimeoutError() with pytest.raises(tvm.micro.transport.IoTimeoutError): transport_logger.write(b"data", 0.0) assert get_latest_log() == ("foo: write { 0.00s} <- [ 4 B]: [IoTimeoutError 0.00s]") @tvm.testing.requires_micro def test_write_other_exception(transport, transport_logger, get_latest_log): # Other exceptions are logged by name. transport.exc = tvm.micro.transport.TransportClosedError() with pytest.raises(tvm.micro.transport.TransportClosedError): transport_logger.write(b"data", 0.0) assert get_latest_log() == ("foo: write { 0.00s} <- [ 4 B]: [err: TransportClosedError]") @tvm.testing.requires_micro def test_write_keyboard_interrupt(transport, transport_logger, get_latest_log): # KeyboardInterrupt produces no log record. transport.exc = KeyboardInterrupt() with pytest.raises(KeyboardInterrupt): transport_logger.write(b"data", 0.0) with pytest.raises(IndexError): get_latest_log() if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:]))