BililiveRecorder/BililiveRecorder.Web/BasicAuthMiddleware.cs

75 lines
2.8 KiB
C#
Raw Normal View History

2022-06-08 00:15:05 +08:00
using System;
using System.IO;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
2022-06-08 01:19:38 +08:00
using Microsoft.Extensions.DependencyInjection;
2022-06-08 00:15:05 +08:00
using Microsoft.Extensions.FileProviders;
using Microsoft.Net.Http.Headers;
namespace BililiveRecorder.Web
{
public class BasicAuthMiddleware
{
private readonly RequestDelegate next;
private readonly ManifestEmbeddedFileProvider fileProvider;
private const string BasicAndSpace = "Basic ";
private static string? Html401Page;
2022-06-08 01:19:38 +08:00
public BasicAuthMiddleware(RequestDelegate next, ManifestEmbeddedFileProvider fileProvider)
2022-06-08 00:15:05 +08:00
{
this.next = next ?? throw new ArgumentNullException(nameof(next));
this.fileProvider = fileProvider ?? throw new ArgumentNullException(nameof(fileProvider));
}
public Task InvokeAsync(HttpContext context)
{
2022-06-08 01:19:38 +08:00
if (context.RequestServices.GetService<BasicAuthCredential>() is not { } credential)
2022-06-08 00:15:05 +08:00
{
// 没有启用身份验证
return this.next(context);
}
string headerValue = context.Request.Headers["Authorization"];
if (string.IsNullOrEmpty(headerValue) ||
!headerValue.StartsWith(BasicAndSpace, StringComparison.OrdinalIgnoreCase))
{
return this.ResponseWith401Async(context);
}
var requestCredential = headerValue[BasicAndSpace.Length..].Trim();
if (string.IsNullOrEmpty(requestCredential))
{
return this.ResponseWith401Async(context);
}
2022-06-08 01:19:38 +08:00
if (credential.EncoededValue.Equals(requestCredential, StringComparison.Ordinal))
2022-06-08 00:15:05 +08:00
{
return this.next(context);
}
else
{
return this.ResponseWith401Async(context);
}
}
private async Task ResponseWith401Async(HttpContext context)
{
context.Response.StatusCode = 401;
context.Response.ContentType = "text/html";
context.Response.Headers.Append(HeaderNames.WWWAuthenticate, $"{BasicAndSpace}realm=\"BililiveRecorder {GitVersionInformation.FullSemVer}\"");
if (Html401Page is null)
{
using var file = this.fileProvider.GetFileInfo("/401.html").CreateReadStream();
using var reader = new StreamReader(file);
var str = await reader.ReadToEndAsync().ConfigureAwait(false);
Html401Page = str.Replace("__VERSION__", GitVersionInformation.FullSemVer).Replace("__FULL_VERSION__", GitVersionInformation.InformationalVersion);
}
await context.Response.WriteAsync(Html401Page, System.Text.Encoding.UTF8).ConfigureAwait(false);
}
}
}