/* SPDX-License-Identifier: Apache-2.0 * * The OpenSearch Contributors require contributions made to * this file be licensed under the Apache-2.0 license or a * compatible open source license. */ /* * Modifications Copyright OpenSearch Contributors. See * GitHub history for details. * * Licensed to Elasticsearch B.V. under one or more contributor * license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright * ownership. Elasticsearch B.V. licenses this file to you under * the Apache License, Version 2.0 (the "License"); you may * not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ using System; using System.Collections.Generic; using System.IO; using System.Threading; using System.Threading.Tasks; using OpenSearch.Net.Extensions; namespace OpenSearch.Net { // ReSharper disable once UnusedTypeParameter public interface IPostData { void Write(Stream writableStream, IConnectionConfigurationValues settings); Task WriteAsync(Stream writableStream, IConnectionConfigurationValues settings, CancellationToken token); } public enum PostType { ByteArray, #if NETSTANDARD2_1 ReadOnlyMemory, #endif LiteralString, EnumerableOfString, EnumerableOfObject, StreamHandler, Serializable } public abstract class PostData { protected const int BufferSize = 81920; protected const string NewLineString = "\n"; protected static readonly byte[] NewLineByteArray = { (byte)'\n' }; //TODO internal set?; public bool? DisableDirectStreaming { get; set; } public PostType Type { get; protected set; } public byte[] WrittenBytes { get; protected set; } public static PostData Empty => new PostData(string.Empty); public abstract void Write(Stream writableStream, IConnectionConfigurationValues settings); public abstract Task WriteAsync(Stream writableStream, IConnectionConfigurationValues settings, CancellationToken cancellationToken); public static implicit operator PostData(byte[] byteArray) => Bytes(byteArray); public static implicit operator PostData(string literalString) => String(literalString); public static SerializableData Serializable(T o) => new SerializableData(o); public static PostData MultiJson(IEnumerable listOfString) => new PostData(listOfString); public static PostData MultiJson(IEnumerable listOfObjects) => new PostData(listOfObjects); public static PostData Bytes(byte[] bytes) => new PostData(bytes); #if NETSTANDARD2_1 public static PostData ReadOnlyMemory(ReadOnlyMemory bytes) => new PostData(bytes); #endif public static PostData String(string serializedString) => new PostData(serializedString); public static PostData StreamHandler(T state, Action syncWriter, Func asyncWriter) => new StreamableData(state, syncWriter, asyncWriter); protected void BufferIfNeeded(IConnectionConfigurationValues settings, ref MemoryStream buffer, ref Stream stream) { var disableDirectStreaming = DisableDirectStreaming ?? settings.DisableDirectStreaming; if (!disableDirectStreaming) return; buffer = settings.MemoryStreamFactory.Create(); stream = buffer; } protected void FinishStream(Stream writableStream, MemoryStream buffer, IConnectionConfigurationValues settings) { var disableDirectStreaming = DisableDirectStreaming ?? settings.DisableDirectStreaming; if (buffer == null || !disableDirectStreaming) return; buffer.Position = 0; buffer.CopyTo(writableStream, BufferSize); WrittenBytes ??= buffer.ToArray(); } protected async #if NETSTANDARD2_1 ValueTask #else Task #endif FinishStreamAsync(Stream writableStream, MemoryStream buffer, IConnectionConfigurationValues settings, CancellationToken ctx) { var disableDirectStreaming = DisableDirectStreaming ?? settings.DisableDirectStreaming; if (buffer == null || !disableDirectStreaming) return; buffer.Position = 0; await buffer.CopyToAsync(writableStream, BufferSize, ctx).ConfigureAwait(false); WrittenBytes ??= buffer.ToArray(); } } public class PostData : PostData, IPostData { private readonly IEnumerable _enumerableOfObject; private readonly IEnumerable _enumerableOfStrings; private readonly string _literalString; #if NETSTANDARD2_1 private readonly ReadOnlyMemory _memoryOfBytes; #endif protected internal PostData(byte[] item) { WrittenBytes = item; Type = PostType.ByteArray; } #if NETSTANDARD2_1 protected internal PostData(ReadOnlyMemory item) { _memoryOfBytes = item; Type = PostType.ReadOnlyMemory; } #endif protected internal PostData(string item) { _literalString = item; Type = PostType.LiteralString; } protected internal PostData(IEnumerable item) { _enumerableOfStrings = item; Type = PostType.EnumerableOfString; } protected internal PostData(IEnumerable item) { _enumerableOfObject = item; Type = PostType.EnumerableOfObject; } public override void Write(Stream writableStream, IConnectionConfigurationValues settings) { MemoryStream buffer = null; var stream = writableStream; var disableDirectStreaming = DisableDirectStreaming ?? settings.DisableDirectStreaming; switch (Type) { case PostType.ByteArray: if (WrittenBytes == null) return; if (!disableDirectStreaming) stream.Write(WrittenBytes, 0, WrittenBytes.Length); else buffer = settings.MemoryStreamFactory.Create(WrittenBytes); break; #if NETSTANDARD2_1 case PostType.ReadOnlyMemory: if (_memoryOfBytes.IsEmpty) return; if (!disableDirectStreaming) stream.Write(_memoryOfBytes.Span); else { WrittenBytes ??= _memoryOfBytes.Span.ToArray(); buffer = settings.MemoryStreamFactory.Create(WrittenBytes); } break; #endif case PostType.LiteralString: if (string.IsNullOrEmpty(_literalString)) return; var stringBytes = WrittenBytes ?? _literalString.Utf8Bytes(); WrittenBytes ??= stringBytes; if (!disableDirectStreaming) stream.Write(stringBytes, 0, stringBytes.Length); else buffer = settings.MemoryStreamFactory.Create(stringBytes); break; case PostType.EnumerableOfString: { if (_enumerableOfStrings == null) return; using var enumerator = _enumerableOfStrings.GetEnumerator(); if (!enumerator.MoveNext()) return; BufferIfNeeded(settings, ref buffer, ref stream); do { var bytes = enumerator.Current.Utf8Bytes(); stream.Write(bytes, 0, bytes.Length); stream.Write(NewLineByteArray, 0, 1); } while (enumerator.MoveNext()); break; } case PostType.EnumerableOfObject: { if (_enumerableOfObject == null) return; using var enumerator = _enumerableOfObject.GetEnumerator(); if (!enumerator.MoveNext()) return; BufferIfNeeded(settings, ref buffer, ref stream); do { var o = enumerator.Current; settings.RequestResponseSerializer.Serialize(o, stream, SerializationFormatting.None); stream.Write(NewLineByteArray, 0, 1); } while (enumerator.MoveNext()); break; } case PostType.StreamHandler: var streamHandlerException = $"{nameof(PostData)} cannot handle {nameof(PostType.StreamHandler)} data. " + $"Use {typeof(StreamableData<>).FullName} through {nameof(PostData)}.{nameof(StreamHandler)}() for streamable data"; throw new Exception(streamHandlerException); case PostType.Serializable: var serializableException = $"{nameof(PostData)} cannot handle {nameof(PostType.Serializable)} data. " + $"Use {typeof(SerializableData<>).FullName} through {nameof(PostData)}.{nameof(Serializable)}() for serializable data"; throw new Exception(serializableException); default: throw new ArgumentOutOfRangeException(); } FinishStream(writableStream, buffer, settings); } public override async Task WriteAsync(Stream writableStream, IConnectionConfigurationValues settings, CancellationToken cancellationToken) { MemoryStream buffer = null; var stream = writableStream; var disableDirectStreaming = DisableDirectStreaming ?? settings.DisableDirectStreaming; switch (Type) { case PostType.ByteArray: if (!disableDirectStreaming) await stream.WriteAsync(WrittenBytes, 0, WrittenBytes.Length, cancellationToken).ConfigureAwait(false); else buffer = settings.MemoryStreamFactory.Create(WrittenBytes); break; #if NETSTANDARD2_1 case PostType.ReadOnlyMemory: if (_memoryOfBytes.IsEmpty) return; if (!disableDirectStreaming) stream.Write(_memoryOfBytes.Span); else { WrittenBytes ??= _memoryOfBytes.Span.ToArray(); buffer = settings.MemoryStreamFactory.Create(WrittenBytes); } break; #endif case PostType.LiteralString: if (string.IsNullOrEmpty(_literalString)) return; var stringBytes = WrittenBytes ?? _literalString.Utf8Bytes(); WrittenBytes ??= stringBytes; if (!disableDirectStreaming) await stream.WriteAsync(stringBytes, 0, stringBytes.Length, cancellationToken).ConfigureAwait(false); else buffer = settings.MemoryStreamFactory.Create(stringBytes); break; case PostType.EnumerableOfString: { if (_enumerableOfStrings == null) return; using var enumerator = _enumerableOfStrings.GetEnumerator(); if (!enumerator.MoveNext()) return; BufferIfNeeded(settings, ref buffer, ref stream); do { var bytes = enumerator.Current.Utf8Bytes(); await stream.WriteAsync(bytes, 0, bytes.Length, cancellationToken).ConfigureAwait(false); await stream.WriteAsync(NewLineByteArray, 0, 1, cancellationToken).ConfigureAwait(false); } while (enumerator.MoveNext()); break; } case PostType.EnumerableOfObject: { if (_enumerableOfObject == null) return; using var enumerator = _enumerableOfObject.GetEnumerator(); if (!enumerator.MoveNext()) return; BufferIfNeeded(settings, ref buffer, ref stream); do { var o = enumerator.Current; await settings.RequestResponseSerializer.SerializeAsync(o, stream, SerializationFormatting.None, cancellationToken) .ConfigureAwait(false); await stream.WriteAsync(NewLineByteArray, 0, 1, cancellationToken).ConfigureAwait(false); } while (enumerator.MoveNext()); break; } case PostType.StreamHandler: throw new Exception("PostData is not expected/capable to handle streamable data, use StreamableData instead"); case PostType.Serializable: throw new Exception("PostData is not expected/capable to handle contain serializable, use SerializableData instead"); default: throw new ArgumentOutOfRangeException(); } await FinishStreamAsync(writableStream, buffer, settings, cancellationToken).ConfigureAwait(false); } } }