Skip to content
This repository has been archived by the owner on Nov 1, 2023. It is now read-only.

Move auth into middleware #3133

Merged
merged 21 commits into from
Jun 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions src/ApiService/ApiService/Auth/AuthenticationItems.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
using Microsoft.Azure.Functions.Worker;

namespace Microsoft.OneFuzz.Service.Auth;

public static class AuthenticationItems {
private const string Key = "ONEFUZZ_USER_INFO";

public static void SetUserAuthInfo(this FunctionContext context, UserAuthInfo info)
=> context.Items[Key] = info;

public static UserAuthInfo GetUserAuthInfo(this FunctionContext context)
=> (UserAuthInfo)context.Items[Key];

public static UserAuthInfo? TryGetUserAuthInfo(this FunctionContext context)
=> context.Items.TryGetValue(Key, out var result) ? (UserAuthInfo)result : null;
}
111 changes: 111 additions & 0 deletions src/ApiService/ApiService/Auth/AuthenticationMiddleware.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
using System.IdentityModel.Tokens.Jwt;
using System.Net;
using System.Net.Http.Headers;
using Microsoft.Azure.Functions.Worker;
using Microsoft.Azure.Functions.Worker.Http;
using Microsoft.Azure.Functions.Worker.Middleware;

namespace Microsoft.OneFuzz.Service.Auth;

public sealed class AuthenticationMiddleware : IFunctionsWorkerMiddleware {
private readonly IConfigOperations _config;
private readonly ILogTracer _log;

public AuthenticationMiddleware(IConfigOperations config, ILogTracer log) {
_config = config;
_log = log;
}

public async Async.Task Invoke(FunctionContext context, FunctionExecutionDelegate next) {
var requestData = await context.GetHttpRequestDataAsync();
if (requestData is not null) {
var authToken = GetAuthToken(requestData);
if (authToken is not null) {
// note that no validation of the token is performed here
// this is done globally by Azure Functions; see the configuration in
// 'function.bicep'
var token = new JwtSecurityToken(authToken);
var allowedTenants = await AllowedTenants();
if (!allowedTenants.Contains(token.Issuer)) {
await BadIssuer(requestData, context, token, allowedTenants);
return;
}

context.SetUserAuthInfo(UserInfoFromAuthToken(token));
}
}

await next(context);
}

private static UserAuthInfo UserInfoFromAuthToken(JwtSecurityToken token)
=> token.Payload.Claims.Aggregate(
seed: new UserAuthInfo(new UserInfo(null, null, null), new List<string>()),
(acc, claim) => {
switch (claim.Type) {
case "oid":
return acc with { UserInfo = acc.UserInfo with { ObjectId = Guid.Parse(claim.Value) } };
case "appid":
return acc with { UserInfo = acc.UserInfo with { ApplicationId = Guid.Parse(claim.Value) } };
case "upn":
return acc with { UserInfo = acc.UserInfo with { Upn = claim.Value } };
case "roles":
acc.Roles.Add(claim.Value);
return acc;
default:
return acc;
}
});

private async Async.ValueTask BadIssuer(
HttpRequestData request,
FunctionContext context,
JwtSecurityToken token,
IEnumerable<string> allowedTenants) {

var tenantsStr = string.Join("; ", allowedTenants);
_log.Error($"issuer not from allowed tenant. issuer: {token.Issuer:Tag:Issuer} - tenants: {tenantsStr:Tag:Tenants}");

var response = HttpResponseData.CreateResponse(request);
var status = HttpStatusCode.BadRequest;
await response.WriteAsJsonAsync(
new ProblemDetails(
status,
new Error(
ErrorCode.INVALID_REQUEST,
new List<string> {
"unauthorized AAD issuer. If multi-tenant auth is failing, make sure to include all tenant_ids in the `allowed_aad_tenants` list in the instance_config. To see the current instance_config, run `onefuzz instance_config get`. "
}
)),
"application/problem+json",
status);

context.GetInvocationResult().Value = response;
}

private async Async.Task<IEnumerable<string>> AllowedTenants() {
var config = await _config.Fetch();
return config.AllowedAadTenants.Select(t => $"https://sts.windows.net/{t}/");
}

private static string? GetAuthToken(HttpRequestData requestData)
=> GetBearerToken(requestData) ?? GetAadIdToken(requestData);

private static string? GetAadIdToken(HttpRequestData requestData) {
if (!requestData.Headers.TryGetValues("x-ms-token-aad-id-token", out var values)) {
return null;
}

return values.First();
}

private static string? GetBearerToken(HttpRequestData requestData) {
if (!requestData.Headers.TryGetValues("Authorization", out var values)
|| !AuthenticationHeaderValue.TryParse(values.First(), out var headerValue)
|| !string.Equals(headerValue.Scheme, "Bearer", StringComparison.OrdinalIgnoreCase)) {
return null;
}

return headerValue.Parameter;
}
}
103 changes: 103 additions & 0 deletions src/ApiService/ApiService/Auth/AuthorizationMiddleware.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
using System.Collections.Immutable;
using System.Diagnostics;
using System.Net;
using System.Reflection;
using Microsoft.Azure.Functions.Worker;
using Microsoft.Azure.Functions.Worker.Http;
using Microsoft.Azure.Functions.Worker.Middleware;

namespace Microsoft.OneFuzz.Service.Auth;

public sealed class AuthorizationMiddleware : IFunctionsWorkerMiddleware {
private readonly IEndpointAuthorization _auth;
private readonly ILogTracer _log;

public AuthorizationMiddleware(IEndpointAuthorization auth, ILogTracer log) {
_auth = auth;
_log = log;
}

public async Async.Task Invoke(FunctionContext context, FunctionExecutionDelegate next) {
var attribute = GetAuthorizeAttribute(context);
if (attribute is not null) {
var req = await context.GetHttpRequestDataAsync() ?? throw new NotSupportedException("no HTTP request data found");
tevoinea marked this conversation as resolved.
Show resolved Hide resolved
var user = context.TryGetUserAuthInfo();
if (user is null) {
await Reject(req, context, "no authentication");
return;
}

var (isAgent, _) = await _auth.IsAgent(user);
if (isAgent) {
if (attribute.Allow != Allow.Agent) {
await Reject(req, context, "endpoint not allowed for agents");
return;
}
} else {
if (attribute.Allow == Allow.Agent) {
await Reject(req, context, "endpoint not allowed for users");
return;
}

Debug.Assert(attribute.Allow is Allow.User or Allow.Admin);

// check access control first
var access = await _auth.CheckAccess(req);
if (!access.IsOk) {
await Reject(req, context, "access control rejected request");
return;
}

// check admin next
if (attribute.Allow == Allow.Admin) {
var adminAccess = await _auth.CheckRequireAdmins(user);
if (!adminAccess.IsOk) {
await Reject(req, context, "must be admin to use this endpoint");
return;
}
}
}
}

await next(context);
}

private static async Async.ValueTask Reject(HttpRequestData request, FunctionContext context, string reason) {
var response = HttpResponseData.CreateResponse(request);
var status = HttpStatusCode.Unauthorized;
await response.WriteAsJsonAsync(
new ProblemDetails(
status,
Error.Create(ErrorCode.UNAUTHORIZED, reason)),
"application/problem+json",
status);

context.GetInvocationResult().Value = response;
}

// use ImmutableDictionary to prevent needing to lock and without the overhead
// of ConcurrentDictionary
private static ImmutableDictionary<string, AuthorizeAttribute?> _authorizeCache =
ImmutableDictionary.Create<string, AuthorizeAttribute?>();

private static AuthorizeAttribute? GetAuthorizeAttribute(FunctionContext context) {
// fully-qualified name of the method
var entryPoint = context.FunctionDefinition.EntryPoint;
if (_authorizeCache.TryGetValue(entryPoint, out var cached)) {
return cached;
}

var lastDot = entryPoint.LastIndexOf('.');
var (typeName, methodName) = (entryPoint[..lastDot], entryPoint[(lastDot + 1)..]);
var assemblyPath = context.FunctionDefinition.PathToAssembly;
Porges marked this conversation as resolved.
Show resolved Hide resolved
var assembly = Assembly.LoadFrom(assemblyPath); // should already be loaded
var type = assembly.GetType(typeName)!;
var method = type.GetMethod(methodName)!;
var result =
method.GetCustomAttribute<AuthorizeAttribute>()
?? type.GetCustomAttribute<AuthorizeAttribute>();

_authorizeCache = _authorizeCache.SetItem(entryPoint, result);
return result;
}
}
17 changes: 17 additions & 0 deletions src/ApiService/ApiService/Auth/AuthorizeAttribute.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
namespace Microsoft.OneFuzz.Service.Auth;

[AttributeUsage(AttributeTargets.Class | AttributeTargets.Method)]
public sealed class AuthorizeAttribute : Attribute {
public AuthorizeAttribute(Allow allow) {
Allow = allow;
}

public Allow Allow { get; set; }
}

public enum Allow {
Porges marked this conversation as resolved.
Show resolved Hide resolved
Agent,
User,
Admin,

}
14 changes: 5 additions & 9 deletions src/ApiService/ApiService/Functions/AgentCanSchedule.cs
Original file line number Diff line number Diff line change
@@ -1,27 +1,23 @@
using Microsoft.Azure.Functions.Worker;
using Microsoft.Azure.Functions.Worker.Http;
using Microsoft.OneFuzz.Service.Auth;

namespace Microsoft.OneFuzz.Service.Functions;

public class AgentCanSchedule {
private readonly ILogTracer _log;
private readonly IEndpointAuthorization _auth;
private readonly IOnefuzzContext _context;


public AgentCanSchedule(ILogTracer log, IEndpointAuthorization auth, IOnefuzzContext context) {
public AgentCanSchedule(ILogTracer log, IOnefuzzContext context) {
_log = log;
_auth = auth;
_context = context;
}

[Function("AgentCanSchedule")]
public Async.Task<HttpResponseData> Run(
[Authorize(Allow.Agent)]
public async Async.Task<HttpResponseData> Run(
[HttpTrigger(AuthorizationLevel.Anonymous, "POST", Route="agents/can_schedule")]
HttpRequestData req)
=> _auth.CallIfAgent(req, Post);

private async Async.Task<HttpResponseData> Post(HttpRequestData req) {
HttpRequestData req) {
var request = await RequestHandling.ParseRequest<CanScheduleRequest>(req);
if (!request.IsOk) {
_log.Warning($"Cannot schedule due to {request.ErrorV}");
Expand Down
10 changes: 5 additions & 5 deletions src/ApiService/ApiService/Functions/AgentCommands.cs
Original file line number Diff line number Diff line change
@@ -1,28 +1,28 @@
using Microsoft.Azure.Functions.Worker;
using Microsoft.Azure.Functions.Worker.Http;
using Microsoft.OneFuzz.Service.Auth;

namespace Microsoft.OneFuzz.Service.Functions;

public class AgentCommands {
private readonly ILogTracer _log;
private readonly IEndpointAuthorization _auth;
private readonly IOnefuzzContext _context;

public AgentCommands(ILogTracer log, IEndpointAuthorization auth, IOnefuzzContext context) {
public AgentCommands(ILogTracer log, IOnefuzzContext context) {
_log = log;
_auth = auth;
_context = context;
}

[Function("AgentCommands")]
[Authorize(Allow.Agent)]
public Async.Task<HttpResponseData> Run(
[HttpTrigger(AuthorizationLevel.Anonymous, "GET", "DELETE", Route="agents/commands")]
HttpRequestData req)
=> _auth.CallIfAgent(req, r => r.Method switch {
=> req.Method switch {
"GET" => Get(req),
"DELETE" => Delete(req),
_ => throw new NotSupportedException($"HTTP Method {req.Method} is not supported for this method")
});
};

private async Async.Task<HttpResponseData> Get(HttpRequestData req) {
var request = await RequestHandling.ParseRequest<NodeCommandGet>(req);
Expand Down
13 changes: 5 additions & 8 deletions src/ApiService/ApiService/Functions/AgentEvents.cs
Original file line number Diff line number Diff line change
@@ -1,28 +1,25 @@
using System.Threading.Tasks;
using Microsoft.Azure.Functions.Worker;
using Microsoft.Azure.Functions.Worker.Http;
using Microsoft.OneFuzz.Service.Auth;
using Microsoft.OneFuzz.Service.OneFuzzLib.Orm;

namespace Microsoft.OneFuzz.Service.Functions;

public class AgentEvents {
private readonly ILogTracer _log;
private readonly IEndpointAuthorization _auth;
private readonly IOnefuzzContext _context;

public AgentEvents(ILogTracer log, IEndpointAuthorization auth, IOnefuzzContext context) {
public AgentEvents(ILogTracer log, IOnefuzzContext context) {
_log = log;
_auth = auth;
_context = context;
}

[Function("AgentEvents")]
public Async.Task<HttpResponseData> Run(
[Authorize(Allow.Agent)]
public async Async.Task<HttpResponseData> Run(
[HttpTrigger(AuthorizationLevel.Anonymous, "POST", Route="agents/events")]
HttpRequestData req)
=> _auth.CallIfAgent(req, Post);

private async Async.Task<HttpResponseData> Post(HttpRequestData req) {
HttpRequestData req) {
var request = await RequestHandling.ParseRequest<NodeStateEnvelope>(req);
if (!request.IsOk) {
return await _context.RequestHandling.NotOk(req, request.ErrorV, context: "node event");
Expand Down
18 changes: 8 additions & 10 deletions src/ApiService/ApiService/Functions/AgentRegistration.cs
Original file line number Diff line number Diff line change
@@ -1,33 +1,31 @@
using Azure.Storage.Sas;
using Microsoft.Azure.Functions.Worker;
using Microsoft.Azure.Functions.Worker.Http;
using Microsoft.OneFuzz.Service.Auth;

namespace Microsoft.OneFuzz.Service.Functions;

public class AgentRegistration {
private readonly ILogTracer _log;
private readonly IEndpointAuthorization _auth;
private readonly IOnefuzzContext _context;

public AgentRegistration(ILogTracer log, IEndpointAuthorization auth, IOnefuzzContext context) {
public AgentRegistration(ILogTracer log, IOnefuzzContext context) {
_log = log;
_auth = auth;
_context = context;
}

[Function("AgentRegistration")]
[Authorize(Allow.Agent)]
public Async.Task<HttpResponseData> Run(
[HttpTrigger(
AuthorizationLevel.Anonymous,
"GET", "POST",
Route="agents/registration")] HttpRequestData req)
=> _auth.CallIfAgent(
req,
r => r.Method switch {
"GET" => Get(r),
"POST" => Post(r),
var m => throw new InvalidOperationException($"method {m} not supported"),
});
=> req.Method switch {
"GET" => Get(req),
"POST" => Post(req),
var m => throw new InvalidOperationException($"method {m} not supported"),
};

private async Async.Task<HttpResponseData> Get(HttpRequestData req) {
var request = await RequestHandling.ParseUri<AgentRegistrationGet>(req);
Expand Down
Loading