diff --git a/src/Microsoft.OpenApi/Reader/OpenApiModelFactory.cs b/src/Microsoft.OpenApi/Reader/OpenApiModelFactory.cs index 6f02d9426..d6b57fbbb 100644 --- a/src/Microsoft.OpenApi/Reader/OpenApiModelFactory.cs +++ b/src/Microsoft.OpenApi/Reader/OpenApiModelFactory.cs @@ -2,6 +2,7 @@ // Licensed under the MIT license. using System; +using System.Diagnostics.CodeAnalysis; using System.IO; using System.Linq; using System.Security; @@ -362,7 +363,23 @@ private static string InspectInputFormat(string input) return input.StartsWith("{", StringComparison.OrdinalIgnoreCase) || input.StartsWith("[", StringComparison.OrdinalIgnoreCase) ? OpenApiConstants.Json : OpenApiConstants.Yaml; } - private static string InspectStreamFormat(Stream stream) + /// + /// Reads the initial bytes of the stream to determine if it is JSON or YAML. + /// + /// + /// It is important NOT TO change the stream type from MemoryStream. + /// In Asp.Net core 3.0+ we could get passed a stream from a request or response body. + /// In such case, we CAN'T use the ReadByte method as it throws NotSupportedException. + /// Therefore, we need to ensure that the stream is a MemoryStream before calling this method. + /// Maintaining this type ensures there won't be any unforeseen wrong usage of the method. + /// + /// The stream to inspect + /// The format of the stream. + private static string InspectStreamFormat(MemoryStream stream) + { + return TryInspectStreamFormat(stream, out var format) ? format! : throw new InvalidOperationException("Could not determine the format of the stream."); + } + private static bool TryInspectStreamFormat(Stream stream, out string? format) { #if NET6_0_OR_GREATER ArgumentNullException.ThrowIfNull(stream); @@ -370,65 +387,69 @@ private static string InspectStreamFormat(Stream stream) if (stream is null) throw new ArgumentNullException(nameof(stream)); #endif - long initialPosition = stream.Position; - int firstByte = stream.ReadByte(); - - // Skip whitespace if present and read the next non-whitespace byte - if (char.IsWhiteSpace((char)firstByte)) + try { - firstByte = stream.ReadByte(); - } + var initialPosition = stream.Position; + var firstByte = (char)stream.ReadByte(); + + // Skip whitespace if present and read the next non-whitespace byte + if (char.IsWhiteSpace(firstByte)) + { + firstByte = (char)stream.ReadByte(); + } - stream.Position = initialPosition; // Reset the stream position to the beginning + stream.Position = initialPosition; // Reset the stream position to the beginning - char firstChar = (char)firstByte; - return firstChar switch + format = firstByte switch + { + '{' or '[' => OpenApiConstants.Json, // If the first character is '{' or '[', assume JSON + _ => OpenApiConstants.Yaml // Otherwise assume YAML + }; + return true; + } + catch (NotSupportedException) + { + // https://github.com/dotnet/aspnetcore/blob/c9d0750396e1d319301255ba61842721ab72ab10/src/Servers/Kestrel/Core/src/Internal/Http/HttpResponseStream.cs#L40 + } +#if NETSTANDARD2_1_OR_GREATER || NETCOREAPP || NET5_0_OR_GREATER + catch (InvalidOperationException ex) when (ex.Message.Contains("AllowSynchronousIO", StringComparison.Ordinal)) +#else + catch (InvalidOperationException ex) when (ex.Message.Contains("AllowSynchronousIO")) +#endif { - '{' or '[' => OpenApiConstants.Json, // If the first character is '{' or '[', assume JSON - _ => OpenApiConstants.Yaml // Otherwise assume YAML - }; + // https://github.com/dotnet/aspnetcore/blob/c9d0750396e1d319301255ba61842721ab72ab10/src/Servers/HttpSys/src/RequestProcessing/RequestStream.cs#L100-L108 + // https://github.com/dotnet/aspnetcore/blob/c9d0750396e1d319301255ba61842721ab72ab10/src/Servers/IIS/IIS/src/Core/HttpRequestStream.cs#L24-L30 + // https://github.com/dotnet/aspnetcore/blob/c9d0750396e1d319301255ba61842721ab72ab10/src/Servers/Kestrel/Core/src/Internal/Http/HttpRequestStream.cs#L54-L60 + } + format = null; + return false; } + private static async Task CopyToMemoryStreamAsync(Stream input, CancellationToken token) + { + var bufferStream = new MemoryStream(); +#if NETSTANDARD2_1_OR_GREATER || NETCOREAPP || NET5_0_OR_GREATER + await input.CopyToAsync(bufferStream, token).ConfigureAwait(false); +#else + await input.CopyToAsync(bufferStream, 81920, token).ConfigureAwait(false); +#endif + bufferStream.Position = 0; + return bufferStream; + } + private static async Task<(Stream, string)> PrepareStreamForReadingAsync(Stream input, string? format, CancellationToken token = default) { Stream preparedStream = input; - if (!input.CanSeek) + if (input is MemoryStream ms) { - // Use a temporary buffer to read a small portion for format detection - using var bufferStream = new MemoryStream(); - await input.CopyToAsync(bufferStream, 1024, token).ConfigureAwait(false); - bufferStream.Position = 0; - - // Inspect the format from the buffered portion - format ??= InspectStreamFormat(bufferStream); - - // If format is JSON, no need to buffer further — use the original stream. - if (format.Equals(OpenApiConstants.Json, StringComparison.OrdinalIgnoreCase)) - { - preparedStream = input; - } - else - { - // YAML or other non-JSON format; copy remaining input to a new stream. - preparedStream = new MemoryStream(); - bufferStream.Position = 0; - await bufferStream.CopyToAsync(preparedStream, 81920, token).ConfigureAwait(false); // Copy buffered portion - await input.CopyToAsync(preparedStream, 81920, token).ConfigureAwait(false); // Copy remaining data - preparedStream.Position = 0; - } + format ??= InspectStreamFormat(ms); } - else + else if (!input.CanSeek || !TryInspectStreamFormat(input, out format!)) { - format ??= InspectStreamFormat(input); - - if (!format.Equals(OpenApiConstants.Json, StringComparison.OrdinalIgnoreCase)) - { - // Buffer stream for non-JSON formats (e.g., YAML) since they require synchronous reading - preparedStream = new MemoryStream(); - await input.CopyToAsync(preparedStream, 81920, token).ConfigureAwait(false); - preparedStream.Position = 0; - } + // Copy to a MemoryStream to enable seeking and perform format inspection + var bufferStream = await CopyToMemoryStreamAsync(input, token).ConfigureAwait(false); + return await PrepareStreamForReadingAsync(bufferStream, format, token).ConfigureAwait(false); } return (preparedStream, format); diff --git a/test/Microsoft.OpenApi.Tests/Reader/OpenApiModelFactoryTests.cs b/test/Microsoft.OpenApi.Tests/Reader/OpenApiModelFactoryTests.cs index 26bd6d472..464e8fdf3 100644 --- a/test/Microsoft.OpenApi.Tests/Reader/OpenApiModelFactoryTests.cs +++ b/test/Microsoft.OpenApi.Tests/Reader/OpenApiModelFactoryTests.cs @@ -3,6 +3,7 @@ using System.Threading.Tasks; using System.IO; using System; +using System.Threading; namespace Microsoft.OpenApi.Tests.Reader; @@ -119,4 +120,365 @@ await File.WriteAllTextAsync(tempFilePathReferrer, Assert.NotNull(readResult.Document.Components); Assert.Equal(baseUri, readResult.Document.BaseUri); } + private readonly string documentJson = +""" +{ + "openapi": "3.1.0", + "info": { + "title": "Sample API", + "version": "1.0.0" + }, + "paths": {} +} +"""; + private readonly string documentYaml = +""" +openapi: 3.1.0 +info: + title: Sample API + version: 1.0.0 +paths: {} +"""; + [Fact] + public async Task CanLoadANonSeekableStreamInJsonAndDetectFormat() + { + // Given + using var memoryStream = new MemoryStream(System.Text.Encoding.UTF8.GetBytes(documentJson)); + using var nonSeekableStream = new NonSeekableStream(memoryStream); + + // When + var (document, _) = await OpenApiDocument.LoadAsync(nonSeekableStream); + + // Then + Assert.NotNull(document); + Assert.Equal("Sample API", document.Info.Title); + } + + [Fact] + public async Task CanLoadANonSeekableStreamInYamlAndDetectFormat() + { + // Given + using var memoryStream = new MemoryStream(System.Text.Encoding.UTF8.GetBytes(documentYaml)); + using var nonSeekableStream = new NonSeekableStream(memoryStream); + var settings = new OpenApiReaderSettings(); + settings.AddYamlReader(); + + // When + var (document, _) = await OpenApiDocument.LoadAsync(nonSeekableStream, settings: settings); + + // Then + Assert.NotNull(document); + Assert.Equal("Sample API", document.Info.Title); + } + + [Fact] + public async Task CanLoadAnAsyncOnlyStreamInJsonAndDetectFormat() + { + // Given + await using var memoryStream = new MemoryStream(System.Text.Encoding.UTF8.GetBytes(documentJson)); + await using var nonSeekableStream = new AsyncOnlyStream(memoryStream); + + // When + var (document, _) = await OpenApiDocument.LoadAsync(nonSeekableStream); + + // Then + Assert.NotNull(document); + Assert.Equal("Sample API", document.Info.Title); + } + + [Fact] + public async Task CanLoadAnAsyncOnlyStreamInYamlAndDetectFormat() + { + // Given + await using var memoryStream = new MemoryStream(System.Text.Encoding.UTF8.GetBytes(documentYaml)); + await using var nonSeekableStream = new AsyncOnlyStream(memoryStream); + var settings = new OpenApiReaderSettings(); + settings.AddYamlReader(); + + // When + var (document, _) = await OpenApiDocument.LoadAsync(nonSeekableStream, settings: settings); + + // Then + Assert.NotNull(document); + Assert.Equal("Sample API", document.Info.Title); + } + + public sealed class AsyncOnlyStream : Stream + { + private readonly Stream _innerStream; + public AsyncOnlyStream(Stream stream) : base() + { + _innerStream = stream; + } + public override bool CanSeek => _innerStream.CanSeek; + + public override long Position { get => _innerStream.Position; set => throw new NotSupportedException("Blocking operations are not supported"); } + + public override bool CanRead => _innerStream.CanRead; + + public override bool CanWrite => _innerStream.CanWrite; + + public override long Length => _innerStream.Length; + public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback callback, object state) + { + return _innerStream.BeginRead(buffer, offset, count, callback, state); + } + + public override void Flush() + { + throw new NotSupportedException("Blocking operations are not supported."); + } + + public override int Read(byte[] buffer, int offset, int count) + { + throw new NotSupportedException("Blocking operations are not supported."); + } + + public override long Seek(long offset, SeekOrigin origin) + { + throw new NotSupportedException("Blocking operations are not supported."); + } + + public override void SetLength(long value) + { + _innerStream.SetLength(value); + } + + public override void Write(byte[] buffer, int offset, int count) + { + throw new NotSupportedException("Blocking operations are not supported."); + } + protected override void Dispose(bool disposing) + { + throw new NotSupportedException("Blocking operations are not supported."); + } + + public override async ValueTask DisposeAsync() + { + await _innerStream.DisposeAsync(); + await base.DisposeAsync(); + } + + public override Task CopyToAsync(Stream destination, int bufferSize, CancellationToken cancellationToken) + { + return _innerStream.CopyToAsync(destination, bufferSize, cancellationToken); + } + + public override bool CanTimeout => _innerStream.CanTimeout; + + public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback callback, object state) + { + return _innerStream.BeginWrite(buffer, offset, count, callback, state); + } + + public override void CopyTo(Stream destination, int bufferSize) + { + throw new NotSupportedException("Blocking operations are not supported."); + } + + public override void Close() + { + _innerStream.Close(); + } + + public override int EndRead(IAsyncResult asyncResult) + { + return _innerStream.EndRead(asyncResult); + } + + public override void EndWrite(IAsyncResult asyncResult) + { + _innerStream.EndWrite(asyncResult); + } + + public override int ReadByte() + { + throw new NotSupportedException("Blocking operations are not supported."); + } + + public override void WriteByte(byte value) + { + throw new NotSupportedException("Blocking operations are not supported."); + } + + public override Task FlushAsync(CancellationToken cancellationToken) + { + return _innerStream.FlushAsync(cancellationToken); + } + + public override int Read(Span buffer) + { + throw new NotSupportedException("Blocking operations are not supported."); + } + + public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + return _innerStream.ReadAsync(buffer, offset, count, cancellationToken); + } + + public override ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) + { + return _innerStream.ReadAsync(buffer, cancellationToken); + } + + public override int ReadTimeout { get => _innerStream.ReadTimeout; set => _innerStream.ReadTimeout = value; } + + public override void Write(ReadOnlySpan buffer) + { + throw new NotSupportedException("Blocking operations are not supported."); + } + + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + return _innerStream.WriteAsync(buffer, offset, count, cancellationToken); + } + + public override ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken = default) + { + return _innerStream.WriteAsync(buffer, cancellationToken); + } + + public override int WriteTimeout { get => _innerStream.WriteTimeout; set => _innerStream.WriteTimeout = value; } + + } + + public sealed class NonSeekableStream : Stream + { + private readonly Stream _innerStream; + public NonSeekableStream(Stream stream) : base() + { + _innerStream = stream; + } + public override bool CanSeek => false; + + public override long Position { get => _innerStream.Position; set => throw new NotSupportedException("Seeking is not supported."); } + + public override bool CanRead => _innerStream.CanRead; + + public override bool CanWrite => _innerStream.CanWrite; + + public override long Length => _innerStream.Length; + public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback callback, object state) + { + return _innerStream.BeginRead(buffer, offset, count, callback, state); + } + + public override void Flush() + { + _innerStream.Flush(); + } + + public override int Read(byte[] buffer, int offset, int count) + { + return _innerStream.Read(buffer, offset, count); + } + + public override long Seek(long offset, SeekOrigin origin) + { + throw new NotSupportedException("Seeking is not supported."); + } + + public override void SetLength(long value) + { + _innerStream.SetLength(value); + } + + public override void Write(byte[] buffer, int offset, int count) + { + _innerStream.Write(buffer, offset, count); + } + protected override void Dispose(bool disposing) + { + _innerStream.Dispose(); + base.Dispose(disposing); + } + + public override async ValueTask DisposeAsync() + { + await _innerStream.DisposeAsync(); + await base.DisposeAsync(); + } + + public override Task CopyToAsync(Stream destination, int bufferSize, CancellationToken cancellationToken) + { + return _innerStream.CopyToAsync(destination, bufferSize, cancellationToken); + } + + public override bool CanTimeout => _innerStream.CanTimeout; + + public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback callback, object state) + { + return _innerStream.BeginWrite(buffer, offset, count, callback, state); + } + + public override void CopyTo(Stream destination, int bufferSize) + { + _innerStream.CopyTo(destination, bufferSize); + } + + public override void Close() + { + _innerStream.Close(); + } + + public override int EndRead(IAsyncResult asyncResult) + { + return _innerStream.EndRead(asyncResult); + } + + public override void EndWrite(IAsyncResult asyncResult) + { + _innerStream.EndWrite(asyncResult); + } + + public override int ReadByte() + { + return _innerStream.ReadByte(); + } + + public override void WriteByte(byte value) + { + _innerStream.WriteByte(value); + } + + public override Task FlushAsync(CancellationToken cancellationToken) + { + return _innerStream.FlushAsync(cancellationToken); + } + + public override int Read(Span buffer) + { + return _innerStream.Read(buffer); + } + + public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + return _innerStream.ReadAsync(buffer, offset, count, cancellationToken); + } + + public override ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) + { + return _innerStream.ReadAsync(buffer, cancellationToken); + } + + public override int ReadTimeout { get => _innerStream.ReadTimeout; set => _innerStream.ReadTimeout = value; } + + public override void Write(ReadOnlySpan buffer) + { + _innerStream.Write(buffer); + } + + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + return _innerStream.WriteAsync(buffer, offset, count, cancellationToken); + } + + public override ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken = default) + { + return _innerStream.WriteAsync(buffer, cancellationToken); + } + + public override int WriteTimeout { get => _innerStream.WriteTimeout; set => _innerStream.WriteTimeout = value; } + + } }