"""Ray ArrowCSVDatasource Module."""
from typing import Any, Iterator

import pyarrow as pa
from pyarrow import csv
from ray.data.block import BlockAccessor

from awswrangler._arrow import _add_table_partitions
from awswrangler.distributed.ray.datasources.pandas_file_based_datasource import PandasFileBasedDatasource


class ArrowCSVDatasource(PandasFileBasedDatasource):  # pylint: disable=abstract-method
    """CSV datasource, for reading and writing CSV files using PyArrow."""

    _FILE_EXTENSION = "csv"

    def _read_stream(  # type: ignore[override]  # pylint: disable=arguments-differ
        self,
        f: pa.NativeFile,
        path: str,
        path_root: str,
        dataset: bool,
        **reader_args: Any,
    ) -> Iterator[pa.Table]:
        read_options = reader_args.get("read_options", csv.ReadOptions(use_threads=False))
        parse_options = reader_args.get(
            "parse_options",
            csv.ParseOptions(),
        )
        convert_options = reader_args.get("convert_options", csv.ConvertOptions())

        reader = csv.open_csv(
            f,
            read_options=read_options,
            parse_options=parse_options,
            convert_options=convert_options,
        )

        schema = None
        while True:
            try:
                batch = reader.read_next_batch()
                table = pa.Table.from_batches([batch], schema=schema)
                if schema is None:
                    schema = table.schema

                if dataset:
                    table = _add_table_partitions(
                        table=table,
                        path=f"s3://{path}",
                        path_root=path_root,
                    )

                yield table

            except StopIteration:
                return

    def _write_block(  # type: ignore[override]  # pylint: disable=arguments-differ
        self,
        f: pa.NativeFile,
        block: BlockAccessor,
        **writer_args: Any,
    ) -> None:
        write_options_dict = writer_args.get("write_options", {})
        write_options = csv.WriteOptions(**write_options_dict)

        csv.write_csv(block.to_arrow(), f, write_options)