/*
 * Licensed to Elasticsearch B.V. under one or more contributor
 * license agreements. See the NOTICE file distributed with
 * this work for additional information regarding copyright
 * ownership. Elasticsearch B.V. licenses this file to you under
 * the Apache License, Version 2.0 (the "License"); you may
 * not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *	http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*
 * SPDX-License-Identifier: Apache-2.0
 *
 * The OpenSearch Contributors require contributions made to
 * this file be licensed under the Apache-2.0 license or a
 * compatible open source license.
 *
 * Modifications Copyright OpenSearch Contributors. See
 * GitHub history for details.
 */

//! HTTP transport and connection components

#[cfg(any(feature = "native-tls", feature = "rustls-tls"))]
use crate::auth::ClientCertificate;
#[cfg(any(feature = "native-tls", feature = "rustls-tls"))]
use crate::cert::CertificateValidation;
use crate::{
    auth::Credentials,
    error::Error,
    http::{
        headers::{
            HeaderMap, HeaderName, HeaderValue, ACCEPT, AUTHORIZATION, CONTENT_TYPE,
            DEFAULT_ACCEPT, DEFAULT_CONTENT_TYPE, DEFAULT_USER_AGENT, USER_AGENT,
        },
        request::Body,
        response::Response,
        Method,
    },
};
use base64::{prelude::BASE64_STANDARD, write::EncoderWriter as Base64Encoder};
use bytes::BytesMut;
use dyn_clone::clone_trait_object;
use lazy_static::lazy_static;
use serde::Serialize;
use std::{
    error, fmt,
    fmt::Debug,
    io::{self, Write},
    time::Duration,
};
use url::Url;

/// Error that can occur when building a [Transport]
#[derive(Debug)]
pub enum BuildError {
    /// IO error
    Io(io::Error),

    /// Certificate error
    Cert(reqwest::Error),
}

impl From<io::Error> for BuildError {
    fn from(err: io::Error) -> BuildError {
        BuildError::Io(err)
    }
}

impl From<reqwest::Error> for BuildError {
    fn from(err: reqwest::Error) -> BuildError {
        BuildError::Cert(err)
    }
}

impl error::Error for BuildError {
    #[allow(warnings)]
    fn description(&self) -> &str {
        match *self {
            BuildError::Io(ref err) => err.description(),
            BuildError::Cert(ref err) => err.description(),
        }
    }

    fn cause(&self) -> Option<&dyn error::Error> {
        match *self {
            BuildError::Io(ref err) => Some(err as &dyn error::Error),
            BuildError::Cert(ref err) => Some(err as &dyn error::Error),
        }
    }
}

impl fmt::Display for BuildError {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        match *self {
            BuildError::Io(ref err) => fmt::Display::fmt(err, f),
            BuildError::Cert(ref err) => fmt::Display::fmt(err, f),
        }
    }
}

/// Default address to OpenSearch running on `http://localhost:9200`
pub static DEFAULT_ADDRESS: &str = "http://localhost:9200";

lazy_static! {
    /// Client metadata header: service, language, transport, followed by additional information
    static ref CLIENT_META: String = build_meta();
}

fn build_meta() -> String {
    let mut version_parts = env!("CARGO_PKG_VERSION").split(&['.', '-'][..]);
    let mut version = String::new();

    // major.minor.patch followed with an optional 'p' for preliminary versions
    version.push_str(version_parts.next().unwrap());
    version.push('.');
    version.push_str(version_parts.next().unwrap());
    version.push('.');
    version.push_str(version_parts.next().unwrap());
    if version_parts.next().is_some() {
        version.push('p');
    }

    let rustc = env!("RUSTC_VERSION");
    let mut meta = format!("es={},rs={},t={}", version, rustc, version);

    if cfg!(feature = "native-tls") {
        meta.push_str(",tls=n");
    } else if cfg!(feature = "rustls-tls") {
        meta.push_str(",tls=r");
    }

    meta
}

/// Builds a HTTP transport to make API calls to OpenSearch
pub struct TransportBuilder {
    client_builder: reqwest::ClientBuilder,
    conn_pool: Box<dyn ConnectionPool>,
    credentials: Option<Credentials>,
    #[cfg(any(feature = "native-tls", feature = "rustls-tls"))]
    cert_validation: Option<CertificateValidation>,
    proxy: Option<Url>,
    proxy_credentials: Option<Credentials>,
    disable_proxy: bool,
    headers: HeaderMap,
    timeout: Option<Duration>,
    #[cfg(feature = "aws-auth")]
    service_name: String,
}

impl TransportBuilder {
    /// Creates a new instance of [TransportBuilder]. Accepts a [ConnectionPool]
    /// from which [Connection]s to OpenSearch will be retrieved.
    pub fn new<P>(conn_pool: P) -> Self
    where
        P: ConnectionPool + Debug + Clone + Send + 'static,
    {
        Self {
            client_builder: reqwest::ClientBuilder::new(),
            conn_pool: Box::new(conn_pool),
            credentials: None,
            #[cfg(any(feature = "native-tls", feature = "rustls-tls"))]
            cert_validation: None,
            proxy: None,
            proxy_credentials: None,
            disable_proxy: false,
            headers: HeaderMap::new(),
            timeout: None,
            #[cfg(feature = "aws-auth")]
            service_name: "es".to_string(),
        }
    }

    /// Configures a proxy.
    ///
    /// An optional username and password will be used to set the
    /// `Proxy-Authorization` header using Basic Authentication.
    pub fn proxy(mut self, url: Url, username: Option<&str>, password: Option<&str>) -> Self {
        self.proxy = Some(url);
        if let Some(u) = username {
            let p = password.unwrap_or("");
            self.proxy_credentials = Some(Credentials::Basic(u.into(), p.into()));
        }

        self
    }

    /// Whether to disable proxies, including system proxies.
    ///
    /// NOTE: System proxies are enabled by default.
    pub fn disable_proxy(mut self) -> Self {
        self.disable_proxy = true;
        self
    }

    /// Credentials for the client to use for authentication to OpenSearch.
    pub fn auth(mut self, credentials: Credentials) -> Self {
        self.credentials = Some(credentials);
        self
    }

    /// Validation applied to the certificate provided to establish a HTTPS connection.
    /// By default, full validation is applied. When using a self-signed certificate,
    /// different validation can be applied.
    #[cfg(any(feature = "native-tls", feature = "rustls-tls"))]
    pub fn cert_validation(mut self, validation: CertificateValidation) -> Self {
        self.cert_validation = Some(validation);
        self
    }

    /// Adds a HTTP header that will be added to all client API calls.
    ///
    /// A default HTTP header can be overridden on a per API call basis.
    pub fn header(mut self, key: HeaderName, value: HeaderValue) -> Self {
        self.headers.insert(key, value);
        self
    }

    /// Adds HTTP headers that will be added to all client API calls.
    ///
    /// Default HTTP headers can be overridden on a per API call basis.
    pub fn headers(mut self, headers: HeaderMap) -> Self {
        for (key, value) in headers.iter() {
            self.headers.insert(key, value.clone());
        }
        self
    }

    /// Sets a global request timeout for the client.
    ///
    /// The timeout is applied from when the request starts connecting until the response body has finished.
    /// Default is no timeout.
    pub fn timeout(mut self, timeout: Duration) -> Self {
        self.timeout = Some(timeout);
        self
    }

    /// Sets a global AWS service name.
    ///
    /// Default is "es". Other supported services are "aoss" for OpenSearch Serverless.
    #[cfg(feature = "aws-auth")]
    pub fn service_name(mut self, service_name: &str) -> Self {
        self.service_name = service_name.to_string();
        self
    }

    /// Builds a [Transport] to use to send API calls to Elasticsearch.
    pub fn build(self) -> Result<Transport, BuildError> {
        let mut client_builder = self.client_builder;

        if !self.headers.is_empty() {
            client_builder = client_builder.default_headers(self.headers);
        }

        if let Some(t) = self.timeout {
            client_builder = client_builder.timeout(t);
        }

        #[cfg(any(feature = "native-tls", feature = "rustls-tls"))]
        {
            if let Some(Credentials::Certificate(cert)) = &self.credentials {
                client_builder = match cert {
                    #[cfg(feature = "native-tls")]
                    ClientCertificate::Pkcs12(b, p) => {
                        let password = match p {
                            Some(pass) => pass.as_str(),
                            None => "",
                        };
                        let pkcs12 = reqwest::Identity::from_pkcs12_der(b, password)?;
                        client_builder.identity(pkcs12)
                    }
                    #[cfg(feature = "rustls-tls")]
                    ClientCertificate::Pem(b) => {
                        let pem = reqwest::Identity::from_pem(b)?;
                        client_builder.identity(pem)
                    }
                }
            }
        }

        #[cfg(any(feature = "native-tls", feature = "rustls-tls"))]
        if let Some(v) = self.cert_validation {
            client_builder = match v {
                CertificateValidation::Default => client_builder,
                CertificateValidation::Full(chain) => {
                    chain.into_iter().fold(client_builder, |client_builder, c| {
                        client_builder.add_root_certificate(c)
                    })
                }
                #[cfg(feature = "native-tls")]
                CertificateValidation::Certificate(chain) => chain
                    .into_iter()
                    .fold(client_builder, |client_builder, c| {
                        client_builder.add_root_certificate(c)
                    })
                    .danger_accept_invalid_hostnames(true),
                CertificateValidation::None => client_builder.danger_accept_invalid_certs(true),
            }
        }

        if self.disable_proxy {
            client_builder = client_builder.no_proxy();
        } else if let Some(url) = self.proxy {
            let mut proxy = reqwest::Proxy::all(url)?;
            if let Some(c) = self.proxy_credentials {
                proxy = match c {
                    Credentials::Basic(u, p) => proxy.basic_auth(&u, &p),
                    _ => proxy,
                };
            }
            client_builder = client_builder.proxy(proxy);
        }

        let client = client_builder.build()?;
        Ok(Transport {
            client,
            conn_pool: self.conn_pool,
            credentials: self.credentials,
            #[cfg(feature = "aws-auth")]
            service_name: self.service_name,
        })
    }
}

impl Default for TransportBuilder {
    /// Creates a default implementation using the default implementation of [SingleNodeConnectionPool].
    fn default() -> Self {
        Self::new(SingleNodeConnectionPool::default())
    }
}

/// A connection to an Elasticsearch node, used to send an API request
#[derive(Debug, Clone)]
pub struct Connection {
    url: Url,
}

impl Connection {
    /// Creates a new instance of a [Connection].
    ///
    /// If the passed [url::Url] path does not have a trailing forward slash, a trailing
    /// forward slash will be appended
    pub fn new(url: Url) -> Self {
        let mut url = url;
        if !url.path().ends_with('/') {
            url.set_path(&format!("{}/", url.path()));
        }

        Self { url }
    }
}

/// A HTTP transport responsible for making the API requests to Elasticsearch,
/// using a [Connection] selected from a [ConnectionPool]
#[derive(Debug, Clone)]
pub struct Transport {
    client: reqwest::Client,
    credentials: Option<Credentials>,
    conn_pool: Box<dyn ConnectionPool>,
    #[cfg(feature = "aws-auth")]
    service_name: String,
}

impl Transport {
    fn method(&self, method: Method) -> reqwest::Method {
        match method {
            Method::Get => reqwest::Method::GET,
            Method::Put => reqwest::Method::PUT,
            Method::Post => reqwest::Method::POST,
            Method::Delete => reqwest::Method::DELETE,
            Method::Head => reqwest::Method::HEAD,
        }
    }

    fn bytes_mut(&self) -> BytesMut {
        // NOTE: These could be pooled or re-used
        BytesMut::with_capacity(1024)
    }

    /// Creates a new instance of a [Transport] configured with a
    /// [SingleNodeConnectionPool].
    pub fn single_node(url: &str) -> Result<Transport, Error> {
        let u = Url::parse(url)?;
        let conn_pool = SingleNodeConnectionPool::new(u);
        let transport = TransportBuilder::new(conn_pool).build()?;
        Ok(transport)
    }

    /// Creates an asynchronous request that can be awaited
    pub async fn send<B, Q>(
        &self,
        method: Method,
        path: &str,
        headers: HeaderMap,
        query_string: Option<&Q>,
        body: Option<B>,
        timeout: Option<Duration>,
    ) -> Result<Response, Error>
    where
        B: Body,
        Q: Serialize + ?Sized,
    {
        let connection = self.conn_pool.next();
        let url = connection.url.join(path.trim_start_matches('/'))?;
        let reqwest_method = self.method(method);
        let mut request_builder = self.client.request(reqwest_method, url);

        if let Some(t) = timeout {
            request_builder = request_builder.timeout(t);
        }

        // set credentials before any headers, as credentials append to existing headers in reqwest,
        // whilst setting headers() overwrites, so if an Authorization header has been specified
        // on a specific request, we want it to overwrite.
        if let Some(c) = &self.credentials {
            request_builder = match c {
                Credentials::Basic(u, p) => request_builder.basic_auth(u, Some(p)),
                Credentials::Bearer(t) => request_builder.bearer_auth(t),
                #[cfg(any(feature = "native-tls", feature = "rustls-tls"))]
                Credentials::Certificate(_) => request_builder,
                Credentials::ApiKey(i, k) => {
                    let mut header_value = b"ApiKey ".to_vec();
                    {
                        let mut encoder = Base64Encoder::new(&mut header_value, &BASE64_STANDARD);
                        write!(encoder, "{}:", i).unwrap();
                        write!(encoder, "{}", k).unwrap();
                    }
                    request_builder.header(
                        AUTHORIZATION,
                        HeaderValue::from_bytes(&header_value).unwrap(),
                    )
                }
                #[cfg(feature = "aws-auth")]
                Credentials::AwsSigV4(_, _) => request_builder,
            }
        }

        // default headers first, overwrite with any provided
        let mut request_headers = HeaderMap::with_capacity(4 + headers.len());
        request_headers.insert(CONTENT_TYPE, HeaderValue::from_static(DEFAULT_CONTENT_TYPE));
        request_headers.insert(ACCEPT, HeaderValue::from_static(DEFAULT_ACCEPT));
        request_headers.insert(USER_AGENT, HeaderValue::from_static(DEFAULT_USER_AGENT));
        for (name, value) in headers {
            request_headers.insert(name.unwrap(), value);
        }

        request_builder = request_builder.headers(request_headers);

        if let Some(b) = body {
            let bytes = if let Some(bytes) = b.bytes() {
                bytes
            } else {
                let mut bytes_mut = self.bytes_mut();
                b.write(&mut bytes_mut)?;
                bytes_mut.split().freeze()
            };

            request_builder = request_builder.body(bytes);
        }

        if let Some(q) = query_string {
            request_builder = request_builder.query(q);
        }

        #[cfg_attr(not(feature = "aws-auth"), allow(unused_mut))]
        let mut request = request_builder.build()?;

        #[cfg(feature = "aws-auth")]
        if let Some(Credentials::AwsSigV4(credentials_provider, region)) = &self.credentials {
            super::aws_auth::sign_request(
                &mut request,
                credentials_provider,
                &self.service_name,
                region,
            )
            .await
            .map_err(|e| crate::error::lib(format!("AWSV4 Signing Failed: {}", e)))?;
        }

        let response = self.client.execute(request).await;
        match response {
            Ok(r) => Ok(Response::new(r, method)),
            Err(e) => Err(e.into()),
        }
    }
}

impl Default for Transport {
    fn default() -> Self {
        TransportBuilder::default().build().unwrap()
    }
}

/// A pool of [Connection]s, used to make API calls to Elasticsearch.
///
/// A [ConnectionPool] manages the connections, with different implementations determining how
/// to get the next [Connection]. The simplest type of [ConnectionPool] is [SingleNodeConnectionPool],
/// which manages only a single connection, but other implementations may manage connections more
/// dynamically at runtime, based upon the response to API calls.
pub trait ConnectionPool: Debug + dyn_clone::DynClone + Sync + Send {
    /// Gets a reference to the next [Connection]
    fn next(&self) -> &Connection;
}

clone_trait_object!(ConnectionPool);

/// A connection pool that manages the single connection to an Elasticsearch cluster.
#[derive(Debug, Clone)]
pub struct SingleNodeConnectionPool {
    connection: Connection,
}

impl SingleNodeConnectionPool {
    /// Creates a new instance of [SingleNodeConnectionPool].
    pub fn new(url: Url) -> Self {
        Self {
            connection: Connection::new(url),
        }
    }
}

impl Default for SingleNodeConnectionPool {
    /// Creates a default instance of [SingleNodeConnectionPool], using [DEFAULT_ADDRESS].
    fn default() -> Self {
        Self::new(Url::parse(DEFAULT_ADDRESS).unwrap())
    }
}

impl ConnectionPool for SingleNodeConnectionPool {
    /// Gets a reference to the next [Connection]
    fn next(&self) -> &Connection {
        &self.connection
    }
}

#[cfg(test)]
pub mod tests {
    use super::*;
    #[cfg(any(feature = "native-tls", feature = "rustls-tls"))]
    use crate::auth::ClientCertificate;
    use url::Url;

    #[test]
    #[cfg(feature = "native-tls")]
    fn invalid_pkcs12_cert_credentials() {
        let conn_pool = SingleNodeConnectionPool::default();
        let builder = TransportBuilder::new(conn_pool)
            .auth(ClientCertificate::Pkcs12(b"Nonsense".to_vec(), None).into());

        let res = builder.build();
        assert!(res.is_err());
    }

    #[test]
    #[cfg(feature = "rustls-tls")]
    fn invalid_pem_cert_credentials() {
        let conn_pool = SingleNodeConnectionPool::default();
        let builder = TransportBuilder::new(conn_pool)
            .auth(ClientCertificate::Pem(b"Nonsense".to_vec()).into());

        let res = builder.build();
        assert!(res.is_err());
    }

    #[test]
    fn connection_url_with_no_trailing_slash() {
        let url = Url::parse("http://10.1.2.3/path_with_no_trailing_slash").unwrap();
        let conn = Connection::new(url);
        assert_eq!(
            conn.url.as_str(),
            "http://10.1.2.3/path_with_no_trailing_slash/"
        );
    }

    #[test]
    fn connection_url_with_trailing_slash() {
        let url = Url::parse("http://10.1.2.3/path_with_trailing_slash/").unwrap();
        let conn = Connection::new(url);
        assert_eq!(
            conn.url.as_str(),
            "http://10.1.2.3/path_with_trailing_slash/"
        );
    }

    #[test]
    fn connection_url_with_no_path_and_no_trailing_slash() {
        let url = Url::parse("http://10.1.2.3").unwrap();
        let conn = Connection::new(url);
        assert_eq!(conn.url.as_str(), "http://10.1.2.3/");
    }

    #[test]
    fn connection_url_with_no_path_and_trailing_slash() {
        let url = Url::parse("http://10.1.2.3/").unwrap();
        let conn = Connection::new(url);
        assert_eq!(conn.url.as_str(), "http://10.1.2.3/");
    }
}