""" Copyright OpenSearch Contributors SPDX-License-Identifier: Apache-2.0 """ import mock import pytest from prompt_toolkit.shortcuts import PromptSession from prompt_toolkit.input.defaults import create_pipe_input from src.opensearch_sql_cli.opensearch_buffer import opensearch_is_multiline from .utils import estest, load_data, TEST_INDEX_NAME, ENDPOINT from src.opensearch_sql_cli.opensearchsql_cli import OpenSearchSqlCli from src.opensearch_sql_cli.opensearch_connection import OpenSearchConnection from src.opensearch_sql_cli.opensearch_style import style_factory AUTH = None QUERY_WITH_CTRL_D = "select * from %s;\r\x04\r" % TEST_INDEX_NAME USE_AWS_CREDENTIALS = False QUERY_LANGUAGE = "sql" RESPONSE_TIMEOUT = 10 @pytest.fixture() def cli(default_config_location): return OpenSearchSqlCli(clirc_file=default_config_location, always_use_pager=False) class TestOpenSearchSqlCli: def test_connect(self, cli): with mock.patch.object( OpenSearchConnection, "__init__", return_value=None ) as mock_OpenSearchConnection, mock.patch.object( OpenSearchConnection, "set_connection" ) as mock_set_connectiuon: cli.connect(endpoint=ENDPOINT) mock_OpenSearchConnection.assert_called_with(ENDPOINT, AUTH, USE_AWS_CREDENTIALS, QUERY_LANGUAGE, RESPONSE_TIMEOUT) mock_set_connectiuon.assert_called() @estest @pytest.mark.skip(reason="due to prompt_toolkit throwing error, no way of currently testing this") def test_run_cli(self, connection, cli, capsys): doc = {"a": "aws"} load_data(connection, doc) # the title is colored by formatter expected = ( "fetched rows / total rows = 1/1" "\n+-----+\n| \x1b[38;5;47;01ma\x1b[39;00m |\n|-----|\n| aws |\n+-----+" ) with mock.patch.object(OpenSearchSqlCli, "echo_via_pager") as mock_pager, mock.patch.object( cli, "build_cli" ) as mock_prompt: inp = create_pipe_input() inp.send_text(QUERY_WITH_CTRL_D) mock_prompt.return_value = PromptSession( input=inp, multiline=opensearch_is_multiline(cli), style=style_factory(cli.syntax_style, cli.cli_style) ) cli.connect(ENDPOINT) cli.run_cli() out, err = capsys.readouterr() inp.close() mock_pager.assert_called_with(expected) assert out.__contains__("Endpoint: %s" % ENDPOINT) assert out.__contains__("See you next search!")