Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,29 @@

namespace Microsoft.AspNetCore.Routing
{
public class ClientIdentification
{
public string IpAddress { get; set; } = string.Empty;
public string UserAgent { get; set; } = string.Empty;
public string Host { get; set; } = string.Empty;
public string Protocol { get; set; } = string.Empty;
public Dictionary<string, string> AdditionalHeaders { get; set; } = new();
}

public static class IdentityComponentsEndpointRouteBuilderExtensions
{
public delegate Task LoginHandler(string username, string group, IList<string> roles, ClientIdentification clientInfo);
public delegate Task LogoutHandler(string username);

// These endpoints are required by the Identity Razor components defined in the /Components/Account/Pages directory of this project.
public static IEndpointRouteBuilder MapAdditionalIdentityEndpoints(this IEndpointRouteBuilder endpoints)
public static IEndpointRouteBuilder MapAdditionalIdentityEndpoints(this IEndpointRouteBuilder endpoints, LoginHandler? loginHandler = null, LogoutHandler? logoutHandler = null)
{
ArgumentNullException.ThrowIfNull(endpoints);

endpoints.MapPost("/Login", async (
HttpContext context,
[FromServices] SignInManager<User> signInManager) =>
[FromServices] SignInManager<User> signInManager,
[FromServices] UserManager<User> userManager) =>
{
try
{
Expand All @@ -42,7 +55,18 @@ public static IEndpointRouteBuilder MapAdditionalIdentityEndpoints(this IEndpoin
var result = await signInManager.PasswordSignInAsync(username, password, false, lockoutOnFailure: false);

if (result.Succeeded)
{
// Get user details for the login handler
var user = await userManager.FindByNameAsync(username);
if (user != null && loginHandler != null)
{
var roles = await userManager.GetRolesAsync(user);
var clientInfo = GetClientIdentification(context);
await loginHandler(username, user.Group, roles, clientInfo);
}

return TypedResults.LocalRedirect(string.IsNullOrEmpty(returnUrl) ? "/" : !returnUrl.StartsWith("/") ? "/" + returnUrl : returnUrl.StartsWith("//") ? "/" + returnUrl.TrimStart('/') : returnUrl);
}
else // Redirect back to login with error
return TypedResults.LocalRedirect($"/Security/Login?error=invalid&returnUrl={Uri.EscapeDataString(returnUrl ?? "/")}");
}
Expand All @@ -60,6 +84,14 @@ public static IEndpointRouteBuilder MapAdditionalIdentityEndpoints(this IEndpoin
{
var formCollection = await context.Request.ReadFormAsync();
var returnUrl = formCollection["ReturnUrl"].ToString();

// Get username before signing out
string? username = context.User.Identity?.Name;

if (!string.IsNullOrEmpty(username) && logoutHandler != null)
{
await logoutHandler(username);
}

await signInManager.SignOutAsync();
return TypedResults.LocalRedirect(string.IsNullOrEmpty(returnUrl) ? "/" : !returnUrl.StartsWith("/") ? "/" + returnUrl : returnUrl.StartsWith("//") ? "/" + returnUrl.TrimStart('/') : returnUrl);
Expand Down Expand Up @@ -97,19 +129,37 @@ public static IEndpointRouteBuilder MapAdditionalIdentityEndpoints(this IEndpoin
if (TokenHasher.VerifyToken(externalAuthId, currentUser.ExternalAuthId))
{
// Same user - log out
if (logoutHandler != null)
{
var clientInfo = GetClientIdentification(context);
await logoutHandler(currentUser.UserName!);
}
await signInManager.SignOutAsync();
return TypedResults.LocalRedirect(string.IsNullOrEmpty(returnUrl) ? "/" : !returnUrl.StartsWith("/") ? "/" + returnUrl : returnUrl.StartsWith("//") ? "/" + returnUrl.TrimStart('/') : returnUrl);
}
else
{
// Different user - log out current and log in new
if (logoutHandler != null)
{
var clientInfo = GetClientIdentification(context);
await logoutHandler(currentUser.UserName!);
}
await signInManager.SignOutAsync();
}
}

// Sign in the user
await signInManager.SignInAsync(user, isPersistent: false);

// Invoke the login handler
if (loginHandler != null)
{
var roles = await userManager.GetRolesAsync(user);
var clientInfo = GetClientIdentification(context);
await loginHandler(user.UserName!, user.Group, roles, clientInfo);
}

if (!string.IsNullOrEmpty(returnUrl))
{
string[] split = returnUrl.Split('?');
Expand All @@ -132,5 +182,50 @@ public static IEndpointRouteBuilder MapAdditionalIdentityEndpoints(this IEndpoin

return endpoints;
}

private static ClientIdentification GetClientIdentification(HttpContext context)
{
var clientInfo = new ClientIdentification
{
IpAddress = GetClientIpAddress(context),
UserAgent = context.Request.Headers["User-Agent"].FirstOrDefault() ?? "Unknown",
Host = context.Request.Host.ToString(),
Protocol = context.Request.Protocol
};

// Add additional headers that might be useful for client identification
var headersToCapture = new[] { "Referer", "Accept-Language", "X-Requested-With", "Origin" };
foreach (var header in headersToCapture)
{
var value = context.Request.Headers[header].FirstOrDefault();
if (!string.IsNullOrEmpty(value))
{
clientInfo.AdditionalHeaders[header] = value;
}
}

return clientInfo;
}

private static string GetClientIpAddress(HttpContext context)
{
// Try to get IP from X-Forwarded-For header (for reverse proxy scenarios)
var forwardedFor = context.Request.Headers["X-Forwarded-For"].FirstOrDefault();
if (!string.IsNullOrEmpty(forwardedFor))
{
// X-Forwarded-For can contain multiple IPs, take the first one
var ips = forwardedFor.Split(',', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries);
if (ips.Length > 0)
return ips[0];
}

// Try X-Real-IP header
var realIp = context.Request.Headers["X-Real-IP"].FirstOrDefault();
if (!string.IsNullOrEmpty(realIp))
return realIp;

// Fall back to RemoteIpAddress
return context.Connection.RemoteIpAddress?.ToString() ?? "Unknown";
}
}
}
Loading