From e448947abe7d2b36fd461633dc546e5548d30226 Mon Sep 17 00:00:00 2001 From: George Pollard Date: Wed, 7 Jun 2023 13:57:22 +1200 Subject: [PATCH] Move auth into middleware (#3133) Closes #2098. This cleans up the authentication a bit; after this change we have two stages in the middleware pipeline: - `AuthenticationMiddleware` reads the JWT token (it does not validate it, this is done by the Azure Functions service) and stores it in `FunctionContext.Items["ONEFUZZ_USER_INFO"]` - `AuthorizationMiddleware` checks the user info against the `[Authorize]` attribute to see if the user has the required permissions - Functions can read the user info from the `FunctionContext` if needed The authorize attribute can be `[Authorize(Allow.User)]` or `Allow.Agent` or `Allow.Admin`. The `Admin` case is new and allows this to be declaratively specified rather than being checked in code. We have several functions which could be changed to use this (e.g. Pool POST/DELETE/PATCH, Scaleset POST/DELETE/PATCH), but I have only changed one so far (JinjaToScriban). One of the benefits here is that this simplifies the test code a lot: we can set the desired user info directly onto our `(Test)FunctionContext` rather than having to supply a fake that pretends to parse the token from the HTTP request. This will also have benefits when running the service locally for testing purposes (refer to internal issue). The other benefit is the ability to programmatically read the required authentication for each function, which may help with Swagger generation. --- .../ApiService/Auth/AuthenticationItems.cs | 16 ++ .../Auth/AuthenticationMiddleware.cs | 111 ++++++++++++ .../Auth/AuthorizationMiddleware.cs | 103 +++++++++++ .../ApiService/Auth/AuthorizeAttribute.cs | 17 ++ .../ApiService/Functions/AgentCanSchedule.cs | 14 +- .../ApiService/Functions/AgentCommands.cs | 10 +- .../ApiService/Functions/AgentEvents.cs | 13 +- .../ApiService/Functions/AgentRegistration.cs | 18 +- src/ApiService/ApiService/Functions/Config.cs | 9 +- .../ApiService/Functions/Containers.cs | 16 +- .../ApiService/Functions/Download.cs | 11 +- src/ApiService/ApiService/Functions/Events.cs | 14 +- src/ApiService/ApiService/Functions/Info.cs | 13 +- .../ApiService/Functions/InstanceConfig.cs | 36 ++-- src/ApiService/ApiService/Functions/Jobs.cs | 30 ++-- .../Functions/Migrations/JinjaToScriban.cs | 17 +- .../ApiService/Functions/Negotiate.cs | 28 ++- src/ApiService/ApiService/Functions/Node.cs | 49 +++-- .../ApiService/Functions/NodeAddSshKey.cs | 23 +-- .../ApiService/Functions/Notifications.cs | 19 +- .../ApiService/Functions/NotificationsTest.cs | 18 +- src/ApiService/ApiService/Functions/Pool.cs | 52 +++--- src/ApiService/ApiService/Functions/Proxy.cs | 24 +-- .../ApiService/Functions/ReproVmss.cs | 36 ++-- .../ApiService/Functions/Scaleset.cs | 51 +++--- src/ApiService/ApiService/Functions/Tasks.cs | 33 ++-- src/ApiService/ApiService/Functions/Tool.cs | 16 +- .../ApiService/Functions/ValidateScriban.cs | 16 +- .../ApiService/Functions/WebhookLogs.cs | 18 +- .../ApiService/Functions/WebhookPing.cs | 18 +- .../ApiService/Functions/Webhooks.cs | 24 ++- src/ApiService/ApiService/HttpClient.cs | 2 +- src/ApiService/ApiService/Program.cs | 16 +- src/ApiService/ApiService/UserCredentials.cs | 103 ----------- src/ApiService/ApiService/onefuzzlib/Auth.cs | 2 +- .../onefuzzlib/EndpointAuthorization.cs | 87 ++------- .../ApiService/onefuzzlib/OnefuzzContext.cs | 2 - .../ApiService/onefuzzlib/ProxyOperations.cs | 12 +- .../ApiService/onefuzzlib/ReproOperations.cs | 2 +- .../IntegrationTests/AgentCanScheduleTests.cs | 33 +--- .../IntegrationTests/AgentCommandsTests.cs | 30 +--- .../IntegrationTests/AgentEventsTests.cs | 51 ++---- .../AgentRegistrationTests.cs | 51 +----- src/ApiService/IntegrationTests/AuthTests.cs | 2 +- .../IntegrationTests/ContainersTests.cs | 37 +--- .../IntegrationTests/DownloadTests.cs | 21 +-- .../IntegrationTests/EndpointAuthTests.cs | 93 ++++++++++ .../IntegrationTests/EventsTests.cs | 3 +- .../IntegrationTests/Fakes/TestContext.cs | 2 - .../Fakes/TestEndpointAuthorization.cs | 47 ----- .../Fakes/TestFunctionContext.cs | 32 ++++ .../Fakes/TestUserCredentials.cs | 20 --- src/ApiService/IntegrationTests/InfoTests.cs | 23 +-- .../JinjaToScribanMigrationTests.cs | 53 +----- src/ApiService/IntegrationTests/JobsTests.cs | 61 +++---- src/ApiService/IntegrationTests/NodeTests.cs | 168 +----------------- src/ApiService/IntegrationTests/PoolTests.cs | 88 ++------- .../IntegrationTests/ReproVmssTests.cs | 87 +++------ .../IntegrationTests/ScalesetTests.cs | 58 +----- src/ApiService/IntegrationTests/TasksTests.cs | 24 ++- src/ApiService/IntegrationTests/ToolsTests.cs | 3 +- src/ApiService/Tests/AuthTests.cs | 79 ++++++++ 62 files changed, 881 insertions(+), 1284 deletions(-) create mode 100644 src/ApiService/ApiService/Auth/AuthenticationItems.cs create mode 100644 src/ApiService/ApiService/Auth/AuthenticationMiddleware.cs create mode 100644 src/ApiService/ApiService/Auth/AuthorizationMiddleware.cs create mode 100644 src/ApiService/ApiService/Auth/AuthorizeAttribute.cs delete mode 100644 src/ApiService/ApiService/UserCredentials.cs create mode 100644 src/ApiService/IntegrationTests/EndpointAuthTests.cs delete mode 100644 src/ApiService/IntegrationTests/Fakes/TestEndpointAuthorization.cs create mode 100644 src/ApiService/IntegrationTests/Fakes/TestFunctionContext.cs delete mode 100644 src/ApiService/IntegrationTests/Fakes/TestUserCredentials.cs create mode 100644 src/ApiService/Tests/AuthTests.cs diff --git a/src/ApiService/ApiService/Auth/AuthenticationItems.cs b/src/ApiService/ApiService/Auth/AuthenticationItems.cs new file mode 100644 index 0000000000..b546a40889 --- /dev/null +++ b/src/ApiService/ApiService/Auth/AuthenticationItems.cs @@ -0,0 +1,16 @@ +using Microsoft.Azure.Functions.Worker; + +namespace Microsoft.OneFuzz.Service.Auth; + +public static class AuthenticationItems { + private const string Key = "ONEFUZZ_USER_INFO"; + + public static void SetUserAuthInfo(this FunctionContext context, UserAuthInfo info) + => context.Items[Key] = info; + + public static UserAuthInfo GetUserAuthInfo(this FunctionContext context) + => (UserAuthInfo)context.Items[Key]; + + public static UserAuthInfo? TryGetUserAuthInfo(this FunctionContext context) + => context.Items.TryGetValue(Key, out var result) ? (UserAuthInfo)result : null; +} diff --git a/src/ApiService/ApiService/Auth/AuthenticationMiddleware.cs b/src/ApiService/ApiService/Auth/AuthenticationMiddleware.cs new file mode 100644 index 0000000000..2cec43994b --- /dev/null +++ b/src/ApiService/ApiService/Auth/AuthenticationMiddleware.cs @@ -0,0 +1,111 @@ +using System.IdentityModel.Tokens.Jwt; +using System.Net; +using System.Net.Http.Headers; +using Microsoft.Azure.Functions.Worker; +using Microsoft.Azure.Functions.Worker.Http; +using Microsoft.Azure.Functions.Worker.Middleware; + +namespace Microsoft.OneFuzz.Service.Auth; + +public sealed class AuthenticationMiddleware : IFunctionsWorkerMiddleware { + private readonly IConfigOperations _config; + private readonly ILogTracer _log; + + public AuthenticationMiddleware(IConfigOperations config, ILogTracer log) { + _config = config; + _log = log; + } + + public async Async.Task Invoke(FunctionContext context, FunctionExecutionDelegate next) { + var requestData = await context.GetHttpRequestDataAsync(); + if (requestData is not null) { + var authToken = GetAuthToken(requestData); + if (authToken is not null) { + // note that no validation of the token is performed here + // this is done globally by Azure Functions; see the configuration in + // 'function.bicep' + var token = new JwtSecurityToken(authToken); + var allowedTenants = await AllowedTenants(); + if (!allowedTenants.Contains(token.Issuer)) { + await BadIssuer(requestData, context, token, allowedTenants); + return; + } + + context.SetUserAuthInfo(UserInfoFromAuthToken(token)); + } + } + + await next(context); + } + + private static UserAuthInfo UserInfoFromAuthToken(JwtSecurityToken token) + => token.Payload.Claims.Aggregate( + seed: new UserAuthInfo(new UserInfo(null, null, null), new List()), + (acc, claim) => { + switch (claim.Type) { + case "oid": + return acc with { UserInfo = acc.UserInfo with { ObjectId = Guid.Parse(claim.Value) } }; + case "appid": + return acc with { UserInfo = acc.UserInfo with { ApplicationId = Guid.Parse(claim.Value) } }; + case "upn": + return acc with { UserInfo = acc.UserInfo with { Upn = claim.Value } }; + case "roles": + acc.Roles.Add(claim.Value); + return acc; + default: + return acc; + } + }); + + private async Async.ValueTask BadIssuer( + HttpRequestData request, + FunctionContext context, + JwtSecurityToken token, + IEnumerable allowedTenants) { + + var tenantsStr = string.Join("; ", allowedTenants); + _log.Error($"issuer not from allowed tenant. issuer: {token.Issuer:Tag:Issuer} - tenants: {tenantsStr:Tag:Tenants}"); + + var response = HttpResponseData.CreateResponse(request); + var status = HttpStatusCode.BadRequest; + await response.WriteAsJsonAsync( + new ProblemDetails( + status, + new Error( + ErrorCode.INVALID_REQUEST, + new List { + "unauthorized AAD issuer. If multi-tenant auth is failing, make sure to include all tenant_ids in the `allowed_aad_tenants` list in the instance_config. To see the current instance_config, run `onefuzz instance_config get`. " + } + )), + "application/problem+json", + status); + + context.GetInvocationResult().Value = response; + } + + private async Async.Task> AllowedTenants() { + var config = await _config.Fetch(); + return config.AllowedAadTenants.Select(t => $"https://sts.windows.net/{t}/"); + } + + private static string? GetAuthToken(HttpRequestData requestData) + => GetBearerToken(requestData) ?? GetAadIdToken(requestData); + + private static string? GetAadIdToken(HttpRequestData requestData) { + if (!requestData.Headers.TryGetValues("x-ms-token-aad-id-token", out var values)) { + return null; + } + + return values.First(); + } + + private static string? GetBearerToken(HttpRequestData requestData) { + if (!requestData.Headers.TryGetValues("Authorization", out var values) + || !AuthenticationHeaderValue.TryParse(values.First(), out var headerValue) + || !string.Equals(headerValue.Scheme, "Bearer", StringComparison.OrdinalIgnoreCase)) { + return null; + } + + return headerValue.Parameter; + } +} diff --git a/src/ApiService/ApiService/Auth/AuthorizationMiddleware.cs b/src/ApiService/ApiService/Auth/AuthorizationMiddleware.cs new file mode 100644 index 0000000000..34ca16ab28 --- /dev/null +++ b/src/ApiService/ApiService/Auth/AuthorizationMiddleware.cs @@ -0,0 +1,103 @@ +using System.Collections.Immutable; +using System.Diagnostics; +using System.Net; +using System.Reflection; +using Microsoft.Azure.Functions.Worker; +using Microsoft.Azure.Functions.Worker.Http; +using Microsoft.Azure.Functions.Worker.Middleware; + +namespace Microsoft.OneFuzz.Service.Auth; + +public sealed class AuthorizationMiddleware : IFunctionsWorkerMiddleware { + private readonly IEndpointAuthorization _auth; + private readonly ILogTracer _log; + + public AuthorizationMiddleware(IEndpointAuthorization auth, ILogTracer log) { + _auth = auth; + _log = log; + } + + public async Async.Task Invoke(FunctionContext context, FunctionExecutionDelegate next) { + var attribute = GetAuthorizeAttribute(context); + if (attribute is not null) { + var req = await context.GetHttpRequestDataAsync() ?? throw new NotSupportedException("no HTTP request data found"); + var user = context.TryGetUserAuthInfo(); + if (user is null) { + await Reject(req, context, "no authentication"); + return; + } + + var (isAgent, _) = await _auth.IsAgent(user); + if (isAgent) { + if (attribute.Allow != Allow.Agent) { + await Reject(req, context, "endpoint not allowed for agents"); + return; + } + } else { + if (attribute.Allow == Allow.Agent) { + await Reject(req, context, "endpoint not allowed for users"); + return; + } + + Debug.Assert(attribute.Allow is Allow.User or Allow.Admin); + + // check access control first + var access = await _auth.CheckAccess(req); + if (!access.IsOk) { + await Reject(req, context, "access control rejected request"); + return; + } + + // check admin next + if (attribute.Allow == Allow.Admin) { + var adminAccess = await _auth.CheckRequireAdmins(user); + if (!adminAccess.IsOk) { + await Reject(req, context, "must be admin to use this endpoint"); + return; + } + } + } + } + + await next(context); + } + + private static async Async.ValueTask Reject(HttpRequestData request, FunctionContext context, string reason) { + var response = HttpResponseData.CreateResponse(request); + var status = HttpStatusCode.Unauthorized; + await response.WriteAsJsonAsync( + new ProblemDetails( + status, + Error.Create(ErrorCode.UNAUTHORIZED, reason)), + "application/problem+json", + status); + + context.GetInvocationResult().Value = response; + } + + // use ImmutableDictionary to prevent needing to lock and without the overhead + // of ConcurrentDictionary + private static ImmutableDictionary _authorizeCache = + ImmutableDictionary.Create(); + + private static AuthorizeAttribute? GetAuthorizeAttribute(FunctionContext context) { + // fully-qualified name of the method + var entryPoint = context.FunctionDefinition.EntryPoint; + if (_authorizeCache.TryGetValue(entryPoint, out var cached)) { + return cached; + } + + var lastDot = entryPoint.LastIndexOf('.'); + var (typeName, methodName) = (entryPoint[..lastDot], entryPoint[(lastDot + 1)..]); + var assemblyPath = context.FunctionDefinition.PathToAssembly; + var assembly = Assembly.LoadFrom(assemblyPath); // should already be loaded + var type = assembly.GetType(typeName)!; + var method = type.GetMethod(methodName)!; + var result = + method.GetCustomAttribute() + ?? type.GetCustomAttribute(); + + _authorizeCache = _authorizeCache.SetItem(entryPoint, result); + return result; + } +} diff --git a/src/ApiService/ApiService/Auth/AuthorizeAttribute.cs b/src/ApiService/ApiService/Auth/AuthorizeAttribute.cs new file mode 100644 index 0000000000..c195e33e0f --- /dev/null +++ b/src/ApiService/ApiService/Auth/AuthorizeAttribute.cs @@ -0,0 +1,17 @@ +namespace Microsoft.OneFuzz.Service.Auth; + +[AttributeUsage(AttributeTargets.Class | AttributeTargets.Method)] +public sealed class AuthorizeAttribute : Attribute { + public AuthorizeAttribute(Allow allow) { + Allow = allow; + } + + public Allow Allow { get; set; } +} + +public enum Allow { + Agent, + User, + Admin, + +} diff --git a/src/ApiService/ApiService/Functions/AgentCanSchedule.cs b/src/ApiService/ApiService/Functions/AgentCanSchedule.cs index 8e83a015b8..b4d54c9477 100644 --- a/src/ApiService/ApiService/Functions/AgentCanSchedule.cs +++ b/src/ApiService/ApiService/Functions/AgentCanSchedule.cs @@ -1,27 +1,23 @@ using Microsoft.Azure.Functions.Worker; using Microsoft.Azure.Functions.Worker.Http; +using Microsoft.OneFuzz.Service.Auth; namespace Microsoft.OneFuzz.Service.Functions; public class AgentCanSchedule { private readonly ILogTracer _log; - private readonly IEndpointAuthorization _auth; private readonly IOnefuzzContext _context; - - public AgentCanSchedule(ILogTracer log, IEndpointAuthorization auth, IOnefuzzContext context) { + public AgentCanSchedule(ILogTracer log, IOnefuzzContext context) { _log = log; - _auth = auth; _context = context; } [Function("AgentCanSchedule")] - public Async.Task Run( + [Authorize(Allow.Agent)] + public async Async.Task Run( [HttpTrigger(AuthorizationLevel.Anonymous, "POST", Route="agents/can_schedule")] - HttpRequestData req) - => _auth.CallIfAgent(req, Post); - - private async Async.Task Post(HttpRequestData req) { + HttpRequestData req) { var request = await RequestHandling.ParseRequest(req); if (!request.IsOk) { _log.Warning($"Cannot schedule due to {request.ErrorV}"); diff --git a/src/ApiService/ApiService/Functions/AgentCommands.cs b/src/ApiService/ApiService/Functions/AgentCommands.cs index 5c1d7721a4..2ba68e9e64 100644 --- a/src/ApiService/ApiService/Functions/AgentCommands.cs +++ b/src/ApiService/ApiService/Functions/AgentCommands.cs @@ -1,28 +1,28 @@ using Microsoft.Azure.Functions.Worker; using Microsoft.Azure.Functions.Worker.Http; +using Microsoft.OneFuzz.Service.Auth; namespace Microsoft.OneFuzz.Service.Functions; public class AgentCommands { private readonly ILogTracer _log; - private readonly IEndpointAuthorization _auth; private readonly IOnefuzzContext _context; - public AgentCommands(ILogTracer log, IEndpointAuthorization auth, IOnefuzzContext context) { + public AgentCommands(ILogTracer log, IOnefuzzContext context) { _log = log; - _auth = auth; _context = context; } [Function("AgentCommands")] + [Authorize(Allow.Agent)] public Async.Task Run( [HttpTrigger(AuthorizationLevel.Anonymous, "GET", "DELETE", Route="agents/commands")] HttpRequestData req) - => _auth.CallIfAgent(req, r => r.Method switch { + => req.Method switch { "GET" => Get(req), "DELETE" => Delete(req), _ => throw new NotSupportedException($"HTTP Method {req.Method} is not supported for this method") - }); + }; private async Async.Task Get(HttpRequestData req) { var request = await RequestHandling.ParseRequest(req); diff --git a/src/ApiService/ApiService/Functions/AgentEvents.cs b/src/ApiService/ApiService/Functions/AgentEvents.cs index 46f329856d..63559dd102 100644 --- a/src/ApiService/ApiService/Functions/AgentEvents.cs +++ b/src/ApiService/ApiService/Functions/AgentEvents.cs @@ -1,28 +1,25 @@ using System.Threading.Tasks; using Microsoft.Azure.Functions.Worker; using Microsoft.Azure.Functions.Worker.Http; +using Microsoft.OneFuzz.Service.Auth; using Microsoft.OneFuzz.Service.OneFuzzLib.Orm; namespace Microsoft.OneFuzz.Service.Functions; public class AgentEvents { private readonly ILogTracer _log; - private readonly IEndpointAuthorization _auth; private readonly IOnefuzzContext _context; - public AgentEvents(ILogTracer log, IEndpointAuthorization auth, IOnefuzzContext context) { + public AgentEvents(ILogTracer log, IOnefuzzContext context) { _log = log; - _auth = auth; _context = context; } [Function("AgentEvents")] - public Async.Task Run( + [Authorize(Allow.Agent)] + public async Async.Task Run( [HttpTrigger(AuthorizationLevel.Anonymous, "POST", Route="agents/events")] - HttpRequestData req) - => _auth.CallIfAgent(req, Post); - - private async Async.Task Post(HttpRequestData req) { + HttpRequestData req) { var request = await RequestHandling.ParseRequest(req); if (!request.IsOk) { return await _context.RequestHandling.NotOk(req, request.ErrorV, context: "node event"); diff --git a/src/ApiService/ApiService/Functions/AgentRegistration.cs b/src/ApiService/ApiService/Functions/AgentRegistration.cs index 6b715c2245..31134c78cb 100644 --- a/src/ApiService/ApiService/Functions/AgentRegistration.cs +++ b/src/ApiService/ApiService/Functions/AgentRegistration.cs @@ -1,33 +1,31 @@ using Azure.Storage.Sas; using Microsoft.Azure.Functions.Worker; using Microsoft.Azure.Functions.Worker.Http; +using Microsoft.OneFuzz.Service.Auth; namespace Microsoft.OneFuzz.Service.Functions; public class AgentRegistration { private readonly ILogTracer _log; - private readonly IEndpointAuthorization _auth; private readonly IOnefuzzContext _context; - public AgentRegistration(ILogTracer log, IEndpointAuthorization auth, IOnefuzzContext context) { + public AgentRegistration(ILogTracer log, IOnefuzzContext context) { _log = log; - _auth = auth; _context = context; } [Function("AgentRegistration")] + [Authorize(Allow.Agent)] public Async.Task Run( [HttpTrigger( AuthorizationLevel.Anonymous, "GET", "POST", Route="agents/registration")] HttpRequestData req) - => _auth.CallIfAgent( - req, - r => r.Method switch { - "GET" => Get(r), - "POST" => Post(r), - var m => throw new InvalidOperationException($"method {m} not supported"), - }); + => req.Method switch { + "GET" => Get(req), + "POST" => Post(req), + var m => throw new InvalidOperationException($"method {m} not supported"), + }; private async Async.Task Get(HttpRequestData req) { var request = await RequestHandling.ParseUri(req); diff --git a/src/ApiService/ApiService/Functions/Config.cs b/src/ApiService/ApiService/Functions/Config.cs index 2704ab6b00..6f4037584b 100644 --- a/src/ApiService/ApiService/Functions/Config.cs +++ b/src/ApiService/ApiService/Functions/Config.cs @@ -14,11 +14,10 @@ public Config(ILogTracer log, IOnefuzzContext context) { } [Function("Config")] - public Async.Task Run( - [HttpTrigger(AuthorizationLevel.Anonymous, "GET")] HttpRequestData req) { - return Get(req); - } - public async Async.Task Get(HttpRequestData req) { + public async Async.Task Run( + [HttpTrigger(AuthorizationLevel.Anonymous, "GET")] + HttpRequestData req) { + _log.Info($"getting endpoint config parameters"); var endpointParams = new ConfigResponse( diff --git a/src/ApiService/ApiService/Functions/Containers.cs b/src/ApiService/ApiService/Functions/Containers.cs index c3bedcfd66..8eed38e23b 100644 --- a/src/ApiService/ApiService/Functions/Containers.cs +++ b/src/ApiService/ApiService/Functions/Containers.cs @@ -1,28 +1,28 @@ using Azure.Storage.Sas; using Microsoft.Azure.Functions.Worker; using Microsoft.Azure.Functions.Worker.Http; +using Microsoft.OneFuzz.Service.Auth; namespace Microsoft.OneFuzz.Service.Functions; public class ContainersFunction { private readonly ILogTracer _logger; - private readonly IEndpointAuthorization _auth; private readonly IOnefuzzContext _context; - public ContainersFunction(ILogTracer logger, IEndpointAuthorization auth, IOnefuzzContext context) { + public ContainersFunction(ILogTracer logger, IOnefuzzContext context) { _logger = logger; - _auth = auth; _context = context; } [Function("Containers")] + [Authorize(Allow.User)] public Async.Task Run([HttpTrigger(AuthorizationLevel.Anonymous, "GET", "POST", "DELETE")] HttpRequestData req) - => _auth.CallIfUser(req, r => r.Method switch { - "GET" => Get(r), - "POST" => Post(r), - "DELETE" => Delete(r), + => req.Method switch { + "GET" => Get(req), + "POST" => Post(req), + "DELETE" => Delete(req), _ => throw new NotSupportedException(), - }); + }; private async Async.Task Get(HttpRequestData req) { diff --git a/src/ApiService/ApiService/Functions/Download.cs b/src/ApiService/ApiService/Functions/Download.cs index dac128b08b..27202afad1 100644 --- a/src/ApiService/ApiService/Functions/Download.cs +++ b/src/ApiService/ApiService/Functions/Download.cs @@ -2,23 +2,20 @@ using Azure.Storage.Sas; using Microsoft.Azure.Functions.Worker; using Microsoft.Azure.Functions.Worker.Http; +using Microsoft.OneFuzz.Service.Auth; namespace Microsoft.OneFuzz.Service.Functions; public class Download { - private readonly IEndpointAuthorization _auth; private readonly IOnefuzzContext _context; - public Download(IEndpointAuthorization auth, IOnefuzzContext context) { - _auth = auth; + public Download(IOnefuzzContext context) { _context = context; } [Function("Download")] - public Async.Task Run([HttpTrigger(AuthorizationLevel.Anonymous, "GET")] HttpRequestData req) - => _auth.CallIfUser(req, Get); - - private async Async.Task Get(HttpRequestData req) { + [Authorize(Allow.User)] + public async Async.Task Run([HttpTrigger(AuthorizationLevel.Anonymous, "GET")] HttpRequestData req) { var query = HttpUtility.ParseQueryString(req.Url.Query); var queryContainer = query["container"]; diff --git a/src/ApiService/ApiService/Functions/Events.cs b/src/ApiService/ApiService/Functions/Events.cs index 83cb585a19..67c25febec 100644 --- a/src/ApiService/ApiService/Functions/Events.cs +++ b/src/ApiService/ApiService/Functions/Events.cs @@ -1,27 +1,21 @@ using Microsoft.Azure.Functions.Worker; using Microsoft.Azure.Functions.Worker.Http; +using Microsoft.OneFuzz.Service.Auth; namespace Microsoft.OneFuzz.Service.Functions; public class EventsFunction { private readonly ILogTracer _log; - private readonly IEndpointAuthorization _auth; private readonly IOnefuzzContext _context; - public EventsFunction(ILogTracer log, IEndpointAuthorization auth, IOnefuzzContext context) { - _auth = auth; + public EventsFunction(ILogTracer log, IOnefuzzContext context) { _context = context; _log = log; } [Function("Events")] - public Async.Task Run([HttpTrigger(AuthorizationLevel.Anonymous, "GET")] HttpRequestData req) - => _auth.CallIfUser(req, r => r.Method switch { - "GET" => Get(r), - _ => throw new NotSupportedException(), - }); - - private async Async.Task Get(HttpRequestData req) { + [Authorize(Allow.User)] + public async Async.Task Run([HttpTrigger(AuthorizationLevel.Anonymous, "GET")] HttpRequestData req) { var request = await RequestHandling.ParseRequest(req); if (!request.IsOk) { return await _context.RequestHandling.NotOk(req, request.ErrorV, "events get"); diff --git a/src/ApiService/ApiService/Functions/Info.cs b/src/ApiService/ApiService/Functions/Info.cs index 1c5aef1033..6dd32a12a7 100644 --- a/src/ApiService/ApiService/Functions/Info.cs +++ b/src/ApiService/ApiService/Functions/Info.cs @@ -4,17 +4,16 @@ using System.Threading; using Microsoft.Azure.Functions.Worker; using Microsoft.Azure.Functions.Worker.Http; +using Microsoft.OneFuzz.Service.Auth; namespace Microsoft.OneFuzz.Service.Functions; public class Info { private readonly IOnefuzzContext _context; - private readonly IEndpointAuthorization _auth; private readonly Lazy> _response; - public Info(IEndpointAuthorization auth, IOnefuzzContext context) { + public Info(IOnefuzzContext context) { _context = context; - _auth = auth; // TODO: this isn’t actually shared between calls at the moment, // this needs to be placed into a class that can be registered into the @@ -60,10 +59,8 @@ private static string ReadResource(Assembly asm, string resourceName) { return sr.ReadToEnd().Trim(); } - private async Async.Task GetResponse(HttpRequestData req) - => await RequestHandling.Ok(req, await _response.Value); - [Function("Info")] - public Async.Task Run([HttpTrigger(AuthorizationLevel.Anonymous, "GET")] HttpRequestData req) - => _auth.CallIfUser(req, GetResponse); + [Authorize(Allow.User)] + public async Async.Task Run([HttpTrigger(AuthorizationLevel.Anonymous, "GET")] HttpRequestData req) + => await RequestHandling.Ok(req, await _response.Value); } diff --git a/src/ApiService/ApiService/Functions/InstanceConfig.cs b/src/ApiService/ApiService/Functions/InstanceConfig.cs index 59ae6d316c..4718c8e3f5 100644 --- a/src/ApiService/ApiService/Functions/InstanceConfig.cs +++ b/src/ApiService/ApiService/Functions/InstanceConfig.cs @@ -2,30 +2,26 @@ using System.Threading.Tasks; using Microsoft.Azure.Functions.Worker; using Microsoft.Azure.Functions.Worker.Http; +using Microsoft.OneFuzz.Service.Auth; namespace Microsoft.OneFuzz.Service.Functions; public class InstanceConfig { private readonly ILogTracer _log; - private readonly IEndpointAuthorization _auth; private readonly IOnefuzzContext _context; - public InstanceConfig(ILogTracer log, IEndpointAuthorization auth, IOnefuzzContext context) { + public InstanceConfig(ILogTracer log, IOnefuzzContext context) { _log = log; - _auth = auth; _context = context; } + public const string Route = "instance_config"; + [Function("InstanceConfig")] - public Async.Task Run( - [HttpTrigger(AuthorizationLevel.Anonymous, "GET", "POST", Route = "instance_config")] HttpRequestData req) { - return _auth.CallIfUser(req, r => r.Method switch { - "GET" => Get(r), - "POST" => Post(r), - _ => throw new InvalidOperationException("Unsupported HTTP method"), - }); - } - public async Async.Task Get(HttpRequestData req) { + [Authorize(Allow.User)] + public async Task Get( + [HttpTrigger(AuthorizationLevel.Anonymous, "GET", Route=Route)] + HttpRequestData req) { _log.Info($"getting instance_config"); var config = await _context.ConfigOperations.Fetch(); @@ -34,7 +30,11 @@ public async Async.Task Get(HttpRequestData req) { return response; } - public async Async.Task Post(HttpRequestData req) { + [Function("InstanceConfig_Admin")] + [Authorize(Allow.Admin)] + public async Task Post( + [HttpTrigger(AuthorizationLevel.Anonymous, "POST", Route=Route)] + HttpRequestData req) { _log.Info($"attempting instance_config update"); var request = await RequestHandling.ParseRequest(req); @@ -44,12 +44,8 @@ public async Async.Task Post(HttpRequestData req) { request.ErrorV, context: "instance_config update"); } - var (config, answer) = await ( - _context.ConfigOperations.Fetch(), - _auth.CheckRequireAdmins(req)); - if (!answer.IsOk) { - return await _context.RequestHandling.NotOk(req, answer.ErrorV, "instance_config update"); - } + + var config = await _context.ConfigOperations.Fetch(); var updateNsg = false; if (request.OkV.config.ProxyNsgConfig is NetworkSecurityGroupConfig requestConfig && config.ProxyNsgConfig is NetworkSecurityGroupConfig currentConfig) { @@ -58,7 +54,9 @@ public async Async.Task Post(HttpRequestData req) { updateNsg = true; } } + await _context.ConfigOperations.Save(request.OkV.config, false, false); + if (updateNsg) { await foreach (var nsg in _context.NsgOperations.ListNsgs()) { _log.Info($"Checking if nsg: {nsg.Data.Location!:Tag:Location} ({nsg.Data.Name:Tag:NsgName}) owned by OneFuzz"); diff --git a/src/ApiService/ApiService/Functions/Jobs.cs b/src/ApiService/ApiService/Functions/Jobs.cs index c7ee52bb5f..9326f773f1 100644 --- a/src/ApiService/ApiService/Functions/Jobs.cs +++ b/src/ApiService/ApiService/Functions/Jobs.cs @@ -1,39 +1,39 @@ using System.Threading.Tasks; using Microsoft.Azure.Functions.Worker; using Microsoft.Azure.Functions.Worker.Http; +using Microsoft.OneFuzz.Service.Auth; namespace Microsoft.OneFuzz.Service.Functions; public class Jobs { private readonly IOnefuzzContext _context; - private readonly IEndpointAuthorization _auth; private readonly ILogTracer _logTracer; - public Jobs(IEndpointAuthorization auth, IOnefuzzContext context, ILogTracer logTracer) { + public Jobs(IOnefuzzContext context, ILogTracer logTracer) { _context = context; - _auth = auth; _logTracer = logTracer; } [Function("Jobs")] - public Async.Task Run([HttpTrigger(AuthorizationLevel.Anonymous, "GET", "POST", "DELETE")] HttpRequestData req) - => _auth.CallIfUser(req, r => r.Method switch { - "GET" => Get(r), - "DELETE" => Delete(r), - "POST" => Post(r), + [Authorize(Allow.User)] + public Async.Task Run( + [HttpTrigger(AuthorizationLevel.Anonymous, "GET", "POST", "DELETE")] + HttpRequestData req, + FunctionContext context) + => req.Method switch { + "GET" => Get(req), + "DELETE" => Delete(req), + "POST" => Post(req, context), var m => throw new NotSupportedException($"Unsupported HTTP method {m}"), - }); + }; - private async Task Post(HttpRequestData req) { + private async Task Post(HttpRequestData req, FunctionContext context) { var request = await RequestHandling.ParseRequest(req); if (!request.IsOk) { return await _context.RequestHandling.NotOk(req, request.ErrorV, "jobs create"); } - var userInfo = await _context.UserCredentials.ParseJwtToken(req); - if (!userInfo.IsOk) { - return await _context.RequestHandling.NotOk(req, userInfo.ErrorV, "jobs create"); - } + var userInfo = context.GetUserAuthInfo(); var create = request.OkV; var cfg = new JobConfig( @@ -47,7 +47,7 @@ private async Task Post(HttpRequestData req) { JobId: Guid.NewGuid(), State: JobState.Init, Config: cfg) { - UserInfo = userInfo.OkV.UserInfo, + UserInfo = userInfo.UserInfo, }; // create the job logs container diff --git a/src/ApiService/ApiService/Functions/Migrations/JinjaToScriban.cs b/src/ApiService/ApiService/Functions/Migrations/JinjaToScriban.cs index 197c90ecae..0e69ab5103 100644 --- a/src/ApiService/ApiService/Functions/Migrations/JinjaToScriban.cs +++ b/src/ApiService/ApiService/Functions/Migrations/JinjaToScriban.cs @@ -1,28 +1,26 @@ using Microsoft.Azure.Functions.Worker; using Microsoft.Azure.Functions.Worker.Http; +using Microsoft.OneFuzz.Service.Auth; namespace Microsoft.OneFuzz.Service.Functions; public class JinjaToScriban { private readonly ILogTracer _log; - private readonly IEndpointAuthorization _auth; private readonly IOnefuzzContext _context; - public JinjaToScriban(ILogTracer log, IEndpointAuthorization auth, IOnefuzzContext context) { + public JinjaToScriban(ILogTracer log, IOnefuzzContext context) { _log = log; - _auth = auth; _context = context; } [Function("JinjaToScriban")] - public Async.Task Run( + [Authorize(Allow.Admin)] + public async Async.Task Run( [HttpTrigger(AuthorizationLevel.Anonymous, "POST", Route="migrations/jinja_to_scriban")] - HttpRequestData req) - => _auth.CallIfUser(req, Post); + HttpRequestData req) { - private async Async.Task Post(HttpRequestData req) { var request = await RequestHandling.ParseRequest(req); if (!request.IsOk) { return await _context.RequestHandling.NotOk( @@ -31,11 +29,6 @@ private async Async.Task Post(HttpRequestData req) { "JinjaToScriban"); } - var answer = await _auth.CheckRequireAdmins(req); - if (!answer.IsOk) { - return await _context.RequestHandling.NotOk(req, answer.ErrorV, "JinjaToScriban"); - } - _log.Info($"Finding notifications to migrate"); var notifications = _context.NotificationOperations.SearchAll() diff --git a/src/ApiService/ApiService/Functions/Negotiate.cs b/src/ApiService/ApiService/Functions/Negotiate.cs index 21e5f05c9a..16f9496ec1 100644 --- a/src/ApiService/ApiService/Functions/Negotiate.cs +++ b/src/ApiService/ApiService/Functions/Negotiate.cs @@ -2,32 +2,24 @@ using System.Threading.Tasks; using Microsoft.Azure.Functions.Worker; using Microsoft.Azure.Functions.Worker.Http; +using Microsoft.OneFuzz.Service.Auth; namespace Microsoft.OneFuzz.Service.Functions; public class Negotiate { - private readonly IEndpointAuthorization _auth; - public Negotiate(IEndpointAuthorization auth) { - _auth = auth; - } - [Function("Negotiate")] - public Task Run( + [Authorize(Allow.User)] + public static async Task Run( [HttpTrigger(AuthorizationLevel.Anonymous, "POST")] HttpRequestData req, - [SignalRConnectionInfoInput(HubName = "dashboard")] string info) - => _auth.CallIfUser(req, r => r.Method switch { - "POST" => Post(r, info), - var m => throw new InvalidOperationException($"Unsupported HTTP method {m}"), - }); + [SignalRConnectionInfoInput(HubName = "dashboard")] string info) { - // This endpoint handles the signalr negotation - // As we do not differentiate from clients at this time, we pass the Functions runtime - // provided connection straight to the client - // - // For more info: - // https://docs.microsoft.com/en-us/azure/azure-signalr/signalr-concept-internals + // This endpoint handles the signalr negotation + // As we do not differentiate from clients at this time, we pass the Functions runtime + // provided connection straight to the client + // + // For more info: + // https://docs.microsoft.com/en-us/azure/azure-signalr/signalr-concept-internals - private static async Task Post(HttpRequestData req, string info) { var resp = req.CreateResponse(HttpStatusCode.OK); resp.Headers.Add("Content-Type", "application/json"); await resp.WriteStringAsync(info); diff --git a/src/ApiService/ApiService/Functions/Node.cs b/src/ApiService/ApiService/Functions/Node.cs index 70231b90d2..b7f033ce22 100644 --- a/src/ApiService/ApiService/Functions/Node.cs +++ b/src/ApiService/ApiService/Functions/Node.cs @@ -1,30 +1,42 @@ using System.Threading.Tasks; using Microsoft.Azure.Functions.Worker; using Microsoft.Azure.Functions.Worker.Http; +using Microsoft.OneFuzz.Service.Auth; namespace Microsoft.OneFuzz.Service.Functions; public class Node { private readonly ILogTracer _log; - private readonly IEndpointAuthorization _auth; private readonly IOnefuzzContext _context; - public Node(ILogTracer log, IEndpointAuthorization auth, IOnefuzzContext context) { + public Node(ILogTracer log, IOnefuzzContext context) { _log = log; - _auth = auth; _context = context; } + public const string Route = "node"; + [Function("Node")] - public Async.Task Run([HttpTrigger(AuthorizationLevel.Anonymous, "GET", "PATCH", "POST", "DELETE")] HttpRequestData req) { - return _auth.CallIfUser(req, r => r.Method switch { - "GET" => Get(r), - "PATCH" => Patch(r), - "POST" => Post(r), - "DELETE" => Delete(r), + [Authorize(Allow.User)] + public Task Run( + [HttpTrigger(AuthorizationLevel.Anonymous, "GET", Route=Route)] + HttpRequestData req) + => req.Method switch { + "GET" => Get(req), _ => throw new InvalidOperationException("Unsupported HTTP method"), - }); - } + }; + + [Function("Node_Admin")] + [Authorize(Allow.Admin)] + public Task Admin( + [HttpTrigger(AuthorizationLevel.Anonymous, "PATCH", "POST", "DELETE", Route=Route)] + HttpRequestData req) + => req.Method switch { + "PATCH" => Patch(req), + "POST" => Post(req), + "DELETE" => Delete(req), + _ => throw new InvalidOperationException("Unsupported HTTP method"), + }; private async Async.Task Get(HttpRequestData req) { var request = await RequestHandling.ParseRequest(req); @@ -82,11 +94,6 @@ private async Async.Task Patch(HttpRequestData req) { "NodeReimage"); } - var authCheck = await _auth.CheckRequireAdmins(req); - if (!authCheck.IsOk) { - return await _context.RequestHandling.NotOk(req, authCheck.ErrorV, "NodeReimage"); - } - var patch = request.OkV; var node = await _context.NodeOperations.GetByMachineId(patch.MachineId); if (node is null) { @@ -116,11 +123,6 @@ private async Async.Task Post(HttpRequestData req) { "NodeUpdate"); } - var authCheck = await _auth.CheckRequireAdmins(req); - if (!authCheck.IsOk) { - return await _context.RequestHandling.NotOk(req, authCheck.ErrorV, "NodeUpdate"); - } - var post = request.OkV; var node = await _context.NodeOperations.GetByMachineId(post.MachineId); if (node is null) { @@ -150,11 +152,6 @@ private async Async.Task Delete(HttpRequestData req) { context: "NodeDelete"); } - var authCheck = await _auth.CheckRequireAdmins(req); - if (!authCheck.IsOk) { - return await _context.RequestHandling.NotOk(req, authCheck.ErrorV, "NodeDelete"); - } - var delete = request.OkV; var node = await _context.NodeOperations.GetByMachineId(delete.MachineId); if (node is null) { diff --git a/src/ApiService/ApiService/Functions/NodeAddSshKey.cs b/src/ApiService/ApiService/Functions/NodeAddSshKey.cs index 1bce7c9fb2..524acaf8fd 100644 --- a/src/ApiService/ApiService/Functions/NodeAddSshKey.cs +++ b/src/ApiService/ApiService/Functions/NodeAddSshKey.cs @@ -1,22 +1,20 @@ using System.Net; using Microsoft.Azure.Functions.Worker; using Microsoft.Azure.Functions.Worker.Http; +using Microsoft.OneFuzz.Service.Auth; namespace Microsoft.OneFuzz.Service.Functions; public class NodeAddSshKey { - - private readonly ILogTracer _log; - private readonly IEndpointAuthorization _auth; private readonly IOnefuzzContext _context; - public NodeAddSshKey(ILogTracer log, IEndpointAuthorization auth, IOnefuzzContext context) { - _log = log; - _auth = auth; + public NodeAddSshKey(IOnefuzzContext context) { _context = context; } - private async Async.Task Post(HttpRequestData req) { + [Function("NodeAddSshKey")] + [Authorize(Allow.User)] + public async Async.Task Run([HttpTrigger(AuthorizationLevel.Anonymous, "POST", Route = "node/add_ssh_key")] HttpRequestData req) { var request = await RequestHandling.ParseRequest(req); if (!request.IsOk) { return await _context.RequestHandling.NotOk( @@ -42,16 +40,5 @@ private async Async.Task Post(HttpRequestData req) { var response = req.CreateResponse(HttpStatusCode.OK); await response.WriteAsJsonAsync(new BoolResult(true)); return response; - - } - - [Function("NodeAddSshKey")] - public Async.Task Run([HttpTrigger(AuthorizationLevel.Anonymous, "POST", Route = "node/add_ssh_key")] HttpRequestData req) { - return _auth.CallIfUser(req, r => r.Method switch { - "POST" => Post(r), - _ => throw new InvalidOperationException("Unsupported HTTP method"), - }); - } - } diff --git a/src/ApiService/ApiService/Functions/Notifications.cs b/src/ApiService/ApiService/Functions/Notifications.cs index 5013521a26..323ef65c61 100644 --- a/src/ApiService/ApiService/Functions/Notifications.cs +++ b/src/ApiService/ApiService/Functions/Notifications.cs @@ -1,17 +1,16 @@ using System.Net; using Microsoft.Azure.Functions.Worker; using Microsoft.Azure.Functions.Worker.Http; +using Microsoft.OneFuzz.Service.Auth; namespace Microsoft.OneFuzz.Service.Functions; public class Notifications { private readonly ILogTracer _log; - private readonly IEndpointAuthorization _auth; private readonly IOnefuzzContext _context; - public Notifications(ILogTracer log, IEndpointAuthorization auth, IOnefuzzContext context) { + public Notifications(ILogTracer log, IOnefuzzContext context) { _log = log; - _auth = auth; _context = context; } @@ -82,12 +81,12 @@ private async Async.Task Delete(HttpRequestData req) { [Function("Notifications")] - public Async.Task Run([HttpTrigger(AuthorizationLevel.Anonymous, "GET", "POST", "DELETE")] HttpRequestData req) { - return _auth.CallIfUser(req, r => r.Method switch { - "GET" => Get(r), - "POST" => Post(r), - "DELETE" => Delete(r), + [Authorize(Allow.User)] + public Async.Task Run([HttpTrigger(AuthorizationLevel.Anonymous, "GET", "POST", "DELETE")] HttpRequestData req) + => req.Method switch { + "GET" => Get(req), + "POST" => Post(req), + "DELETE" => Delete(req), _ => throw new InvalidOperationException("Unsupported HTTP method"), - }); - } + }; } diff --git a/src/ApiService/ApiService/Functions/NotificationsTest.cs b/src/ApiService/ApiService/Functions/NotificationsTest.cs index 16a1a5c982..11beedcf70 100644 --- a/src/ApiService/ApiService/Functions/NotificationsTest.cs +++ b/src/ApiService/ApiService/Functions/NotificationsTest.cs @@ -1,21 +1,22 @@ using System.Net; using Microsoft.Azure.Functions.Worker; using Microsoft.Azure.Functions.Worker.Http; +using Microsoft.OneFuzz.Service.Auth; namespace Microsoft.OneFuzz.Service.Functions; public class NotificationsTest { private readonly ILogTracer _log; - private readonly IEndpointAuthorization _auth; private readonly IOnefuzzContext _context; - public NotificationsTest(ILogTracer log, IEndpointAuthorization auth, IOnefuzzContext context) { + public NotificationsTest(ILogTracer log, IOnefuzzContext context) { _log = log; - _auth = auth; _context = context; } - private async Async.Task Post(HttpRequestData req) { + [Function("NotificationsTest")] + [Authorize(Allow.User)] + public async Async.Task Run([HttpTrigger(AuthorizationLevel.Anonymous, "POST", Route = "notifications/test")] HttpRequestData req) { _log.WithTag("HttpRequest", "GET").Info($"Notification test"); var request = await RequestHandling.ParseRequest(req); if (!request.IsOk) { @@ -29,13 +30,4 @@ private async Async.Task Post(HttpRequestData req) { await response.WriteAsJsonAsync(new NotificationTestResponse(result.IsOk, result.ErrorV?.ToString())); return response; } - - - [Function("NotificationsTest")] - public Async.Task Run([HttpTrigger(AuthorizationLevel.Anonymous, "POST", Route = "notifications/test")] HttpRequestData req) { - return _auth.CallIfUser(req, r => r.Method switch { - "POST" => Post(r), - _ => throw new InvalidOperationException("Unsupported HTTP method"), - }); - } } diff --git a/src/ApiService/ApiService/Functions/Pool.cs b/src/ApiService/ApiService/Functions/Pool.cs index 9d4ee7f5df..c160e3b163 100644 --- a/src/ApiService/ApiService/Functions/Pool.cs +++ b/src/ApiService/ApiService/Functions/Pool.cs @@ -2,29 +2,40 @@ using Azure.Storage.Sas; using Microsoft.Azure.Functions.Worker; using Microsoft.Azure.Functions.Worker.Http; +using Microsoft.OneFuzz.Service.Auth; namespace Microsoft.OneFuzz.Service.Functions; public class Pool { - private readonly ILogTracer _log; - private readonly IEndpointAuthorization _auth; private readonly IOnefuzzContext _context; - public Pool(ILogTracer log, IEndpointAuthorization auth, IOnefuzzContext context) { - _log = log; - _auth = auth; + public Pool(IOnefuzzContext context) { _context = context; } + public const string Route = "pool"; + [Function("Pool")] - public Async.Task Run([HttpTrigger(AuthorizationLevel.Anonymous, "GET", "POST", "DELETE", "PATCH")] HttpRequestData req) - => _auth.CallIfUser(req, r => r.Method switch { - "GET" => Get(r), - "POST" => Post(r), - "DELETE" => Delete(r), - "PATCH" => Patch(r), - var m => throw new InvalidOperationException("Unsupported HTTP method {m}"), - }); + [Authorize(Allow.User)] + public Task Run( + [HttpTrigger(AuthorizationLevel.Anonymous, "GET", Route=Route)] + HttpRequestData req) + => req.Method switch { + "GET" => Get(req), + _ => throw new InvalidOperationException("Unsupported HTTP method {m}"), + }; + + [Function("Pool_Admin")] + [Authorize(Allow.Admin)] + public Task Admin( + [HttpTrigger(AuthorizationLevel.Anonymous, "POST", "DELETE", "PATCH", Route=Route)] + HttpRequestData req) + => req.Method switch { + "POST" => Post(req), + "DELETE" => Delete(req), + "PATCH" => Patch(req), + _ => throw new InvalidOperationException("Unsupported HTTP method {m}"), + }; private async Task Delete(HttpRequestData r) { var request = await RequestHandling.ParseRequest(r); @@ -32,11 +43,6 @@ private async Task Delete(HttpRequestData r) { return await _context.RequestHandling.NotOk(r, request.ErrorV, "PoolDelete"); } - var answer = await _auth.CheckRequireAdmins(r); - if (!answer.IsOk) { - return await _context.RequestHandling.NotOk(r, answer.ErrorV, "PoolDelete"); - } - var poolResult = await _context.PoolOperations.GetByName(request.OkV.Name); if (!poolResult.IsOk) { return await _context.RequestHandling.NotOk(r, poolResult.ErrorV, "pool stop"); @@ -53,11 +59,6 @@ private async Task Post(HttpRequestData req) { return await _context.RequestHandling.NotOk(req, request.ErrorV, "PoolCreate"); } - var answer = await _auth.CheckRequireAdmins(req); - if (!answer.IsOk) { - return await _context.RequestHandling.NotOk(req, answer.ErrorV, "PoolCreate"); - } - var create = request.OkV; var pool = await _context.PoolOperations.GetByName(create.Name); if (pool.IsOk) { @@ -77,11 +78,6 @@ private async Task Patch(HttpRequestData req) { return await _context.RequestHandling.NotOk(req, request.ErrorV, "PoolUpdate"); } - var answer = await _auth.CheckRequireAdmins(req); - if (!answer.IsOk) { - return await _context.RequestHandling.NotOk(req, answer.ErrorV, "PoolUpdate"); - } - var update = request.OkV; var pool = await _context.PoolOperations.GetByName(update.Name); if (!pool.IsOk) { diff --git a/src/ApiService/ApiService/Functions/Proxy.cs b/src/ApiService/ApiService/Functions/Proxy.cs index ff56dd5fa9..85e8b4221e 100644 --- a/src/ApiService/ApiService/Functions/Proxy.cs +++ b/src/ApiService/ApiService/Functions/Proxy.cs @@ -1,32 +1,32 @@ using System.Net; using Microsoft.Azure.Functions.Worker; using Microsoft.Azure.Functions.Worker.Http; +using Microsoft.OneFuzz.Service.Auth; using VmProxy = Microsoft.OneFuzz.Service.Proxy; namespace Microsoft.OneFuzz.Service.Functions; public class Proxy { private readonly ILogTracer _log; - private readonly IEndpointAuthorization _auth; private readonly IOnefuzzContext _context; - public Proxy(ILogTracer log, IEndpointAuthorization auth, IOnefuzzContext context) { + public Proxy(ILogTracer log, IOnefuzzContext context) { _log = log; - _auth = auth; _context = context; } [Function("Proxy")] - public Async.Task Run([HttpTrigger(AuthorizationLevel.Anonymous, "GET", "PATCH", "POST", "DELETE")] HttpRequestData req) { - return _auth.CallIfUser(req, r => r.Method switch { - "GET" => Get(r), - "PATCH" => Patch(r), - "POST" => Post(r), - "DELETE" => Delete(r), + [Authorize(Allow.User)] + public Async.Task Run( + [HttpTrigger(AuthorizationLevel.Anonymous, "GET", "PATCH", "POST", "DELETE")] + HttpRequestData req) + => req.Method switch { + "GET" => Get(req), + "PATCH" => Patch(req), + "POST" => Post(req), + "DELETE" => Delete(req), _ => throw new InvalidOperationException("Unsupported HTTP method"), - }); - } - + }; private ProxyGetResult GetResult(ProxyForward proxyForward, VmProxy? proxy) { var forward = _context.ProxyForwardOperations.ToForward(proxyForward); diff --git a/src/ApiService/ApiService/Functions/ReproVmss.cs b/src/ApiService/ApiService/Functions/ReproVmss.cs index f32be74614..6061c44098 100644 --- a/src/ApiService/ApiService/Functions/ReproVmss.cs +++ b/src/ApiService/ApiService/Functions/ReproVmss.cs @@ -3,29 +3,31 @@ using System.Text; using Microsoft.Azure.Functions.Worker; using Microsoft.Azure.Functions.Worker.Http; +using Microsoft.OneFuzz.Service.Auth; namespace Microsoft.OneFuzz.Service.Functions; public class ReproVmss { private readonly ILogTracer _log; - private readonly IEndpointAuthorization _auth; private readonly IOnefuzzContext _context; - public ReproVmss(ILogTracer log, IEndpointAuthorization auth, IOnefuzzContext context) { + public ReproVmss(ILogTracer log, IOnefuzzContext context) { _log = log; - _auth = auth; _context = context; } [Function("ReproVms")] - public Async.Task Run([HttpTrigger(AuthorizationLevel.Anonymous, "GET", "POST", "DELETE", Route = "repro_vms")] HttpRequestData req) { - return _auth.CallIfUser(req, r => r.Method switch { - "GET" => Get(r), - "POST" => Post(r), - "DELETE" => Delete(r), + [Authorize(Allow.User)] + public Async.Task Run( + [HttpTrigger(AuthorizationLevel.Anonymous, "GET", "POST", "DELETE", Route = "repro_vms")] + HttpRequestData req, + FunctionContext context) + => req.Method switch { + "GET" => Get(req), + "POST" => Post(req, context), + "DELETE" => Delete(req), _ => throw new InvalidOperationException("Unsupported HTTP method"), - }); - } + }; private async Async.Task Get(HttpRequestData req) { var request = await RequestHandling.ParseRequest(req); @@ -56,7 +58,7 @@ private async Async.Task Get(HttpRequestData req) { } - private async Async.Task Post(HttpRequestData req) { + private async Async.Task Post(HttpRequestData req, FunctionContext context) { var request = await RequestHandling.ParseRequest(req); if (!request.IsOk) { return await _context.RequestHandling.NotOk( @@ -65,13 +67,7 @@ private async Async.Task Post(HttpRequestData req) { "repro_vm create"); } - var userInfo = await _context.UserCredentials.ParseJwtToken(req); - if (!userInfo.IsOk) { - return await _context.RequestHandling.NotOk( - req, - userInfo.ErrorV, - "repro_vm create"); - } + var userInfo = context.GetUserAuthInfo(); var create = request.OkV; var cfg = new ReproConfig( @@ -79,7 +75,7 @@ private async Async.Task Post(HttpRequestData req) { Path: create.Path, Duration: create.Duration); - var vm = await _context.ReproOperations.Create(cfg, userInfo.OkV.UserInfo); + var vm = await _context.ReproOperations.Create(cfg, userInfo.UserInfo); if (!vm.IsOk) { return await _context.RequestHandling.NotOk( req, @@ -98,7 +94,7 @@ private async Async.Task Post(HttpRequestData req) { // we’d like to track the usage of this feature; // anonymize the user ID so we can distinguish multiple requests { - var data = userInfo.OkV.UserInfo.ToString(); // rely on record ToString + var data = userInfo.UserInfo.ToString(); // rely on record ToString var hash = Convert.ToBase64String(SHA256.HashData(Encoding.UTF8.GetBytes(data))); _log.Event($"created repro VM, user distinguisher: {hash:Tag:UserHash}"); } diff --git a/src/ApiService/ApiService/Functions/Scaleset.cs b/src/ApiService/ApiService/Functions/Scaleset.cs index 403ee54dfa..c01383fa83 100644 --- a/src/ApiService/ApiService/Functions/Scaleset.cs +++ b/src/ApiService/ApiService/Functions/Scaleset.cs @@ -1,30 +1,42 @@ using System.Threading.Tasks; using Microsoft.Azure.Functions.Worker; using Microsoft.Azure.Functions.Worker.Http; +using Microsoft.OneFuzz.Service.Auth; namespace Microsoft.OneFuzz.Service.Functions; public class Scaleset { private readonly ILogTracer _log; - private readonly IEndpointAuthorization _auth; private readonly IOnefuzzContext _context; - public Scaleset(ILogTracer log, IEndpointAuthorization auth, IOnefuzzContext context) { + public Scaleset(ILogTracer log, IOnefuzzContext context) { _log = log; - _auth = auth; _context = context; } + public const string Route = "scaleset"; + [Function("Scaleset")] - public Async.Task Run([HttpTrigger(AuthorizationLevel.Anonymous, "GET", "PATCH", "POST", "DELETE")] HttpRequestData req) { - return _auth.CallIfUser(req, r => r.Method switch { - "GET" => Get(r), - "PATCH" => Patch(r), - "POST" => Post(r), - "DELETE" => Delete(r), + [Authorize(Allow.User)] + public Async.Task Run( + [HttpTrigger(AuthorizationLevel.Anonymous, "GET", Route=Route)] + HttpRequestData req) + => req.Method switch { + "GET" => Get(req), _ => throw new InvalidOperationException("Unsupported HTTP method"), - }); - } + }; + + [Function("Scaleset_Admin")] + [Authorize(Allow.Admin)] + public Async.Task Admin( + [HttpTrigger(AuthorizationLevel.Anonymous, "PATCH", "POST", "DELETE", Route=Route)] + HttpRequestData req) + => req.Method switch { + "PATCH" => Patch(req), + "POST" => Post(req), + "DELETE" => Delete(req), + _ => throw new InvalidOperationException("Unsupported HTTP method"), + }; private async Task Delete(HttpRequestData req) { var request = await RequestHandling.ParseRequest(req); @@ -32,11 +44,6 @@ private async Task Delete(HttpRequestData req) { return await _context.RequestHandling.NotOk(req, request.ErrorV, "ScalesetDelete"); } - var answer = await _auth.CheckRequireAdmins(req); - if (!answer.IsOk) { - return await _context.RequestHandling.NotOk(req, answer.ErrorV, "ScalesetDelete"); - } - var scalesetResult = await _context.ScalesetOperations.GetById(request.OkV.ScalesetId); if (!scalesetResult.IsOk) { return await _context.RequestHandling.NotOk(req, scalesetResult.ErrorV, "ScalesetDelete"); @@ -54,11 +61,6 @@ private async Task Post(HttpRequestData req) { return await _context.RequestHandling.NotOk(req, request.ErrorV, "ScalesetCreate"); } - var answer = await _auth.CheckRequireAdmins(req); - if (!answer.IsOk) { - return await _context.RequestHandling.NotOk(req, answer.ErrorV, "ScalesetCreate"); - } - var create = request.OkV; // verify the pool exists var poolResult = await _context.PoolOperations.GetByName(create.PoolName); @@ -121,7 +123,7 @@ private async Task Post(HttpRequestData req) { ScalesetId: Service.Scaleset.GenerateNewScalesetId(create.PoolName), State: ScalesetState.Init, NeedsConfigUpdate: false, - Auth: new SecretValue(await Auth.BuildAuth(_log)), + Auth: new SecretValue(await AuthHelpers.BuildAuth(_log)), PoolName: create.PoolName, VmSku: create.VmSku, Image: image, @@ -172,11 +174,6 @@ private async Task Patch(HttpRequestData req) { return await _context.RequestHandling.NotOk(req, request.ErrorV, "ScalesetUpdate"); } - var answer = await _auth.CheckRequireAdmins(req); - if (!answer.IsOk) { - return await _context.RequestHandling.NotOk(req, answer.ErrorV, "ScalesetUpdate"); - } - var scalesetResult = await _context.ScalesetOperations.GetById(request.OkV.ScalesetId); if (!scalesetResult.IsOk) { return await _context.RequestHandling.NotOk(req, scalesetResult.ErrorV, "ScalesetUpdate"); diff --git a/src/ApiService/ApiService/Functions/Tasks.cs b/src/ApiService/ApiService/Functions/Tasks.cs index 5c8edbc786..c2b759f527 100644 --- a/src/ApiService/ApiService/Functions/Tasks.cs +++ b/src/ApiService/ApiService/Functions/Tasks.cs @@ -2,29 +2,29 @@ using System.Threading.Tasks; using Microsoft.Azure.Functions.Worker; using Microsoft.Azure.Functions.Worker.Http; +using Microsoft.OneFuzz.Service.Auth; namespace Microsoft.OneFuzz.Service.Functions; public class Tasks { - private readonly ILogTracer _log; - private readonly IEndpointAuthorization _auth; private readonly IOnefuzzContext _context; - public Tasks(ILogTracer log, IEndpointAuthorization auth, IOnefuzzContext context) { - _log = log; - _auth = auth; + public Tasks(IOnefuzzContext context) { _context = context; } [Function("Tasks")] - public Async.Task Run([HttpTrigger(AuthorizationLevel.Anonymous, "GET", "POST", "DELETE")] HttpRequestData req) { - return _auth.CallIfUser(req, r => r.Method switch { - "GET" => Get(r), - "POST" => Post(r), - "DELETE" => Delete(r), + [Authorize(Allow.User)] + public Async.Task Run( + [HttpTrigger(AuthorizationLevel.Anonymous, "GET", "POST", "DELETE")] + HttpRequestData req, + FunctionContext context) + => req.Method switch { + "GET" => Get(req), + "POST" => Post(req, context), + "DELETE" => Delete(req), _ => throw new InvalidOperationException("Unsupported HTTP method"), - }); - } + }; private async Async.Task Get(HttpRequestData req) { var request = await RequestHandling.ParseRequest(req); @@ -73,7 +73,7 @@ private async Async.Task Get(HttpRequestData req) { } - private async Async.Task Post(HttpRequestData req) { + private async Async.Task Post(HttpRequestData req, FunctionContext context) { var request = await RequestHandling.ParseRequest(req); if (!request.IsOk) { return await _context.RequestHandling.NotOk( @@ -82,10 +82,7 @@ private async Async.Task Post(HttpRequestData req) { "task create"); } - var userInfo = await _context.UserCredentials.ParseJwtToken(req); - if (!userInfo.IsOk) { - return await _context.RequestHandling.NotOk(req, userInfo.ErrorV, "task create"); - } + var userInfo = context.GetUserAuthInfo(); var create = request.OkV; var cfg = new TaskConfig( @@ -141,7 +138,7 @@ private async Async.Task Post(HttpRequestData req) { } } - var task = await _context.TaskOperations.Create(cfg, cfg.JobId, userInfo.OkV.UserInfo); + var task = await _context.TaskOperations.Create(cfg, cfg.JobId, userInfo.UserInfo); if (!task.IsOk) { return await _context.RequestHandling.NotOk( diff --git a/src/ApiService/ApiService/Functions/Tool.cs b/src/ApiService/ApiService/Functions/Tool.cs index 1bf4a3f910..b4b7936843 100644 --- a/src/ApiService/ApiService/Functions/Tool.cs +++ b/src/ApiService/ApiService/Functions/Tool.cs @@ -1,19 +1,22 @@ using System.Net; using Microsoft.Azure.Functions.Worker; using Microsoft.Azure.Functions.Worker.Http; +using Microsoft.OneFuzz.Service.Auth; namespace Microsoft.OneFuzz.Service.Functions; public class Tools { private readonly IOnefuzzContext _context; - private readonly IEndpointAuthorization _auth; - public Tools(IEndpointAuthorization auth, IOnefuzzContext context) { + public Tools(IOnefuzzContext context) { _context = context; - _auth = auth; } - public async Async.Task GetResponse(HttpRequestData req) { + [Function("Tools")] + [Authorize(Allow.User)] + public async Async.Task Run( + [HttpTrigger(AuthorizationLevel.Anonymous, "GET")] HttpRequestData req) { + //Note: streaming response are not currently supported by in isolated functions // https://github.com/Azure/azure-functions-dotnet-worker/issues/958 var response = req.CreateResponse(HttpStatusCode.OK); @@ -23,9 +26,4 @@ public async Async.Task GetResponse(HttpRequestData req) { } return response; } - - - [Function("Tools")] - public Async.Task Run([HttpTrigger(AuthorizationLevel.Anonymous, "GET")] HttpRequestData req) - => _auth.CallIfUser(req, GetResponse); } diff --git a/src/ApiService/ApiService/Functions/ValidateScriban.cs b/src/ApiService/ApiService/Functions/ValidateScriban.cs index 4e8b003354..90bd90b723 100644 --- a/src/ApiService/ApiService/Functions/ValidateScriban.cs +++ b/src/ApiService/ApiService/Functions/ValidateScriban.cs @@ -1,5 +1,6 @@ using Microsoft.Azure.Functions.Worker; using Microsoft.Azure.Functions.Worker.Http; +using Microsoft.OneFuzz.Service.Auth; namespace Microsoft.OneFuzz.Service.Functions; @@ -11,7 +12,11 @@ public ValidateScriban(ILogTracer log, IOnefuzzContext context) { _context = context; } - private async Async.Task Post(HttpRequestData req) { + [Function("ValidateScriban")] + [Authorize(Allow.User)] + public async Async.Task Run( + [HttpTrigger(AuthorizationLevel.Anonymous, "POST")] + HttpRequestData req) { var request = await RequestHandling.ParseRequest(req); if (!request.IsOk) { return await _context.RequestHandling.NotOk(req, request.ErrorV, "ValidateTemplate"); @@ -27,14 +32,5 @@ private async Async.Task Post(HttpRequestData req) { $"Template failed to render due to: `{e.Message}`" ); } - - } - - [Function("ValidateScriban")] - public Async.Task Run([HttpTrigger(AuthorizationLevel.Anonymous, "POST")] HttpRequestData req) { - return req.Method switch { - "POST" => Post(req), - _ => throw new InvalidOperationException("Unsupported HTTP method"), - }; } } diff --git a/src/ApiService/ApiService/Functions/WebhookLogs.cs b/src/ApiService/ApiService/Functions/WebhookLogs.cs index 140d16784e..1899fa1623 100644 --- a/src/ApiService/ApiService/Functions/WebhookLogs.cs +++ b/src/ApiService/ApiService/Functions/WebhookLogs.cs @@ -3,28 +3,22 @@ namespace Microsoft.OneFuzz.Service.Functions; using Microsoft.Azure.Functions.Worker; using Microsoft.Azure.Functions.Worker.Http; +using Microsoft.OneFuzz.Service.Auth; public class WebhookLogs { private readonly ILogTracer _log; - private readonly IEndpointAuthorization _auth; private readonly IOnefuzzContext _context; - public WebhookLogs(ILogTracer log, IEndpointAuthorization auth, IOnefuzzContext context) { + public WebhookLogs(ILogTracer log, IOnefuzzContext context) { _log = log; - _auth = auth; _context = context; } [Function("WebhookLogs")] - public Async.Task Run( - [HttpTrigger(AuthorizationLevel.Anonymous, "POST", Route = "webhooks/logs")] HttpRequestData req) { - return _auth.CallIfUser(req, r => r.Method switch { - "POST" => Post(r), - _ => throw new InvalidOperationException("Unsupported HTTP method"), - }); - } - - private async Async.Task Post(HttpRequestData req) { + [Authorize(Allow.User)] + public async Async.Task Run( + [HttpTrigger(AuthorizationLevel.Anonymous, "POST", Route = "webhooks/logs")] + HttpRequestData req) { var request = await RequestHandling.ParseRequest(req); if (!request.IsOk) { return await _context.RequestHandling.NotOk( diff --git a/src/ApiService/ApiService/Functions/WebhookPing.cs b/src/ApiService/ApiService/Functions/WebhookPing.cs index 623c0bed5c..e3e4574d85 100644 --- a/src/ApiService/ApiService/Functions/WebhookPing.cs +++ b/src/ApiService/ApiService/Functions/WebhookPing.cs @@ -3,28 +3,22 @@ namespace Microsoft.OneFuzz.Service.Functions; using Microsoft.Azure.Functions.Worker; using Microsoft.Azure.Functions.Worker.Http; +using Microsoft.OneFuzz.Service.Auth; public class WebhookPing { private readonly ILogTracer _log; - private readonly IEndpointAuthorization _auth; private readonly IOnefuzzContext _context; - public WebhookPing(ILogTracer log, IEndpointAuthorization auth, IOnefuzzContext context) { + public WebhookPing(ILogTracer log, IOnefuzzContext context) { _log = log; - _auth = auth; _context = context; } [Function("WebhookPing")] - public Async.Task Run( - [HttpTrigger(AuthorizationLevel.Anonymous, "POST", Route = "webhooks/ping")] HttpRequestData req) { - return _auth.CallIfUser(req, r => r.Method switch { - "POST" => Post(r), - _ => throw new InvalidOperationException("Unsupported HTTP method"), - }); - } - - private async Async.Task Post(HttpRequestData req) { + [Authorize(Allow.User)] + public async Async.Task Run( + [HttpTrigger(AuthorizationLevel.Anonymous, "POST", Route = "webhooks/ping")] + HttpRequestData req) { var request = await RequestHandling.ParseRequest(req); if (!request.IsOk) { return await _context.RequestHandling.NotOk( diff --git a/src/ApiService/ApiService/Functions/Webhooks.cs b/src/ApiService/ApiService/Functions/Webhooks.cs index 1ac1d30cb7..538e49217b 100644 --- a/src/ApiService/ApiService/Functions/Webhooks.cs +++ b/src/ApiService/ApiService/Functions/Webhooks.cs @@ -3,31 +3,29 @@ namespace Microsoft.OneFuzz.Service.Functions; using Microsoft.Azure.Functions.Worker; using Microsoft.Azure.Functions.Worker.Http; - +using Microsoft.OneFuzz.Service.Auth; public class Webhooks { private readonly ILogTracer _log; - private readonly IEndpointAuthorization _auth; private readonly IOnefuzzContext _context; - public Webhooks(ILogTracer log, IEndpointAuthorization auth, IOnefuzzContext context) { + public Webhooks(ILogTracer log, IOnefuzzContext context) { _log = log; - _auth = auth; _context = context; } [Function("Webhooks")] + [Authorize(Allow.User)] public Async.Task Run( - [HttpTrigger(AuthorizationLevel.Anonymous, "GET", "POST", "DELETE", "PATCH")] HttpRequestData req) { - return _auth.CallIfUser(req, r => r.Method switch { - "GET" => Get(r), - "POST" => Post(r), - "DELETE" => Delete(r), - "PATCH" => Patch(r), + [HttpTrigger(AuthorizationLevel.Anonymous, "GET", "POST", "DELETE", "PATCH")] + HttpRequestData req) + => req.Method switch { + "GET" => Get(req), + "POST" => Post(req), + "DELETE" => Delete(req), + "PATCH" => Patch(req), _ => throw new InvalidOperationException("Unsupported HTTP method"), - }); - } - + }; private async Async.Task Get(HttpRequestData req) { var request = await RequestHandling.ParseRequest(req); diff --git a/src/ApiService/ApiService/HttpClient.cs b/src/ApiService/ApiService/HttpClient.cs index 89de11da38..74d8007eb7 100644 --- a/src/ApiService/ApiService/HttpClient.cs +++ b/src/ApiService/ApiService/HttpClient.cs @@ -9,7 +9,7 @@ namespace Microsoft.OneFuzz.Service; public class Request { private readonly HttpClient _httpClient; - Func>? _auth; + private readonly Func>? _auth; public Request(HttpClient httpClient, Func>? auth = null) { _auth = auth; diff --git a/src/ApiService/ApiService/Program.cs b/src/ApiService/ApiService/Program.cs index a9a71d48ed..f3fe03dc66 100644 --- a/src/ApiService/ApiService/Program.cs +++ b/src/ApiService/ApiService/Program.cs @@ -1,11 +1,8 @@ -// to avoid collision with Task in model.cs -global using System; -global -using System.Collections.Generic; -global -using System.Linq; -global -using Async = System.Threading.Tasks; +global using System; +global using System.Collections.Generic; +global using System.Linq; +// to avoid collision with Task in model.cs +global using Async = System.Threading.Tasks; using System.Text.Json; using ApiService.OneFuzzLib.Orm; using Azure.Core.Serialization; @@ -99,7 +96,6 @@ public static async Async.Task Main() { .AddScoped() .AddScoped() .AddScoped() - .AddScoped() .AddScoped() .AddScoped() .AddScoped() @@ -140,6 +136,8 @@ public static async Async.Task Main() { .ConfigureFunctionsWorkerDefaults(builder => { builder.UseAzureAppConfiguration(); builder.UseMiddleware(); + builder.UseMiddleware(); + builder.UseMiddleware(); builder.AddApplicationInsights(options => { options.ConnectionString = $"InstrumentationKey={configuration.ApplicationInsightsInstrumentationKey}"; }); diff --git a/src/ApiService/ApiService/UserCredentials.cs b/src/ApiService/ApiService/UserCredentials.cs deleted file mode 100644 index 8072471951..0000000000 --- a/src/ApiService/ApiService/UserCredentials.cs +++ /dev/null @@ -1,103 +0,0 @@ -using System.IdentityModel.Tokens.Jwt; -using System.Net.Http.Headers; -using System.Threading.Tasks; -using Microsoft.Azure.Functions.Worker.Http; -using Microsoft.IdentityModel.Tokens; - - -namespace Microsoft.OneFuzz.Service; - -public interface IUserCredentials { - public string? GetBearerToken(HttpRequestData req); - public string? GetAuthToken(HttpRequestData req); - public Task> ParseJwtToken(HttpRequestData req); -} - -public record UserAuthInfo(UserInfo UserInfo, List Roles); - -public class UserCredentials : IUserCredentials { - ILogTracer _log; - IConfigOperations _instanceConfig; - private JwtSecurityTokenHandler _tokenHandler; - - public UserCredentials(ILogTracer log, IConfigOperations instanceConfig) { - _log = log; - _instanceConfig = instanceConfig; - _tokenHandler = new JwtSecurityTokenHandler(); - } - - public string? GetBearerToken(HttpRequestData req) { - if (!req.Headers.TryGetValues("Authorization", out var authHeader) || authHeader.IsNullOrEmpty()) { - return null; - } else { - var auth = AuthenticationHeaderValue.Parse(authHeader.First()); - return auth.Scheme.ToLower() switch { - "bearer" => auth.Parameter, - _ => null, - }; - } - } - - public string? GetAuthToken(HttpRequestData req) { - var token = GetBearerToken(req); - if (token is not null) { - return token; - } else { - if (!req.Headers.TryGetValues("x-ms-token-aad-id-token", out var tokenHeader) || tokenHeader.IsNullOrEmpty()) { - return null; - } else { - return tokenHeader.First(); - } - } - } - - async Task> GetAllowedTenants() { - var r = await _instanceConfig.Fetch(); - var allowedAddTenantsQuery = - from t in r.AllowedAadTenants - select $"https://sts.windows.net/{t}/"; - - return OneFuzzResult.Ok(allowedAddTenantsQuery.ToArray()); - } - - public virtual async Task> ParseJwtToken(HttpRequestData req) { - - - var authToken = GetAuthToken(req); - if (authToken is null) { - return OneFuzzResult.Error(ErrorCode.INVALID_REQUEST, new[] { "unable to find authorization token" }); - } else { - var token = new System.IdentityModel.Tokens.Jwt.JwtSecurityToken(authToken); - var allowedTenants = await GetAllowedTenants(); - if (allowedTenants.IsOk) { - if (allowedTenants.OkV is not null && allowedTenants.OkV.Contains(token.Issuer)) { - var userAuthInfo = new UserAuthInfo(new UserInfo(null, null, null), new List()); - var userInfo = - token.Payload.Claims.Aggregate(userAuthInfo, (acc, claim) => { - switch (claim.Type) { - case "oid": - return acc with { UserInfo = acc.UserInfo with { ObjectId = Guid.Parse(claim.Value) } }; - case "appid": - return acc with { UserInfo = acc.UserInfo with { ApplicationId = Guid.Parse(claim.Value) } }; - case "upn": - return acc with { UserInfo = acc.UserInfo with { Upn = claim.Value } }; - case "roles": - acc.Roles.Add(claim.Value); - return acc; - default: - return acc; - } - }); - return OneFuzzResult.Ok(userInfo); - } else { - var tenantsStr = allowedTenants.OkV is null ? "null" : String.Join(';', allowedTenants.OkV!); - _log.Error($"issuer not from allowed tenant. issuer: {token.Issuer:Tag:Issuer} - tenants: {tenantsStr:Tag:Tenants}"); - return OneFuzzResult.Error(ErrorCode.INVALID_REQUEST, new[] { "unauthorized AAD issuer. If multi-tenant auth is failing, make sure to include all tenant_ids in the `allowed_aad_tenants` list in the instance_config. To see the current instance_config, run `onefuzz instance_config get`. " }); - } - } else { - _log.Error($"Failed to get allowed tenants due to {allowedTenants.ErrorV:Tag:Error}"); - return OneFuzzResult.Error(allowedTenants.ErrorV); - } - } - } -} diff --git a/src/ApiService/ApiService/onefuzzlib/Auth.cs b/src/ApiService/ApiService/onefuzzlib/Auth.cs index 2869f78c25..c2a6395254 100644 --- a/src/ApiService/ApiService/onefuzzlib/Auth.cs +++ b/src/ApiService/ApiService/onefuzzlib/Auth.cs @@ -2,7 +2,7 @@ using System.Diagnostics; using System.IO; -public class Auth { +public static class AuthHelpers { private static ProcessStartInfo SshKeyGenProcConfig(string tempFile) { diff --git a/src/ApiService/ApiService/onefuzzlib/EndpointAuthorization.cs b/src/ApiService/ApiService/onefuzzlib/EndpointAuthorization.cs index 011e522310..cc88d1d924 100644 --- a/src/ApiService/ApiService/onefuzzlib/EndpointAuthorization.cs +++ b/src/ApiService/ApiService/onefuzzlib/EndpointAuthorization.cs @@ -1,36 +1,23 @@ -using System.Net; -using System.Net.Http; +using System.Net.Http; using Microsoft.Azure.Functions.Worker.Http; using Microsoft.Graph; namespace Microsoft.OneFuzz.Service; -public interface IEndpointAuthorization { - - Async.Task CallIfAgent( - HttpRequestData req, - Func> method) - => CallIf(req, method, allowAgent: true); - - Async.Task CallIfUser( - HttpRequestData req, - Func> method) - => CallIf(req, method, allowUser: true); - - Async.Task CallIf( - HttpRequestData req, - Func> method, - bool allowUser = false, - bool allowAgent = false); +public record UserAuthInfo(UserInfo UserInfo, List Roles); - Async.Task CheckRequireAdmins(HttpRequestData req); +public interface IEndpointAuthorization { + Async.Task CheckRequireAdmins(UserAuthInfo authInfo); + Async.Task<(bool, string)> IsAgent(UserAuthInfo authInfo); + Async.Task CheckAccess(HttpRequestData req); } + public class EndpointAuthorization : IEndpointAuthorization { private readonly IOnefuzzContext _context; private readonly ILogTracer _log; private readonly GraphServiceClient _graphClient; - private static readonly HashSet AgentRoles = new HashSet { "UnmanagedNode", "ManagedNode" }; + private static readonly IReadOnlySet _agentRoles = new HashSet() { "UnmanagedNode", "ManagedNode" }; public EndpointAuthorization(IOnefuzzContext context, ILogTracer log, GraphServiceClient graphClient) { _context = context; @@ -38,58 +25,7 @@ public EndpointAuthorization(IOnefuzzContext context, ILogTracer log, GraphServi _graphClient = graphClient; } - public virtual async Async.Task CallIf(HttpRequestData req, Func> method, bool allowUser = false, bool allowAgent = false) { - var tokenResult = await _context.UserCredentials.ParseJwtToken(req); - - if (!tokenResult.IsOk) { - return await _context.RequestHandling.NotOk(req, tokenResult.ErrorV, "token verification", HttpStatusCode.Unauthorized); - } - - var token = tokenResult.OkV.UserInfo; - - var (isAgent, reason) = await IsAgent(tokenResult.OkV); - - if (!isAgent) { - if (!allowUser) { - return await Reject(req, token, "endpoint not allowed for users"); - } - - var access = await CheckAccess(req); - if (!access.IsOk) { - return await _context.RequestHandling.NotOk(req, access.ErrorV, "access control", HttpStatusCode.Unauthorized); - } - } - - - if (isAgent && !allowAgent) { - return await Reject(req, token, reason); - } - - return await method(req); - } - - - public async Async.Task Reject(HttpRequestData req, UserInfo token, String? reason = null) { - var body = await req.ReadAsStringAsync(); - _log.Error($"reject token. reason:{reason} url:{req.Url:Tag:Url} token:{token:Tag:Token} body:{body:Tag:Body}"); - - return await _context.RequestHandling.NotOk( - req, - Error.Create( - ErrorCode.UNAUTHORIZED, - reason ?? "Unrecognized agent" - ), - "token verification", - HttpStatusCode.Unauthorized - ); - } - - public async Async.Task CheckRequireAdmins(HttpRequestData req) { - var tokenResult = await _context.UserCredentials.ParseJwtToken(req); - if (!tokenResult.IsOk) { - return tokenResult.ErrorV; - } - + public async Async.Task CheckRequireAdmins(UserAuthInfo authInfo) { var config = await _context.ConfigOperations.Fetch(); if (config is null) { return Error.Create( @@ -97,7 +33,7 @@ public async Async.Task CheckRequireAdmins(HttpRequestData re "no instance configuration found "); } - return CheckRequireAdminsImpl(config, tokenResult.OkV.UserInfo); + return CheckRequireAdminsImpl(config, authInfo.UserInfo); } private static OneFuzzResultVoid CheckRequireAdminsImpl(InstanceConfig config, UserInfo userInfo) { @@ -177,9 +113,8 @@ private GroupMembershipChecker CreateGroupMembershipChecker(InstanceConfig confi return null; } - public async Async.Task<(bool, string)> IsAgent(UserAuthInfo authInfo) { - if (!AgentRoles.Overlaps(authInfo.Roles)) { + if (!_agentRoles.Overlaps(authInfo.Roles)) { return (false, "no agent role"); } diff --git a/src/ApiService/ApiService/onefuzzlib/OnefuzzContext.cs b/src/ApiService/ApiService/onefuzzlib/OnefuzzContext.cs index 95a599452d..d877bfddbb 100644 --- a/src/ApiService/ApiService/onefuzzlib/OnefuzzContext.cs +++ b/src/ApiService/ApiService/onefuzzlib/OnefuzzContext.cs @@ -37,7 +37,6 @@ public interface IOnefuzzContext { IStorage Storage { get; } ITaskOperations TaskOperations { get; } ITaskEventOperations TaskEventOperations { get; } - IUserCredentials UserCredentials { get; } IVmOperations VmOperations { get; } IVmssOperations VmssOperations { get; } IWebhookMessageLogOperations WebhookMessageLogOperations { get; } @@ -77,7 +76,6 @@ public OnefuzzContext(IServiceProvider serviceProvider) { public IContainers Containers => _serviceProvider.GetRequiredService(); public IReports Reports => _serviceProvider.GetRequiredService(); public INotificationOperations NotificationOperations => _serviceProvider.GetRequiredService(); - public IUserCredentials UserCredentials => _serviceProvider.GetRequiredService(); public IReproOperations ReproOperations => _serviceProvider.GetRequiredService(); public IPoolOperations PoolOperations => _serviceProvider.GetRequiredService(); public IIpOperations IpOperations => _serviceProvider.GetRequiredService(); diff --git a/src/ApiService/ApiService/onefuzzlib/ProxyOperations.cs b/src/ApiService/ApiService/onefuzzlib/ProxyOperations.cs index e5cba14e6c..8f192235ef 100644 --- a/src/ApiService/ApiService/onefuzzlib/ProxyOperations.cs +++ b/src/ApiService/ApiService/onefuzzlib/ProxyOperations.cs @@ -59,7 +59,17 @@ public async Async.Task GetOrCreate(Region region) { } _logTracer.Info($"creating proxy: region:{region:Tag:Region}"); - var newProxy = new Proxy(region, Guid.NewGuid(), DateTimeOffset.UtcNow, VmState.Init, new SecretValue(await Auth.BuildAuth(_logTracer)), null, null, _context.ServiceConfiguration.OneFuzzVersion, null, false); + var newProxy = new Proxy( + region, + Guid.NewGuid(), + DateTimeOffset.UtcNow, + VmState.Init, + new SecretValue(await AuthHelpers.BuildAuth(_logTracer)), + null, + null, + _context.ServiceConfiguration.OneFuzzVersion, + null, + false); var r = await Replace(newProxy); if (!r.IsOk) { diff --git a/src/ApiService/ApiService/onefuzzlib/ReproOperations.cs b/src/ApiService/ApiService/onefuzzlib/ReproOperations.cs index 42fbe1b4c7..bd03824321 100644 --- a/src/ApiService/ApiService/onefuzzlib/ReproOperations.cs +++ b/src/ApiService/ApiService/onefuzzlib/ReproOperations.cs @@ -335,7 +335,7 @@ public async Task> Create(ReproConfig config, UserInfo user return OneFuzzResult.Error(ErrorCode.INVALID_REQUEST, "unable to find task"); } - var auth = await _context.SecretsOperations.StoreSecret(new SecretValue(await Auth.BuildAuth(_logTracer))); + var auth = await _context.SecretsOperations.StoreSecret(new SecretValue(await AuthHelpers.BuildAuth(_logTracer))); var vm = new Repro( VmId: Guid.NewGuid(), diff --git a/src/ApiService/IntegrationTests/AgentCanScheduleTests.cs b/src/ApiService/IntegrationTests/AgentCanScheduleTests.cs index f7a3c8aa1f..e6b663e87d 100644 --- a/src/ApiService/IntegrationTests/AgentCanScheduleTests.cs +++ b/src/ApiService/IntegrationTests/AgentCanScheduleTests.cs @@ -1,10 +1,6 @@ -using System.Net; -using IntegrationTests.Fakes; -using Microsoft.OneFuzz.Service; -using Microsoft.OneFuzz.Service.Functions; +using Microsoft.OneFuzz.Service; using Xunit; using Xunit.Abstractions; -using Async = System.Threading.Tasks; namespace IntegrationTests; @@ -23,31 +19,4 @@ public abstract class AgentCanScheduleTestsBase : FunctionTestBase { public AgentCanScheduleTestsBase(ITestOutputHelper output, IStorage storage) : base(output, storage) { } - - [Fact] - public async Async.Task Authorization_IsRequired() { - var auth = new TestEndpointAuthorization(RequestType.NoAuthorization, Logger, Context); - var func = new AgentCanSchedule(Logger, auth, Context); - - var result = await func.Run(TestHttpRequestData.Empty("POST")); - Assert.Equal(HttpStatusCode.Unauthorized, result.StatusCode); - } - - [Fact] - public async Async.Task UserAuthorization_IsNotPermitted() { - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new AgentCanSchedule(Logger, auth, Context); - - var result = await func.Run(TestHttpRequestData.Empty("POST")); - Assert.Equal(HttpStatusCode.Unauthorized, result.StatusCode); - } - - [Fact] - public async Async.Task AgentAuthorization_IsAccepted() { - var auth = new TestEndpointAuthorization(RequestType.Agent, Logger, Context); - var func = new AgentCanSchedule(Logger, auth, Context); - - var result = await func.Run(TestHttpRequestData.Empty("POST")); - Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode); // BadRequest due to no body, not Unauthorized - } } diff --git a/src/ApiService/IntegrationTests/AgentCommandsTests.cs b/src/ApiService/IntegrationTests/AgentCommandsTests.cs index b18b32e979..47b951fc15 100644 --- a/src/ApiService/IntegrationTests/AgentCommandsTests.cs +++ b/src/ApiService/IntegrationTests/AgentCommandsTests.cs @@ -26,33 +26,6 @@ public AgentCommandsTestsBase(ITestOutputHelper output, IStorage storage) : base(output, storage) { } - [Fact] - public async Async.Task Authorization_IsRequired() { - var auth = new TestEndpointAuthorization(RequestType.NoAuthorization, Logger, Context); - var func = new AgentCommands(Logger, auth, Context); - - var result = await func.Run(TestHttpRequestData.Empty("GET")); - Assert.Equal(HttpStatusCode.Unauthorized, result.StatusCode); - } - - [Fact] - public async Async.Task UserAuthorization_IsNotPermitted() { - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new AgentCommands(Logger, auth, Context); - - var result = await func.Run(TestHttpRequestData.Empty("GET")); - Assert.Equal(HttpStatusCode.Unauthorized, result.StatusCode); - } - - [Fact] - public async Async.Task AgentAuthorization_IsAccepted() { - var auth = new TestEndpointAuthorization(RequestType.Agent, Logger, Context); - var func = new AgentCommands(Logger, auth, Context); - - var result = await func.Run(TestHttpRequestData.Empty("GET")); - Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode); // BadRequest due to no body, not Unauthorized - } - [Fact] public async Async.Task AgentCommand_GetsCommand() { var machineId = Guid.NewGuid(); @@ -69,8 +42,7 @@ await Context.InsertAll(new[] { }); var commandRequest = new NodeCommandGet(machineId); - var auth = new TestEndpointAuthorization(RequestType.Agent, Logger, Context); - var func = new AgentCommands(Logger, auth, Context); + var func = new AgentCommands(Logger, Context); var result = await func.Run(TestHttpRequestData.FromJson("GET", commandRequest)); Assert.Equal(HttpStatusCode.OK, result.StatusCode); diff --git a/src/ApiService/IntegrationTests/AgentEventsTests.cs b/src/ApiService/IntegrationTests/AgentEventsTests.cs index 00bc0a1740..a662b4890a 100644 --- a/src/ApiService/IntegrationTests/AgentEventsTests.cs +++ b/src/ApiService/IntegrationTests/AgentEventsTests.cs @@ -35,28 +35,9 @@ public AgentEventsTestsBase(ITestOutputHelper output, IStorage storage) readonly Guid _poolId = Guid.NewGuid(); readonly string _poolVersion = $"version-{Guid.NewGuid()}"; - [Fact] - public async Async.Task Authorization_IsRequired() { - var auth = new TestEndpointAuthorization(RequestType.NoAuthorization, Logger, Context); - var func = new AgentEvents(Logger, auth, Context); - - var result = await func.Run(TestHttpRequestData.Empty("POST")); - Assert.Equal(HttpStatusCode.Unauthorized, result.StatusCode); - } - - [Fact] - public async Async.Task UserAuthorization_IsNotPermitted() { - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new AgentEvents(Logger, auth, Context); - - var result = await func.Run(TestHttpRequestData.Empty("POST")); - Assert.Equal(HttpStatusCode.Unauthorized, result.StatusCode); - } - [Fact] public async Async.Task WorkerEventMustHaveDoneOrRunningSet() { - var auth = new TestEndpointAuthorization(RequestType.Agent, Logger, Context); - var func = new AgentEvents(Logger, auth, Context); + var func = new AgentEvents(Logger, Context); var data = new NodeStateEnvelope( MachineId: Guid.NewGuid(), @@ -75,8 +56,7 @@ await Context.InsertAll( new Task(_jobId, _taskId, TaskState.Running, Os.Linux, new TaskConfig(_jobId, null, new TaskDetails(TaskType.Coverage, 100)))); - var auth = new TestEndpointAuthorization(RequestType.Agent, Logger, Context); - var func = new AgentEvents(Logger, auth, Context); + var func = new AgentEvents(Logger, Context); var data = new NodeStateEnvelope( MachineId: _machineId, @@ -103,8 +83,7 @@ await Context.InsertAll( new Task(_jobId, _taskId, TaskState.Running, Os.Linux, new TaskConfig(_jobId, null, new TaskDetails(TaskType.Coverage, 100)))); - var auth = new TestEndpointAuthorization(RequestType.Agent, Logger, Context); - var func = new AgentEvents(Logger, auth, Context); + var func = new AgentEvents(Logger, Context); var data = new NodeStateEnvelope( MachineId: _machineId, @@ -130,8 +109,7 @@ await Context.InsertAll( new Task(_jobId, _taskId, TaskState.Scheduled, Os.Linux, new TaskConfig(_jobId, null, new TaskDetails(TaskType.Coverage, 100)))); - var auth = new TestEndpointAuthorization(RequestType.Agent, Logger, Context); - var func = new AgentEvents(Logger, auth, Context); + var func = new AgentEvents(Logger, Context); var data = new NodeStateEnvelope( MachineId: _machineId, @@ -156,8 +134,7 @@ public async Async.Task WorkerRunning_ForMissingTask_ReturnsError() { await Context.InsertAll( new Node(_poolName, _machineId, _poolId, _poolVersion)); - var auth = new TestEndpointAuthorization(RequestType.Agent, Logger, Context); - var func = new AgentEvents(Logger, auth, Context); + var func = new AgentEvents(Logger, Context); var data = new NodeStateEnvelope( MachineId: _machineId, Event: new WorkerEvent(Running: new WorkerRunningEvent(_taskId))); @@ -173,8 +150,7 @@ await Context.InsertAll( new Task(_jobId, _taskId, TaskState.Running, Os.Linux, new TaskConfig(_jobId, null, new TaskDetails(TaskType.Coverage, 0)))); - var auth = new TestEndpointAuthorization(RequestType.Agent, Logger, Context); - var func = new AgentEvents(Logger, auth, Context); + var func = new AgentEvents(Logger, Context); var data = new NodeStateEnvelope( MachineId: _machineId, Event: new WorkerEvent(Running: new WorkerRunningEvent(_taskId))); @@ -191,8 +167,7 @@ await Context.InsertAll( new Task(_jobId, _taskId, TaskState.Running, Os.Linux, new TaskConfig(_jobId, null, new TaskDetails(TaskType.Coverage, 0)))); - var auth = new TestEndpointAuthorization(RequestType.Agent, Logger, Context); - var func = new AgentEvents(Logger, auth, Context); + var func = new AgentEvents(Logger, Context); var data = new NodeStateEnvelope( MachineId: _machineId, Event: new WorkerEvent(Running: new WorkerRunningEvent(_taskId))); @@ -232,8 +207,7 @@ await Async.Task.WhenAll( public async Async.Task NodeStateUpdate_ForMissingNode_IgnoresEvent() { // nothing present in storage - var auth = new TestEndpointAuthorization(RequestType.Agent, Logger, Context); - var func = new AgentEvents(Logger, auth, Context); + var func = new AgentEvents(Logger, Context); var data = new NodeStateEnvelope( MachineId: _machineId, Event: new NodeStateUpdate(NodeState.Init)); @@ -248,8 +222,7 @@ public async Async.Task NodeStateUpdate_CanTransitionFromInitToReady() { await Context.InsertAll( new Node(_poolName, _machineId, _poolId, _poolVersion, State: NodeState.Init)); - var auth = new TestEndpointAuthorization(RequestType.Agent, Logger, Context); - var func = new AgentEvents(Logger, auth, Context); + var func = new AgentEvents(Logger, Context); var data = new NodeStateEnvelope( MachineId: _machineId, Event: new NodeStateUpdate(NodeState.Ready)); @@ -266,8 +239,7 @@ public async Async.Task NodeStateUpdate_BecomingFree_StopsNode_IfMarkedForReimag await Context.InsertAll( new Node(_poolName, _machineId, _poolId, _poolVersion, ReimageRequested: true)); - var auth = new TestEndpointAuthorization(RequestType.Agent, Logger, Context); - var func = new AgentEvents(Logger, auth, Context); + var func = new AgentEvents(Logger, Context); var data = new NodeStateEnvelope( MachineId: _machineId, Event: new NodeStateUpdate(NodeState.Free)); @@ -295,8 +267,7 @@ public async Async.Task NodeStateUpdate_BecomingFree_StopsNode_IfMarkedForDeleti await Context.InsertAll( new Node(_poolName, _machineId, _poolId, _poolVersion, DeleteRequested: true)); - var auth = new TestEndpointAuthorization(RequestType.Agent, Logger, Context); - var func = new AgentEvents(Logger, auth, Context); + var func = new AgentEvents(Logger, Context); var data = new NodeStateEnvelope( MachineId: _machineId, Event: new NodeStateUpdate(NodeState.Free)); diff --git a/src/ApiService/IntegrationTests/AgentRegistrationTests.cs b/src/ApiService/IntegrationTests/AgentRegistrationTests.cs index 618616bbdc..202a66628b 100644 --- a/src/ApiService/IntegrationTests/AgentRegistrationTests.cs +++ b/src/ApiService/IntegrationTests/AgentRegistrationTests.cs @@ -33,37 +33,9 @@ public AgentRegistrationTestsBase(ITestOutputHelper output, IStorage storage) private readonly ScalesetId _scalesetId = ScalesetId.Parse($"scaleset-{Guid.NewGuid()}"); private readonly PoolName _poolName = PoolName.Parse($"pool-{Guid.NewGuid()}"); - [Fact] - public async Async.Task Authorization_IsRequired() { - var auth = new TestEndpointAuthorization(RequestType.NoAuthorization, Logger, Context); - var func = new AgentRegistration(Logger, auth, Context); - - var result = await func.Run(TestHttpRequestData.Empty("POST")); - Assert.Equal(HttpStatusCode.Unauthorized, result.StatusCode); - } - - [Fact] - public async Async.Task UserAuthorization_IsNotPermitted() { - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new AgentRegistration(Logger, auth, Context); - - var result = await func.Run(TestHttpRequestData.Empty("POST")); - Assert.Equal(HttpStatusCode.Unauthorized, result.StatusCode); - } - - [Fact] - public async Async.Task AgentAuthorization_IsAccepted() { - var auth = new TestEndpointAuthorization(RequestType.Agent, Logger, Context); - var func = new AgentRegistration(Logger, auth, Context); - - var result = await func.Run(TestHttpRequestData.Empty("POST")); - Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode); // BadRequest due to missing parameters, not Unauthorized - } - [Fact] public async Async.Task Get_UrlParameterRequired() { - var auth = new TestEndpointAuthorization(RequestType.Agent, Logger, Context); - var func = new AgentRegistration(Logger, auth, Context); + var func = new AgentRegistration(Logger, Context); var req = TestHttpRequestData.Empty("GET"); var result = await func.Run(req); @@ -76,8 +48,7 @@ public async Async.Task Get_UrlParameterRequired() { [Fact] public async Async.Task Get_MissingNode() { - var auth = new TestEndpointAuthorization(RequestType.Agent, Logger, Context); - var func = new AgentRegistration(Logger, auth, Context); + var func = new AgentRegistration(Logger, Context); var req = TestHttpRequestData.Empty("GET"); req.SetUrlParameter("machine_id", _machineId); @@ -95,8 +66,7 @@ public async Async.Task Get_MissingPool() { await Context.InsertAll( new Node(_poolName, _machineId, _poolId, "1.0.0")); - var auth = new TestEndpointAuthorization(RequestType.Agent, Logger, Context); - var func = new AgentRegistration(Logger, auth, Context); + var func = new AgentRegistration(Logger, Context); var req = TestHttpRequestData.Empty("GET"); req.SetUrlParameter("machine_id", _machineId); @@ -115,8 +85,7 @@ await Context.InsertAll( new Node(_poolName, _machineId, _poolId, "1.0.0"), new Pool(_poolName, _poolId, Os.Linux, false, Architecture.x86_64, PoolState.Init, null)); - var auth = new TestEndpointAuthorization(RequestType.Agent, Logger, Context); - var func = new AgentRegistration(Logger, auth, Context); + var func = new AgentRegistration(Logger, Context); var req = TestHttpRequestData.Empty("GET"); req.SetUrlParameter("machine_id", _machineId); @@ -135,8 +104,7 @@ public async Async.Task Post_SetsDefaultVersion_IfNotSupplied() { await Context.InsertAll( new Pool(_poolName, _poolId, Os.Linux, false, Architecture.x86_64, PoolState.Init, null)); - var auth = new TestEndpointAuthorization(RequestType.Agent, Logger, Context); - var func = new AgentRegistration(Logger, auth, Context); + var func = new AgentRegistration(Logger, Context); var req = TestHttpRequestData.Empty("POST"); req.SetUrlParameter("machine_id", _machineId); @@ -157,8 +125,7 @@ public async Async.Task Post_SetsCorrectVersion() { await Context.InsertAll( new Pool(_poolName, _poolId, Os.Linux, false, Architecture.x86_64, PoolState.Init, null)); - var auth = new TestEndpointAuthorization(RequestType.Agent, Logger, Context); - var func = new AgentRegistration(Logger, auth, Context); + var func = new AgentRegistration(Logger, Context); var req = TestHttpRequestData.Empty("POST"); req.SetUrlParameter("machine_id", _machineId); @@ -181,8 +148,7 @@ await Context.InsertAll( new Node(PoolName.Parse("another-pool"), _machineId, _poolId, "1.0.0"), new Pool(_poolName, _poolId, Os.Linux, false, Architecture.x86_64, PoolState.Init, null)); - var auth = new TestEndpointAuthorization(RequestType.Agent, Logger, Context); - var func = new AgentRegistration(Logger, auth, Context); + var func = new AgentRegistration(Logger, Context); var req = TestHttpRequestData.Empty("POST"); req.SetUrlParameter("machine_id", _machineId); @@ -205,8 +171,7 @@ public async Async.Task Post_ChecksRequiredParameters(string parameterToSkip) { await Context.InsertAll( new Pool(_poolName, _poolId, Os.Linux, false, Architecture.x86_64, PoolState.Init, null)); - var auth = new TestEndpointAuthorization(RequestType.Agent, Logger, Context); - var func = new AgentRegistration(Logger, auth, Context); + var func = new AgentRegistration(Logger, Context); var req = TestHttpRequestData.Empty("POST"); if (parameterToSkip != "machine_id") { diff --git a/src/ApiService/IntegrationTests/AuthTests.cs b/src/ApiService/IntegrationTests/AuthTests.cs index 5441c49b29..3632658fbc 100644 --- a/src/ApiService/IntegrationTests/AuthTests.cs +++ b/src/ApiService/IntegrationTests/AuthTests.cs @@ -14,7 +14,7 @@ public AuthTests(ITestOutputHelper output) { [Fact] public async System.Threading.Tasks.Task TestAuth() { - var auth = await Microsoft.OneFuzz.Service.Auth.BuildAuth(Logger); + var auth = await AuthHelpers.BuildAuth(Logger); auth.Should().NotBeNull(); auth.PrivateKey.StartsWith("-----BEGIN OPENSSH PRIVATE KEY-----").Should().BeTrue(); diff --git a/src/ApiService/IntegrationTests/ContainersTests.cs b/src/ApiService/IntegrationTests/ContainersTests.cs index 8b122a52a2..70339edf16 100644 --- a/src/ApiService/IntegrationTests/ContainersTests.cs +++ b/src/ApiService/IntegrationTests/ContainersTests.cs @@ -31,22 +31,6 @@ public abstract class ContainersTestBase : FunctionTestBase { public ContainersTestBase(ITestOutputHelper output, IStorage storage) : base(output, storage) { } - [Theory] - [InlineData("GET")] - [InlineData("POST")] - [InlineData("DELETE")] - public async Async.Task WithoutAuthorization_IsRejected(string method) { - var auth = new TestEndpointAuthorization(RequestType.NoAuthorization, Logger, Context); - var func = new ContainersFunction(Logger, auth, Context); - - var result = await func.Run(TestHttpRequestData.Empty(method)); - Assert.Equal(HttpStatusCode.Unauthorized, result.StatusCode); - - var err = BodyAs(result); - Assert.Equal(ErrorCode.UNAUTHORIZED.ToString(), err.Title); - } - - [Fact] public async Async.Task CanDelete() { var containerName = Container.Parse("test"); @@ -55,8 +39,7 @@ public async Async.Task CanDelete() { var msg = TestHttpRequestData.FromJson("DELETE", new ContainerDelete(containerName)); - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new ContainersFunction(Logger, auth, Context); + var func = new ContainersFunction(Logger, Context); var result = await func.Run(msg); Assert.Equal(HttpStatusCode.OK, result.StatusCode); @@ -71,8 +54,7 @@ public async Async.Task CanPost_New() { var containerName = Container.Parse("test"); var msg = TestHttpRequestData.FromJson("POST", new ContainerCreate(containerName, meta)); - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new ContainersFunction(Logger, auth, Context); + var func = new ContainersFunction(Logger, Context); var result = await func.Run(msg); Assert.Equal(HttpStatusCode.OK, result.StatusCode); @@ -95,8 +77,7 @@ public async Async.Task CanPost_Existing() { var metadata = new Dictionary { { "some", "value" } }; var msg = TestHttpRequestData.FromJson("POST", new ContainerCreate(containerName, metadata)); - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new ContainersFunction(Logger, auth, Context); + var func = new ContainersFunction(Logger, Context); var result = await func.Run(msg); Assert.Equal(HttpStatusCode.OK, result.StatusCode); @@ -119,8 +100,7 @@ public async Async.Task Get_Existing() { var msg = TestHttpRequestData.FromJson("GET", new ContainerGet(containerName)); - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new ContainersFunction(Logger, auth, Context); + var func = new ContainersFunction(Logger, Context); var result = await func.Run(msg); Assert.Equal(HttpStatusCode.OK, result.StatusCode); @@ -134,8 +114,7 @@ public async Async.Task Get_Missing_Fails() { var container = Container.Parse("container"); var msg = TestHttpRequestData.FromJson("GET", new ContainerGet(container)); - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new ContainersFunction(Logger, auth, Context); + var func = new ContainersFunction(Logger, Context); var result = await func.Run(msg); Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode); } @@ -149,8 +128,7 @@ public async Async.Task List_Existing() { var msg = TestHttpRequestData.Empty("GET"); // this means list all - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new ContainersFunction(Logger, auth, Context); + var func = new ContainersFunction(Logger, Context); var result = await func.Run(msg); Assert.Equal(HttpStatusCode.OK, result.StatusCode); @@ -188,8 +166,7 @@ public async Async.Task BadContainerNameProducesGoodErrorMessage() { // use anonymous type so we can send an invalid name var msg = TestHttpRequestData.FromJson("POST", new { Name = "AbCd" }); - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new ContainersFunction(Logger, auth, Context); + var func = new ContainersFunction(Logger, Context); var result = await func.Run(msg); Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode); diff --git a/src/ApiService/IntegrationTests/DownloadTests.cs b/src/ApiService/IntegrationTests/DownloadTests.cs index 0feae4db95..1912071caf 100644 --- a/src/ApiService/IntegrationTests/DownloadTests.cs +++ b/src/ApiService/IntegrationTests/DownloadTests.cs @@ -26,26 +26,13 @@ public abstract class DownloadTestBase : FunctionTestBase { public DownloadTestBase(ITestOutputHelper output, IStorage storage) : base(output, storage) { } - [Fact] - public async Async.Task Download_WithoutAuthorization_IsRejected() { - var auth = new TestEndpointAuthorization(RequestType.NoAuthorization, Logger, Context); - var func = new Download(auth, Context); - - var result = await func.Run(TestHttpRequestData.Empty("GET")); - Assert.Equal(HttpStatusCode.Unauthorized, result.StatusCode); - - var err = BodyAs(result); - Assert.Equal(ErrorCode.UNAUTHORIZED.ToString(), err.Title); - } - [Fact] public async Async.Task Download_WithoutContainer_IsRejected() { var req = TestHttpRequestData.Empty("GET"); var url = new UriBuilder(req.Url) { Query = "filename=xxx" }.Uri; req.SetUrl(url); - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new Download(auth, Context); + var func = new Download(Context); var result = await func.Run(req); Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode); @@ -59,8 +46,7 @@ public async Async.Task Download_WithoutFilename_IsRejected() { var url = new UriBuilder(req.Url) { Query = "container=xxx" }.Uri; req.SetUrl(url); - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new Download(auth, Context); + var func = new Download(Context); var result = await func.Run(req); Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode); @@ -81,8 +67,7 @@ public async Async.Task Download_RedirectsToResult_WithLocationHeader() { var url = new UriBuilder(req.Url) { Query = "container=xxx&filename=yyy" }.Uri; req.SetUrl(url); - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new Download(auth, Context); + var func = new Download(Context); var result = await func.Run(req); Assert.Equal(HttpStatusCode.Found, result.StatusCode); diff --git a/src/ApiService/IntegrationTests/EndpointAuthTests.cs b/src/ApiService/IntegrationTests/EndpointAuthTests.cs new file mode 100644 index 0000000000..6a1355362e --- /dev/null +++ b/src/ApiService/IntegrationTests/EndpointAuthTests.cs @@ -0,0 +1,93 @@ + +using System; +using System.Collections.Generic; +using System.Threading.Tasks; +using Microsoft.OneFuzz.Service; +using Xunit; +using Xunit.Abstractions; +using Async = System.Threading.Tasks; + +namespace IntegrationTests; + + +[Trait("Category", "Live")] +public class AzureStorageEndpointAuthTest : EndpointAuthTestBase { + public AzureStorageEndpointAuthTest(ITestOutputHelper output) + : base(output, Integration.AzureStorage.FromEnvironment()) { } +} + +public class AzuriteEndpointAuthTest : EndpointAuthTestBase { + public AzuriteEndpointAuthTest(ITestOutputHelper output) + : base(output, new Integration.AzuriteStorage()) { } +} + +public abstract class EndpointAuthTestBase : FunctionTestBase { + public EndpointAuthTestBase(ITestOutputHelper output, IStorage storage) + : base(output, storage) { + } + + private readonly Guid _applicationId = Guid.NewGuid(); + private readonly Guid _userObjectId = Guid.NewGuid(); + + private Task CheckUserAdmin() { + var userAuthInfo = new UserAuthInfo( + new UserInfo(ApplicationId: _applicationId, ObjectId: _userObjectId, "upn"), + new List()); + + var auth = new EndpointAuthorization(Context, Logger, null!); + + return auth.CheckRequireAdmins(userAuthInfo); + } + + [Fact] + public async Async.Task IfRequireAdminPrivilegesIsEnabled_UserIsNotPermitted() { + // config specifies that a different user is admin + await Context.InsertAll( + new InstanceConfig(Context.ServiceConfiguration.OneFuzzInstanceName!) { + RequireAdminPrivileges = true, + }); + + var result = await CheckUserAdmin(); + Assert.False(result.IsOk, "should not be admin"); + } + + [Fact] + public async Async.Task IfRequireAdminPrivilegesIsDisabled_UserIsPermitted() { + // disable requiring admin privileges + await Context.InsertAll( + new InstanceConfig(Context.ServiceConfiguration.OneFuzzInstanceName!) { + RequireAdminPrivileges = false, + }); + + var result = await CheckUserAdmin(); + Assert.True(result.IsOk, "should be admin"); + } + + [Fact] + public async Async.Task EnablingAdminForAnotherUserDoesNotPermitThisUser() { + var otherUserObjectId = Guid.NewGuid(); + + // config specifies that a different user is admin + await Context.InsertAll( + new InstanceConfig(Context.ServiceConfiguration.OneFuzzInstanceName!) { + Admins = new[] { otherUserObjectId }, + RequireAdminPrivileges = true, + }); + + var result = await CheckUserAdmin(); + Assert.False(result.IsOk, "should not be admin"); + } + + [Fact] + public async Async.Task UserCanBeAdmin() { + // config specifies that user is admin + await Context.InsertAll( + new InstanceConfig(Context.ServiceConfiguration.OneFuzzInstanceName!) { + Admins = new[] { _userObjectId }, + RequireAdminPrivileges = true, + }); + + var result = await CheckUserAdmin(); + Assert.True(result.IsOk, "should be admin"); + } +} diff --git a/src/ApiService/IntegrationTests/EventsTests.cs b/src/ApiService/IntegrationTests/EventsTests.cs index 7f9d4b11f0..b5b4e03a87 100644 --- a/src/ApiService/IntegrationTests/EventsTests.cs +++ b/src/ApiService/IntegrationTests/EventsTests.cs @@ -46,8 +46,7 @@ public async Async.Task BlobIsCreatedAndIsAccessible() { ping.Should().NotBeNull(); var msg = TestHttpRequestData.FromJson("GET", new EventsGet(ping.PingId)); - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new EventsFunction(Logger, auth, Context); + var func = new EventsFunction(Logger, Context); var result = await func.Run(msg); result.StatusCode.Should().Be(HttpStatusCode.OK); diff --git a/src/ApiService/IntegrationTests/Fakes/TestContext.cs b/src/ApiService/IntegrationTests/Fakes/TestContext.cs index f8fa1ca5fc..41d2219924 100644 --- a/src/ApiService/IntegrationTests/Fakes/TestContext.cs +++ b/src/ApiService/IntegrationTests/Fakes/TestContext.cs @@ -41,7 +41,6 @@ public TestContext(IHttpClientFactory httpClientFactory, ILogTracer logTracer, I ScalesetOperations = new ScalesetOperations(logTracer, cache, this); ReproOperations = new ReproOperations(logTracer, this); Reports = new Reports(logTracer, Containers); - UserCredentials = new UserCredentials(logTracer, ConfigOperations); NotificationOperations = new NotificationOperations(logTracer, this); FeatureManagerSnapshot = new TestFeatureManagerSnapshot(); @@ -79,7 +78,6 @@ public Async.Task InsertAll(params EntityBase[] objs) public ICreds Creds { get; } public IContainers Containers { get; set; } public IQueue Queue { get; } - public IUserCredentials UserCredentials { get; set; } public IRequestHandling RequestHandling { get; } diff --git a/src/ApiService/IntegrationTests/Fakes/TestEndpointAuthorization.cs b/src/ApiService/IntegrationTests/Fakes/TestEndpointAuthorization.cs deleted file mode 100644 index 0d20c42bf6..0000000000 --- a/src/ApiService/IntegrationTests/Fakes/TestEndpointAuthorization.cs +++ /dev/null @@ -1,47 +0,0 @@ - -using System; -using System.Net; -using System.Threading.Tasks; -using Microsoft.Azure.Functions.Worker.Http; -using Microsoft.OneFuzz.Service; - -namespace IntegrationTests.Fakes; - -public enum RequestType { - NoAuthorization, - User, - Agent, -} - -sealed class TestEndpointAuthorization : EndpointAuthorization { - private readonly RequestType _type; - private readonly IOnefuzzContext _context; - - public TestEndpointAuthorization(RequestType type, ILogTracer log, IOnefuzzContext context) - : base(context, log, null! /* not needed for test */) { - _type = type; - _context = context; - } - - public override Task CallIf( - HttpRequestData req, - Func> method, - bool allowUser = false, - bool allowAgent = false) { - - if ((_type == RequestType.User && allowUser) || - (_type == RequestType.Agent && allowAgent)) { - return method(req); - } - - return _context.RequestHandling.NotOk( - req, - Error.Create( - ErrorCode.UNAUTHORIZED, - "Unrecognized agent" - ), - "token verification", - HttpStatusCode.Unauthorized - ); - } -} diff --git a/src/ApiService/IntegrationTests/Fakes/TestFunctionContext.cs b/src/ApiService/IntegrationTests/Fakes/TestFunctionContext.cs new file mode 100644 index 0000000000..f27f863d29 --- /dev/null +++ b/src/ApiService/IntegrationTests/Fakes/TestFunctionContext.cs @@ -0,0 +1,32 @@ +using System; +using System.Collections.Generic; +using Microsoft.Azure.Functions.Worker; +using Microsoft.OneFuzz.Service; +using Microsoft.OneFuzz.Service.Auth; + +namespace IntegrationTests.Fakes; + +public class TestFunctionContext : FunctionContext { + public override IDictionary Items { get; set; } = new Dictionary(); + + public void SetUserAuthInfo(UserInfo userInfo) + => this.SetUserAuthInfo(new UserAuthInfo(userInfo, new List())); + + // everything else unsupported + + public override string InvocationId => throw new NotSupportedException(); + + public override string FunctionId => throw new NotSupportedException(); + + public override TraceContext TraceContext => throw new NotSupportedException(); + + public override BindingContext BindingContext => throw new NotSupportedException(); + + public override RetryContext RetryContext => throw new NotSupportedException(); + + public override IServiceProvider InstanceServices { get => throw new NotSupportedException(); set => throw new NotImplementedException(); } + + public override FunctionDefinition FunctionDefinition => throw new NotSupportedException(); + + public override IInvocationFeatures Features => throw new NotSupportedException(); +} diff --git a/src/ApiService/IntegrationTests/Fakes/TestUserCredentials.cs b/src/ApiService/IntegrationTests/Fakes/TestUserCredentials.cs deleted file mode 100644 index 4ee641880f..0000000000 --- a/src/ApiService/IntegrationTests/Fakes/TestUserCredentials.cs +++ /dev/null @@ -1,20 +0,0 @@ -using System.Collections.Generic; -using System.Threading.Tasks; -using Microsoft.Azure.Functions.Worker.Http; -using Microsoft.OneFuzz.Service; - -using Async = System.Threading.Tasks; - -namespace IntegrationTests.Fakes; - -sealed class TestUserCredentials : UserCredentials { - - private readonly OneFuzzResult _tokenResult; - - public TestUserCredentials(ILogTracer log, IConfigOperations instanceConfig, OneFuzzResult tokenResult) - : base(log, instanceConfig) { - _tokenResult = tokenResult.IsOk ? OneFuzzResult.Ok(new UserAuthInfo(tokenResult.OkV, new List())) : OneFuzzResult.Error(tokenResult.ErrorV); - } - - public override Task> ParseJwtToken(HttpRequestData req) => Async.Task.FromResult(_tokenResult); -} diff --git a/src/ApiService/IntegrationTests/InfoTests.cs b/src/ApiService/IntegrationTests/InfoTests.cs index 1bd7ef339e..b2ccafb225 100644 --- a/src/ApiService/IntegrationTests/InfoTests.cs +++ b/src/ApiService/IntegrationTests/InfoTests.cs @@ -26,27 +26,8 @@ public InfoTestBase(ITestOutputHelper output, IStorage storage) : base(output, storage) { } [Fact] - public async Async.Task TestInfo_WithoutAuthorization_IsRejected() { - var auth = new TestEndpointAuthorization(RequestType.NoAuthorization, Logger, Context); - var func = new Info(auth, Context); - - var result = await func.Run(TestHttpRequestData.Empty("GET")); - Assert.Equal(HttpStatusCode.Unauthorized, result.StatusCode); - } - - [Fact] - public async Async.Task TestInfo_WithAgentCredentials_IsRejected() { - var auth = new TestEndpointAuthorization(RequestType.Agent, Logger, Context); - var func = new Info(auth, Context); - - var result = await func.Run(TestHttpRequestData.Empty("GET")); - Assert.Equal(HttpStatusCode.Unauthorized, result.StatusCode); - } - - [Fact] - public async Async.Task TestInfo_WithUserCredentials_Succeeds() { - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new Info(auth, Context); + public async Async.Task TestInfo_Succeeds() { + var func = new Info(Context); var result = await func.Run(TestHttpRequestData.Empty("GET")); Assert.Equal(HttpStatusCode.OK, result.StatusCode); diff --git a/src/ApiService/IntegrationTests/JinjaToScribanMigrationTests.cs b/src/ApiService/IntegrationTests/JinjaToScribanMigrationTests.cs index 43e90f547b..09387877d7 100644 --- a/src/ApiService/IntegrationTests/JinjaToScribanMigrationTests.cs +++ b/src/ApiService/IntegrationTests/JinjaToScribanMigrationTests.cs @@ -24,8 +24,6 @@ protected JinjaToScribanMigrationTestBase(ITestOutputHelper output, IStorage sto [Fact] public async Async.Task Dry_Run_Does_Not_Make_Changes() { - await ConfigureAuth(); - var notificationContainer = Container.Parse("abc123"); var _ = await Context.Containers.CreateContainer(notificationContainer, StorageType.Corpus, null); var r = await Context.NotificationOperations.Create( @@ -39,8 +37,7 @@ public async Async.Task Dry_Run_Does_Not_Make_Changes() { var notificationBefore = r.OkV!; var adoTemplateBefore = (notificationBefore.Config as AdoTemplate)!; - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new JinjaToScribanMigrationFunction(Logger, auth, Context); + var func = new JinjaToScribanMigrationFunction(Logger, Context); var req = new JinjaToScribanMigrationPost(DryRun: true); var result = await func.Run(TestHttpRequestData.FromJson("POST", req)); @@ -62,8 +59,6 @@ public async Async.Task Dry_Run_Does_Not_Make_Changes() { [Fact] public async Async.Task Migration_Happens_When_Not_Dry_run() { - await ConfigureAuth(); - var notificationContainer = Container.Parse("abc123"); var _ = await Context.Containers.CreateContainer(notificationContainer, StorageType.Corpus, null); var r = await Context.NotificationOperations.Create( @@ -77,8 +72,7 @@ public async Async.Task Migration_Happens_When_Not_Dry_run() { var notificationBefore = r.OkV!; var adoTemplateBefore = (notificationBefore.Config as AdoTemplate)!; - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new JinjaToScribanMigrationFunction(Logger, auth, Context); + var func = new JinjaToScribanMigrationFunction(Logger, Context); var req = new JinjaToScribanMigrationPost(); var result = await func.Run(TestHttpRequestData.FromJson("POST", req)); @@ -99,8 +93,6 @@ public async Async.Task Migration_Happens_When_Not_Dry_run() { [Fact] public async Async.Task OptionalFieldsAreSupported() { - await ConfigureAuth(); - var adoTemplate = new AdoTemplate( new Uri("http://example.com"), new SecretData(new SecretValue("some secret")), @@ -126,8 +118,6 @@ public async Async.Task OptionalFieldsAreSupported() { [Fact] public async Async.Task All_ADO_Fields_Are_Migrated() { - await ConfigureAuth(); - var notificationContainer = Container.Parse("abc123"); var adoTemplate = new AdoTemplate( new Uri("http://example.com"), @@ -161,8 +151,7 @@ public async Async.Task All_ADO_Fields_Are_Migrated() { var notificationBefore = r.OkV!; var adoTemplateBefore = (notificationBefore.Config as AdoTemplate)!; - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new JinjaToScribanMigrationFunction(Logger, auth, Context); + var func = new JinjaToScribanMigrationFunction(Logger, Context); var req = new JinjaToScribanMigrationPost(); var result = await func.Run(TestHttpRequestData.FromJson("POST", req)); @@ -196,8 +185,6 @@ public async Async.Task All_ADO_Fields_Are_Migrated() { [Fact] public async Async.Task All_Github_Fields_Are_Migrated() { - await ConfigureAuth(); - var githubTemplate = MigratableGithubTemplate(); var notificationContainer = Container.Parse("abc123"); @@ -213,8 +200,7 @@ public async Async.Task All_Github_Fields_Are_Migrated() { var notificationBefore = r.OkV!; var githubTemplateBefore = (notificationBefore.Config as GithubIssuesTemplate)!; - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new JinjaToScribanMigrationFunction(Logger, auth, Context); + var func = new JinjaToScribanMigrationFunction(Logger, Context); var req = new JinjaToScribanMigrationPost(); var result = await func.Run(TestHttpRequestData.FromJson("POST", req)); @@ -252,8 +238,6 @@ public async Async.Task All_Github_Fields_Are_Migrated() { [Fact] public async Async.Task Teams_Template_Not_Migrated() { - await ConfigureAuth(); - var teamsTemplate = GetTeamsTemplate(); var notificationContainer = Container.Parse("abc123"); var _ = await Context.Containers.CreateContainer(notificationContainer, StorageType.Corpus, null); @@ -268,8 +252,7 @@ public async Async.Task Teams_Template_Not_Migrated() { var notificationBefore = r.OkV!; var teamsTemplateBefore = (notificationBefore.Config as TeamsTemplate)!; - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new JinjaToScribanMigrationFunction(Logger, auth, Context); + var func = new JinjaToScribanMigrationFunction(Logger, Context); var req = new JinjaToScribanMigrationPost(); var result = await func.Run(TestHttpRequestData.FromJson("POST", req)); @@ -288,8 +271,6 @@ public async Async.Task Teams_Template_Not_Migrated() { // Multiple notification configs can be migrated [Fact] public async Async.Task Can_Migrate_Multiple_Notification_Configs() { - await ConfigureAuth(); - var notificationContainer = Container.Parse("abc123"); var _ = await Context.Containers.CreateContainer(notificationContainer, StorageType.Corpus, null); @@ -323,8 +304,7 @@ public async Async.Task Can_Migrate_Multiple_Notification_Configs() { var githubNotificationBefore = r.OkV!; var githubTemplateBefore = (githubNotificationBefore.Config as GithubIssuesTemplate)!; - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new JinjaToScribanMigrationFunction(Logger, auth, Context); + var func = new JinjaToScribanMigrationFunction(Logger, Context); var req = new JinjaToScribanMigrationPost(); var result = await func.Run(TestHttpRequestData.FromJson("POST", req)); @@ -351,33 +331,12 @@ public async Async.Task Can_Migrate_Multiple_Notification_Configs() { githubTemplateAfter.Organization.Should().BeEquivalentTo(JinjaTemplateAdapter.AdaptForScriban(githubTemplateBefore.Organization)); } - [Fact] - public async Async.Task Access_WithoutAuthorization_IsRejected() { - - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new JinjaToScribanMigrationFunction(Logger, auth, Context); - var req = new JinjaToScribanMigrationPost(); - var result = await func.Run(TestHttpRequestData.FromJson("POST", req)); - - result.StatusCode.Should().Be(System.Net.HttpStatusCode.BadRequest); - } - [Fact] public async Async.Task Do_Not_Enforce_Key_Exists_In_Strict_Validation() { (await JinjaTemplateAdapter.IsValidScribanNotificationTemplate(Context, Logger, ValidScribanAdoTemplate())) .Should().BeTrue(); } - private async Async.Task ConfigureAuth() { - await Context.InsertAll( - new InstanceConfig(Context.ServiceConfiguration.OneFuzzInstanceName!) { Admins = new[] { _userObjectId } } // needed for admin check - ); - - // override the found user credentials - need these to check for admin - var userInfo = new UserInfo(ApplicationId: Guid.NewGuid(), ObjectId: _userObjectId, "upn"); - Context.UserCredentials = new TestUserCredentials(Logger, Context.ConfigOperations, OneFuzzResult.Ok(userInfo)); - } - private static AdoTemplate MigratableAdoTemplate() { return new AdoTemplate( new Uri("http://example.com"), diff --git a/src/ApiService/IntegrationTests/JobsTests.cs b/src/ApiService/IntegrationTests/JobsTests.cs index 41b0b4b1c5..365df56d88 100644 --- a/src/ApiService/IntegrationTests/JobsTests.cs +++ b/src/ApiService/IntegrationTests/JobsTests.cs @@ -29,27 +29,12 @@ public JobsTestBase(ITestOutputHelper output, IStorage storage) private readonly Guid _jobId = Guid.NewGuid(); private readonly JobConfig _config = new("project", "name", "build", 1000, null); - [Theory] - [InlineData("POST")] - [InlineData("GET")] - [InlineData("DELETE")] - public async Async.Task Access_WithoutAuthorization_IsRejected(string method) { - var auth = new TestEndpointAuthorization(RequestType.NoAuthorization, Logger, Context); - var func = new Jobs(auth, Context, Logger); - - var result = await func.Run(TestHttpRequestData.Empty(method)); - Assert.Equal(HttpStatusCode.Unauthorized, result.StatusCode); - - var err = BodyAs(result); - Assert.Equal(ErrorCode.UNAUTHORIZED.ToString(), err.Title); - } - [Fact] public async Async.Task Delete_NonExistentJob_Fails() { - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new Jobs(auth, Context, Logger); + var func = new Jobs(Context, Logger); - var result = await func.Run(TestHttpRequestData.FromJson("DELETE", new JobGet(_jobId))); + var ctx = new TestFunctionContext(); + var result = await func.Run(TestHttpRequestData.FromJson("DELETE", new JobGet(_jobId)), ctx); Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode); var err = BodyAs(result); @@ -61,10 +46,10 @@ public async Async.Task Delete_ExistingJob_SetsStoppingState() { await Context.InsertAll( new Job(_jobId, JobState.Enabled, _config)); - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new Jobs(auth, Context, Logger); + var func = new Jobs(Context, Logger); - var result = await func.Run(TestHttpRequestData.FromJson("DELETE", new JobGet(_jobId))); + var ctx = new TestFunctionContext(); + var result = await func.Run(TestHttpRequestData.FromJson("DELETE", new JobGet(_jobId)), ctx); Assert.Equal(HttpStatusCode.OK, result.StatusCode); var response = BodyAs(result); @@ -80,10 +65,10 @@ public async Async.Task Delete_ExistingStoppedJob_DoesNotSetStoppingState() { await Context.InsertAll( new Job(_jobId, JobState.Stopped, _config)); - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new Jobs(auth, Context, Logger); + var func = new Jobs(Context, Logger); - var result = await func.Run(TestHttpRequestData.FromJson("DELETE", new JobGet(_jobId))); + var ctx = new TestFunctionContext(); + var result = await func.Run(TestHttpRequestData.FromJson("DELETE", new JobGet(_jobId)), ctx); Assert.Equal(HttpStatusCode.OK, result.StatusCode); var response = BodyAs(result); @@ -100,10 +85,10 @@ public async Async.Task Get_CanFindSpecificJob() { await Context.InsertAll( new Job(_jobId, JobState.Stopped, _config)); - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new Jobs(auth, Context, Logger); + var func = new Jobs(Context, Logger); - var result = await func.Run(TestHttpRequestData.FromJson("GET", new JobSearch(JobId: _jobId))); + var ctx = new TestFunctionContext(); + var result = await func.Run(TestHttpRequestData.FromJson("GET", new JobSearch(JobId: _jobId)), ctx); Assert.Equal(HttpStatusCode.OK, result.StatusCode); var response = BodyAs(result); @@ -119,11 +104,11 @@ await Context.InsertAll( new Job(Guid.NewGuid(), JobState.Enabled, _config), new Job(Guid.NewGuid(), JobState.Stopped, _config)); - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new Jobs(auth, Context, Logger); + var func = new Jobs(Context, Logger); var req = new JobSearch(State: new List { JobState.Enabled }); - var result = await func.Run(TestHttpRequestData.FromJson("GET", req)); + var ctx = new TestFunctionContext(); + var result = await func.Run(TestHttpRequestData.FromJson("GET", req), ctx); Assert.Equal(HttpStatusCode.OK, result.StatusCode); var response = BodyAs(result); @@ -138,11 +123,11 @@ await Context.InsertAll( new Job(Guid.NewGuid(), JobState.Enabled, _config), new Job(Guid.NewGuid(), JobState.Stopped, _config)); - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new Jobs(auth, Context, Logger); + var func = new Jobs(Context, Logger); var req = new JobSearch(State: new List { JobState.Enabled, JobState.Stopping }); - var result = await func.Run(TestHttpRequestData.FromJson("GET", req)); + var ctx = new TestFunctionContext(); + var result = await func.Run(TestHttpRequestData.FromJson("GET", req), ctx); Assert.Equal(HttpStatusCode.OK, result.StatusCode); var response = BodyAs(result); @@ -153,14 +138,12 @@ await Context.InsertAll( [Fact] public async Async.Task Post_CreatesJob_AndContainer() { - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new Jobs(auth, Context, Logger); + var func = new Jobs(Context, Logger); // need user credentials to put into the job object - var userInfo = new UserInfo(Guid.NewGuid(), Guid.NewGuid(), "upn"); - Context.UserCredentials = new TestUserCredentials(Logger, Context.ConfigOperations, OneFuzzResult.Ok(userInfo)); - - var result = await func.Run(TestHttpRequestData.FromJson("POST", _config)); + var ctx = new TestFunctionContext(); + ctx.SetUserAuthInfo(new UserInfo(Guid.NewGuid(), Guid.NewGuid(), "upn")); + var result = await func.Run(TestHttpRequestData.FromJson("POST", _config), ctx); Assert.Equal(HttpStatusCode.OK, result.StatusCode); var job = Assert.Single(await Context.JobOperations.SearchAll().ToListAsync()); diff --git a/src/ApiService/IntegrationTests/NodeTests.cs b/src/ApiService/IntegrationTests/NodeTests.cs index 8902a01ab3..8900fb9d6f 100644 --- a/src/ApiService/IntegrationTests/NodeTests.cs +++ b/src/ApiService/IntegrationTests/NodeTests.cs @@ -35,10 +35,8 @@ public NodeTestBase(ITestOutputHelper output, IStorage storage) [Fact] public async Async.Task Search_SpecificNode_NotFound_ReturnsNotFound() { - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var req = new NodeSearch(MachineId: _machineId); - var func = new NodeFunction(Logger, auth, Context); + var func = new NodeFunction(Logger, Context); var result = await func.Run(TestHttpRequestData.FromJson("GET", req)); Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode); } @@ -48,10 +46,8 @@ public async Async.Task Search_SpecificNode_Found_ReturnsOk() { await Context.InsertAll( new Node(_poolName, _machineId, null, _version)); - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var req = new NodeSearch(MachineId: _machineId); - var func = new NodeFunction(Logger, auth, Context); + var func = new NodeFunction(Logger, Context); var result = await func.Run(TestHttpRequestData.FromJson("GET", req)); Assert.Equal(HttpStatusCode.OK, result.StatusCode); @@ -62,10 +58,8 @@ await Context.InsertAll( [Fact] public async Async.Task Search_MultipleNodes_CanFindNone() { - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var req = new NodeSearch(); - var func = new NodeFunction(Logger, auth, Context); + var func = new NodeFunction(Logger, Context); var result = await func.Run(TestHttpRequestData.FromJson("GET", req)); Assert.Equal(HttpStatusCode.OK, result.StatusCode); Assert.Equal("[]", BodyAsString(result)); @@ -80,8 +74,7 @@ await Context.InsertAll( var req = new NodeSearch(PoolName: _poolName); - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new NodeFunction(Logger, auth, Context); + var func = new NodeFunction(Logger, Context); var result = await func.Run(TestHttpRequestData.FromJson("GET", req)); Assert.Equal(HttpStatusCode.OK, result.StatusCode); @@ -99,8 +92,7 @@ await Context.InsertAll( var req = new NodeSearch(ScalesetId: _scalesetId); - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new NodeFunction(Logger, auth, Context); + var func = new NodeFunction(Logger, Context); var result = await func.Run(TestHttpRequestData.FromJson("GET", req)); Assert.Equal(HttpStatusCode.OK, result.StatusCode); @@ -109,7 +101,6 @@ await Context.InsertAll( Assert.Equal(2, deserialized.Length); } - [Fact] public async Async.Task Search_MultipleNodes_ByState() { await Context.InsertAll( @@ -119,8 +110,7 @@ await Context.InsertAll( var req = new NodeSearch(State: new List { NodeState.Busy }); - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new NodeFunction(Logger, auth, Context); + var func = new NodeFunction(Logger, Context); var result = await func.Run(TestHttpRequestData.FromJson("GET", req)); Assert.Equal(HttpStatusCode.OK, result.StatusCode); @@ -139,8 +129,7 @@ await Context.InsertAll( var req = new NodeSearch(State: new List { NodeState.Free, NodeState.Busy }); - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new NodeFunction(Logger, auth, Context); + var func = new NodeFunction(Logger, Context); var result = await func.Run(TestHttpRequestData.FromJson("GET", req)); Assert.Equal(HttpStatusCode.OK, result.StatusCode); @@ -148,147 +137,4 @@ await Context.InsertAll( var deserialized = BodyAs(result); Assert.Equal(3, deserialized.Length); } - - [Theory] - [InlineData("PATCH")] - [InlineData("POST")] - [InlineData("DELETE")] - public async Async.Task RequiresAdmin(string method) { - // config must be found - await Context.InsertAll( - new InstanceConfig(Context.ServiceConfiguration.OneFuzzInstanceName!) { - RequireAdminPrivileges = true - }); - - // must be a user to auth - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - - // override the found user credentials - var userInfo = new UserInfo(ApplicationId: Guid.NewGuid(), ObjectId: Guid.NewGuid(), "upn"); - Context.UserCredentials = new TestUserCredentials(Logger, Context.ConfigOperations, OneFuzzResult.Ok(userInfo)); - - var req = new NodeGet(MachineId: _machineId); - var func = new NodeFunction(Logger, auth, Context); - var result = await func.Run(TestHttpRequestData.FromJson(method, req)); - Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode); - - var err = BodyAs(result); - Assert.Equal(ErrorCode.UNAUTHORIZED.ToString(), err.Title); - Assert.Contains("pool modification disabled", err.Detail); - } - - [Theory] - [InlineData("PATCH")] - [InlineData("POST")] - [InlineData("DELETE")] - public async Async.Task RequiresAdmin_CanBeDisabled(string method) { - // disable requiring admin privileges - await Context.InsertAll( - new InstanceConfig(Context.ServiceConfiguration.OneFuzzInstanceName!) { - RequireAdminPrivileges = false - }); - - // must be a user to auth - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - - // override the found user credentials - var userInfo = new UserInfo(ApplicationId: Guid.NewGuid(), ObjectId: Guid.NewGuid(), "upn"); - Context.UserCredentials = new TestUserCredentials(Logger, Context.ConfigOperations, OneFuzzResult.Ok(userInfo)); - - var req = new NodeGet(MachineId: _machineId); - var func = new NodeFunction(Logger, auth, Context); - var result = await func.Run(TestHttpRequestData.FromJson(method, req)); - - // we will fail with BadRequest but due to not being able to find the Node, - // not because of UNAUTHORIZED - Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode); - Assert.Equal(ErrorCode.UNABLE_TO_FIND.ToString(), BodyAs(result).Title); - } - - [Theory] - [InlineData("PATCH")] - [InlineData("POST")] - [InlineData("DELETE")] - public async Async.Task UserCanBeAdmin(string method) { - var userObjectId = Guid.NewGuid(); - - // config specifies that user is admin - await Context.InsertAll( - new InstanceConfig(Context.ServiceConfiguration.OneFuzzInstanceName!) { - Admins = new[] { userObjectId } - }); - - // must be a user to auth - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - - // override the found user credentials - var userInfo = new UserInfo(ApplicationId: Guid.NewGuid(), ObjectId: userObjectId, "upn"); - Context.UserCredentials = new TestUserCredentials(Logger, Context.ConfigOperations, OneFuzzResult.Ok(userInfo)); - - var req = new NodeGet(MachineId: _machineId); - var func = new NodeFunction(Logger, auth, Context); - var result = await func.Run(TestHttpRequestData.FromJson(method, req)); - - // we will fail with BadRequest but due to not being able to find the Node, - // not because of UNAUTHORIZED - Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode); - Assert.Equal(ErrorCode.UNABLE_TO_FIND.ToString(), BodyAs(result).Title); - } - - [Theory] - [InlineData("PATCH")] - [InlineData("POST")] - [InlineData("DELETE")] - public async Async.Task EnablingAdminForAnotherUserDoesNotPermitThisUser(string method) { - var userObjectId = Guid.NewGuid(); - var otherObjectId = Guid.NewGuid(); - - // config specifies that a different user is admin - await Context.InsertAll( - new InstanceConfig(Context.ServiceConfiguration.OneFuzzInstanceName!) { - Admins = new[] { otherObjectId }, RequireAdminPrivileges = true - }); - - // must be a user to auth - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - - // override the found user credentials - var userInfo = new UserInfo(ApplicationId: Guid.NewGuid(), ObjectId: userObjectId, "upn"); - Context.UserCredentials = new TestUserCredentials(Logger, Context.ConfigOperations, OneFuzzResult.Ok(userInfo)); - - var req = new NodeGet(MachineId: _machineId); - var func = new NodeFunction(Logger, auth, Context); - var result = await func.Run(TestHttpRequestData.FromJson(method, req)); - Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode); - - var err = BodyAs(result); - Assert.Equal(ErrorCode.UNAUTHORIZED.ToString(), err.Title); - Assert.Contains("not authorized to manage instance", err.Detail); - } - - [Theory] - [InlineData("PATCH")] - [InlineData("POST")] - [InlineData("DELETE")] - public async Async.Task CanPerformOperation(string method) { - // disable requiring admin privileges - await Context.InsertAll( - new InstanceConfig(Context.ServiceConfiguration.OneFuzzInstanceName!) { - RequireAdminPrivileges = false - }, - new Node(_poolName, _machineId, null, _version)); - - // must be a user to auth - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - - // override the found user credentials - var userInfo = new UserInfo(ApplicationId: Guid.NewGuid(), ObjectId: Guid.NewGuid(), "upn"); - Context.UserCredentials = new TestUserCredentials(Logger, Context.ConfigOperations, OneFuzzResult.Ok(userInfo)); - - // all of these operations use NodeGet - var req = new NodeGet(MachineId: _machineId); - var func = new NodeFunction(Logger, auth, Context); - var result = await func.Run(TestHttpRequestData.FromJson(method, req)); - Assert.Equal(HttpStatusCode.OK, result.StatusCode); - } } diff --git a/src/ApiService/IntegrationTests/PoolTests.cs b/src/ApiService/IntegrationTests/PoolTests.cs index f861053ffb..437120291a 100644 --- a/src/ApiService/IntegrationTests/PoolTests.cs +++ b/src/ApiService/IntegrationTests/PoolTests.cs @@ -31,26 +31,10 @@ public PoolTestBase(ITestOutputHelper output, IStorage storage) private readonly PoolName _poolName = PoolName.Parse("pool-" + Guid.NewGuid()); - [Theory] - [InlineData("POST", RequestType.Agent)] - [InlineData("POST", RequestType.NoAuthorization)] - [InlineData("GET", RequestType.Agent)] - [InlineData("GET", RequestType.NoAuthorization)] - [InlineData("DELETE", RequestType.Agent)] - [InlineData("DELETE", RequestType.NoAuthorization)] - public async Async.Task UserAuthorization_IsRequired(string method, RequestType authType) { - var auth = new TestEndpointAuthorization(authType, Logger, Context); - var func = new PoolFunction(Logger, auth, Context); - var result = await func.Run(TestHttpRequestData.Empty(method)); - Assert.Equal(HttpStatusCode.Unauthorized, result.StatusCode); - } - [Fact] public async Async.Task Search_SpecificPool_ById_NotFound_ReturnsBadRequest() { - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var req = new PoolSearch(PoolId: _poolId); - var func = new PoolFunction(Logger, auth, Context); + var func = new PoolFunction(Context); var result = await func.Run(TestHttpRequestData.FromJson("GET", req)); Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode); } @@ -66,10 +50,8 @@ await Context.InsertAll( // use test class to override instance ID Context.Containers = new TestContainers(Logger, Context.Storage, Context.ServiceConfiguration); - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var req = new PoolSearch(PoolId: _poolId); - var func = new PoolFunction(Logger, auth, Context); + var func = new PoolFunction(Context); var result = await func.Run(TestHttpRequestData.FromJson("GET", req)); Assert.Equal(HttpStatusCode.OK, result.StatusCode); @@ -79,10 +61,8 @@ await Context.InsertAll( [Fact] public async Async.Task Search_SpecificPool_ByName_NotFound_ReturnsBadRequest() { - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var req = new PoolSearch(Name: _poolName); - var func = new PoolFunction(Logger, auth, Context); + var func = new PoolFunction(Context); var result = await func.Run(TestHttpRequestData.FromJson("GET", req)); Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode); } @@ -98,10 +78,8 @@ await Context.InsertAll( // use test class to override instance ID Context.Containers = new TestContainers(Logger, Context.Storage, Context.ServiceConfiguration); - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var req = new PoolSearch(Name: _poolName); - var func = new PoolFunction(Logger, auth, Context); + var func = new PoolFunction(Context); var result = await func.Run(TestHttpRequestData.FromJson("GET", req)); Assert.Equal(HttpStatusCode.OK, result.StatusCode); @@ -111,10 +89,8 @@ await Context.InsertAll( [Fact] public async Async.Task Search_SpecificPool_ByState_NotFound_ReturnsEmptyResult() { - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var req = new PoolSearch(State: new List { PoolState.Init }); - var func = new PoolFunction(Logger, auth, Context); + var func = new PoolFunction(Context); var result = await func.Run(TestHttpRequestData.FromJson("GET", req)); Assert.Equal(HttpStatusCode.OK, result.StatusCode); @@ -127,9 +103,7 @@ await Context.InsertAll( new InstanceConfig(Context.ServiceConfiguration.OneFuzzInstanceName!) { Admins = new[] { _userObjectId } }, // needed for admin check new Pool(_poolName, _poolId, Os.Linux, true, Architecture.x86_64, PoolState.Running, null)); - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - - var func = new PoolFunction(Logger, auth, Context); + var func = new PoolFunction(Context); var result = await func.Run(TestHttpRequestData.FromJson("GET", new PoolSearch())); Assert.Equal(HttpStatusCode.OK, result.StatusCode); @@ -144,15 +118,9 @@ await Context.InsertAll( new InstanceConfig(Context.ServiceConfiguration.OneFuzzInstanceName!) { Admins = new[] { _userObjectId } }, // needed for admin check new Pool(_poolName, _poolId, Os.Linux, true, Architecture.x86_64, PoolState.Running, null)); - // override the found user credentials - need these to check for admin - var userInfo = new UserInfo(ApplicationId: Guid.NewGuid(), ObjectId: _userObjectId, "upn"); - Context.UserCredentials = new TestUserCredentials(Logger, Context.ConfigOperations, OneFuzzResult.Ok(userInfo)); - - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new PoolFunction(Logger, auth, Context); - + var func = new PoolFunction(Context); var req = new PoolStop(Name: _poolName, Now: false); - var result = await func.Run(TestHttpRequestData.FromJson("DELETE", req)); + var result = await func.Admin(TestHttpRequestData.FromJson("DELETE", req)); Assert.Equal(HttpStatusCode.OK, result.StatusCode); var pool = await Context.PoolOperations.GetByName(_poolName); @@ -166,15 +134,9 @@ await Context.InsertAll( new InstanceConfig(Context.ServiceConfiguration.OneFuzzInstanceName!) { Admins = new[] { _userObjectId } }, // needed for admin check new Pool(_poolName, _poolId, Os.Linux, true, Architecture.x86_64, PoolState.Halt, null)); - // override the found user credentials - need these to check for admin - var userInfo = new UserInfo(ApplicationId: Guid.NewGuid(), ObjectId: _userObjectId, "upn"); - Context.UserCredentials = new TestUserCredentials(Logger, Context.ConfigOperations, OneFuzzResult.Ok(userInfo)); - - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new PoolFunction(Logger, auth, Context); - + var func = new PoolFunction(Context); var req = new PoolStop(Name: _poolName, Now: false); - var result = await func.Run(TestHttpRequestData.FromJson("DELETE", req)); + var result = await func.Admin(TestHttpRequestData.FromJson("DELETE", req)); Assert.Equal(HttpStatusCode.OK, result.StatusCode); var pool = await Context.PoolOperations.GetByName(_poolName); @@ -188,15 +150,9 @@ await Context.InsertAll( new InstanceConfig(Context.ServiceConfiguration.OneFuzzInstanceName!) { Admins = new[] { _userObjectId } }, // needed for admin check new Pool(_poolName, _poolId, Os.Linux, true, Architecture.x86_64, PoolState.Running, null)); - // override the found user credentials - need these to check for admin - var userInfo = new UserInfo(ApplicationId: Guid.NewGuid(), ObjectId: _userObjectId, "upn"); - Context.UserCredentials = new TestUserCredentials(Logger, Context.ConfigOperations, OneFuzzResult.Ok(userInfo)); - - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new PoolFunction(Logger, auth, Context); - + var func = new PoolFunction(Context); var req = new PoolStop(Name: _poolName, Now: true); - var result = await func.Run(TestHttpRequestData.FromJson("DELETE", req)); + var result = await func.Admin(TestHttpRequestData.FromJson("DELETE", req)); Assert.Equal(HttpStatusCode.OK, result.StatusCode); var pool = await Context.PoolOperations.GetByName(_poolName); @@ -209,18 +165,12 @@ public async Async.Task Post_CreatesNewPool() { await Context.InsertAll( new InstanceConfig(Context.ServiceConfiguration.OneFuzzInstanceName!) { Admins = new[] { _userObjectId } }); // needed for admin check - // override the found user credentials - need these to check for admin - var userInfo = new UserInfo(ApplicationId: Guid.NewGuid(), ObjectId: _userObjectId, "upn"); - Context.UserCredentials = new TestUserCredentials(Logger, Context.ConfigOperations, OneFuzzResult.Ok(userInfo)); - // need to override instance id Context.Containers = new TestContainers(Logger, Context.Storage, Context.ServiceConfiguration); - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new PoolFunction(Logger, auth, Context); - + var func = new PoolFunction(Context); var req = new PoolCreate(Name: _poolName, Os.Linux, Architecture.x86_64, true); - var result = await func.Run(TestHttpRequestData.FromJson("POST", req)); + var result = await func.Admin(TestHttpRequestData.FromJson("POST", req)); Assert.Equal(HttpStatusCode.OK, result.StatusCode); // should get a pool back @@ -240,15 +190,9 @@ await Context.InsertAll( new InstanceConfig(Context.ServiceConfiguration.OneFuzzInstanceName!) { Admins = new[] { _userObjectId } }, // needed for admin check new Pool(_poolName, _poolId, Os.Linux, true, Architecture.x86_64, PoolState.Running, null)); - // override the found user credentials - need these to check for admin - var userInfo = new UserInfo(ApplicationId: Guid.NewGuid(), ObjectId: _userObjectId, "upn"); - Context.UserCredentials = new TestUserCredentials(Logger, Context.ConfigOperations, OneFuzzResult.Ok(userInfo)); - - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new PoolFunction(Logger, auth, Context); - + var func = new PoolFunction(Context); var req = new PoolCreate(Name: _poolName, Os.Linux, Architecture.x86_64, true); - var result = await func.Run(TestHttpRequestData.FromJson("POST", req)); + var result = await func.Admin(TestHttpRequestData.FromJson("POST", req)); Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode); // should get an error back diff --git a/src/ApiService/IntegrationTests/ReproVmssTests.cs b/src/ApiService/IntegrationTests/ReproVmssTests.cs index 29168df2be..83b274f4cc 100644 --- a/src/ApiService/IntegrationTests/ReproVmssTests.cs +++ b/src/ApiService/IntegrationTests/ReproVmssTests.cs @@ -28,27 +28,12 @@ public abstract class ReproVmssTestBase : FunctionTestBase { public ReproVmssTestBase(ITestOutputHelper output, IStorage storage) : base(output, storage) { } - - [Theory] - [InlineData("POST", RequestType.Agent)] - [InlineData("POST", RequestType.NoAuthorization)] - [InlineData("GET", RequestType.Agent)] - [InlineData("GET", RequestType.NoAuthorization)] - [InlineData("DELETE", RequestType.Agent)] - [InlineData("DELETE", RequestType.NoAuthorization)] - public async Async.Task UserAuthorization_IsRequired(string method, RequestType authType) { - var auth = new TestEndpointAuthorization(authType, Logger, Context); - var func = new ReproVmss(Logger, auth, Context); - var result = await func.Run(TestHttpRequestData.Empty(method)); - Assert.Equal(HttpStatusCode.Unauthorized, result.StatusCode); - } - [Fact] public async Async.Task GetMissingVmFails() { - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new ReproVmss(Logger, auth, Context); + var func = new ReproVmss(Logger, Context); var req = new ReproGet(VmId: Guid.NewGuid()); - var result = await func.Run(TestHttpRequestData.FromJson("GET", req)); + var ctx = new TestFunctionContext(); + var result = await func.Run(TestHttpRequestData.FromJson("GET", req), ctx); // TODO: should this be 404? Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode); var err = BodyAs(result); @@ -57,10 +42,10 @@ public async Async.Task GetMissingVmFails() { [Fact] public async Async.Task GetAvailableVMsCanReturnEmpty() { - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new ReproVmss(Logger, auth, Context); + var func = new ReproVmss(Logger, Context); var req = new ReproGet(VmId: null); // this means "all available" - var result = await func.Run(TestHttpRequestData.FromJson("GET", req)); + var ctx = new TestFunctionContext(); + var result = await func.Run(TestHttpRequestData.FromJson("GET", req), ctx); Assert.Equal(HttpStatusCode.OK, result.StatusCode); Assert.Empty(BodyAs(result)); } @@ -77,10 +62,10 @@ await Context.InsertAll( Auth: new SecretValue(new Authentication("", "", "")), Os: Os.Linux)); - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new ReproVmss(Logger, auth, Context); + var func = new ReproVmss(Logger, Context); var req = new ReproGet(VmId: null); // this means "all available" - var result = await func.Run(TestHttpRequestData.FromJson("GET", req)); + var ctx = new TestFunctionContext(); + var result = await func.Run(TestHttpRequestData.FromJson("GET", req), ctx); Assert.Equal(HttpStatusCode.OK, result.StatusCode); var repro = Assert.Single(BodyAs(result)); Assert.Equal(vmId, repro.VmId); @@ -101,10 +86,10 @@ await Context.InsertAll( Auth: new SecretAddress(secretUri), Os: Os.Linux)); - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new ReproVmss(Logger, auth, Context); + var func = new ReproVmss(Logger, Context); var req = new ReproGet(VmId: vmId); - var result = await func.Run(TestHttpRequestData.FromJson("GET", req)); + var ctx = new TestFunctionContext(); + var result = await func.Run(TestHttpRequestData.FromJson("GET", req), ctx); Assert.Equal(HttpStatusCode.OK, result.StatusCode); Assert.Equal(vmId, BodyAs(result).VmId); } @@ -128,37 +113,23 @@ await Context.InsertAll( Os: Os.Linux, State: VmState.Stopped)); - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new ReproVmss(Logger, auth, Context); + var func = new ReproVmss(Logger, Context); var req = new ReproGet(VmId: null); // this means "all available" - var result = await func.Run(TestHttpRequestData.FromJson("GET", req)); + var ctx = new TestFunctionContext(); + var result = await func.Run(TestHttpRequestData.FromJson("GET", req), ctx); Assert.Equal(HttpStatusCode.OK, result.StatusCode); Assert.Empty(BodyAs(result)); } - [Fact] - public async Async.Task CannotCreateVMWithoutCredentials() { - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - - var func = new ReproVmss(Logger, auth, Context); - var req = new ReproCreate(Container.Parse("abcd"), "/", 12345); - var result = await func.Run(TestHttpRequestData.FromJson("POST", req)); - Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode); - var err = BodyAs(result); - Assert.Equal(new ProblemDetails(400, "INVALID_REQUEST", "unable to find authorization token"), err); - } - [Fact] public async Async.Task CannotCreateVMForMissingReport() { - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - // setup fake user - var userInfo = new UserInfo(Guid.NewGuid(), Guid.NewGuid(), "upn"); - Context.UserCredentials = new TestUserCredentials(Logger, Context.ConfigOperations, OneFuzzResult.Ok(userInfo)); + var ctx = new TestFunctionContext(); + ctx.SetUserAuthInfo(new UserInfo(Guid.NewGuid(), Guid.NewGuid(), "upn")); - var func = new ReproVmss(Logger, auth, Context); + var func = new ReproVmss(Logger, Context); var req = new ReproCreate(Container.Parse("abcd"), "/", 12345); - var result = await func.Run(TestHttpRequestData.FromJson("POST", req)); + var result = await func.Run(TestHttpRequestData.FromJson("POST", req), ctx); Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode); var err = BodyAs(result); Assert.Equal(new ProblemDetails(400, "UNABLE_TO_FIND", "unable to find report"), err); @@ -209,15 +180,13 @@ public async Async.Task CannotCreateVMForMissingReport() { public async Async.Task CannotCreateVMForMissingTask() { var (container, filename) = await CreateContainerWithReport(Guid.NewGuid(), Guid.NewGuid()); - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - // setup fake user - var userInfo = new UserInfo(Guid.NewGuid(), Guid.NewGuid(), "upn"); - Context.UserCredentials = new TestUserCredentials(Logger, Context.ConfigOperations, OneFuzzResult.Ok(userInfo)); + var ctx = new TestFunctionContext(); + ctx.SetUserAuthInfo(new UserInfo(Guid.NewGuid(), Guid.NewGuid(), "upn")); - var func = new ReproVmss(Logger, auth, Context); + var func = new ReproVmss(Logger, Context); var req = new ReproCreate(container, filename, 12345); - var result = await func.Run(TestHttpRequestData.FromJson("POST", req)); + var result = await func.Run(TestHttpRequestData.FromJson("POST", req), ctx); Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode); var err = BodyAs(result); Assert.Equal(new ProblemDetails(400, "INVALID_REQUEST", "unable to find task"), err); @@ -241,15 +210,13 @@ await Context.InsertAll( null, new TaskDetails(TaskType.LibfuzzerFuzz, 12345)))); - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - // setup fake user - var userInfo = new UserInfo(Guid.NewGuid(), Guid.NewGuid(), "upn"); - Context.UserCredentials = new TestUserCredentials(Logger, Context.ConfigOperations, OneFuzzResult.Ok(userInfo)); + var ctx = new TestFunctionContext(); + ctx.SetUserAuthInfo(new UserInfo(Guid.NewGuid(), Guid.NewGuid(), "upn")); - var func = new ReproVmss(Logger, auth, Context); + var func = new ReproVmss(Logger, Context); var req = new ReproCreate(container, filename, 12345); - var result = await func.Run(TestHttpRequestData.FromJson("POST", req)); + var result = await func.Run(TestHttpRequestData.FromJson("POST", req), ctx); Assert.Equal(HttpStatusCode.OK, result.StatusCode); var repro = BodyAs(result); Assert.Equal(taskId, repro.TaskId); diff --git a/src/ApiService/IntegrationTests/ScalesetTests.cs b/src/ApiService/IntegrationTests/ScalesetTests.cs index 6ee51271ae..441ed9dc8f 100644 --- a/src/ApiService/IntegrationTests/ScalesetTests.cs +++ b/src/ApiService/IntegrationTests/ScalesetTests.cs @@ -25,28 +25,10 @@ public abstract class ScalesetTestBase : FunctionTestBase { public ScalesetTestBase(ITestOutputHelper output, IStorage storage) : base(output, storage) { } - [Theory] - [InlineData("POST", RequestType.Agent)] - [InlineData("POST", RequestType.NoAuthorization)] - [InlineData("PATCH", RequestType.Agent)] - [InlineData("PATCH", RequestType.NoAuthorization)] - [InlineData("GET", RequestType.Agent)] - [InlineData("GET", RequestType.NoAuthorization)] - [InlineData("DELETE", RequestType.Agent)] - [InlineData("DELETE", RequestType.NoAuthorization)] - public async Async.Task UserAuthorization_IsRequired(string method, RequestType authType) { - var auth = new TestEndpointAuthorization(authType, Logger, Context); - var func = new ScalesetFunction(Logger, auth, Context); - var result = await func.Run(TestHttpRequestData.Empty(method)); - Assert.Equal(HttpStatusCode.Unauthorized, result.StatusCode); - } - [Fact] public async Async.Task Search_SpecificScaleset_ReturnsErrorIfNoneFound() { - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var req = new ScalesetSearch(ScalesetId: ScalesetId.Parse(Guid.NewGuid().ToString())); - var func = new ScalesetFunction(Logger, auth, Context); + var func = new ScalesetFunction(Logger, Context); var result = await func.Run(TestHttpRequestData.FromJson("GET", req)); Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode); @@ -56,10 +38,8 @@ public async Async.Task Search_SpecificScaleset_ReturnsErrorIfNoneFound() { [Fact] public async Async.Task Search_AllScalesets_ReturnsEmptyIfNoneFound() { - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var req = new ScalesetSearch(); - var func = new ScalesetFunction(Logger, auth, Context); + var func = new ScalesetFunction(Logger, Context); var result = await func.Run(TestHttpRequestData.FromJson("GET", req)); Assert.Equal(HttpStatusCode.OK, result.StatusCode); @@ -81,10 +61,8 @@ await Context.InsertAll( new Node(poolName, Guid.NewGuid(), poolId, "version", ScalesetId: scalesetId) ); - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var req = new ScalesetSearch(ScalesetId: scalesetId); - var func = new ScalesetFunction(Logger, auth, Context); + var func = new ScalesetFunction(Logger, Context); var result = await func.Run(TestHttpRequestData.FromJson("GET", req)); Assert.Equal(HttpStatusCode.OK, result.StatusCode); @@ -95,17 +73,10 @@ await Context.InsertAll( [Fact] public async Async.Task Create_Scaleset() { - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - - // override the found user credentials - var userObjectId = Guid.NewGuid(); - var userInfo = new UserInfo(ApplicationId: Guid.NewGuid(), ObjectId: userObjectId, "upn"); - Context.UserCredentials = new TestUserCredentials(Logger, Context.ConfigOperations, OneFuzzResult.Ok(userInfo)); - var poolName = PoolName.Parse("mypool"); await Context.InsertAll( - // user must be admin - new InstanceConfig(Context.ServiceConfiguration.OneFuzzInstanceName!) { Admins = new[] { userObjectId } }, + // config must exist + new InstanceConfig(Context.ServiceConfiguration.OneFuzzInstanceName!), // pool must exist and be managed new Pool(poolName, Guid.NewGuid(), Os.Linux, Managed: true, Architecture.x86_64, PoolState.Running)); @@ -118,8 +89,8 @@ await Context.InsertAll( SpotInstances: false, Tags: new Dictionary()); - var func = new ScalesetFunction(Logger, auth, Context); - var result = await func.Run(TestHttpRequestData.FromJson("POST", req)); + var func = new ScalesetFunction(Logger, Context); + var result = await func.Admin(TestHttpRequestData.FromJson("POST", req)); Assert.Equal(HttpStatusCode.OK, result.StatusCode); @@ -132,17 +103,6 @@ await Context.InsertAll( [Fact] public async Async.Task Create_Scaleset_Under_NonExistent_Pool_Provides_Error() { - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - - // override the found user credentials - var userObjectId = Guid.NewGuid(); - var userInfo = new UserInfo(ApplicationId: Guid.NewGuid(), ObjectId: userObjectId, "upn"); - Context.UserCredentials = new TestUserCredentials(Logger, Context.ConfigOperations, OneFuzzResult.Ok(userInfo)); - - await Context.InsertAll( - // user must be admin - new InstanceConfig(Context.ServiceConfiguration.OneFuzzInstanceName!) { Admins = new[] { userObjectId } }); - var poolName = PoolName.Parse("nosuchpool"); // pool not created var req = new ScalesetCreate( @@ -154,8 +114,8 @@ await Context.InsertAll( SpotInstances: false, Tags: new Dictionary()); - var func = new ScalesetFunction(Logger, auth, Context); - var result = await func.Run(TestHttpRequestData.FromJson("POST", req)); + var func = new ScalesetFunction(Logger, Context); + var result = await func.Admin(TestHttpRequestData.FromJson("POST", req)); Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode); diff --git a/src/ApiService/IntegrationTests/TasksTests.cs b/src/ApiService/IntegrationTests/TasksTests.cs index 847a712153..139c3784ee 100644 --- a/src/ApiService/IntegrationTests/TasksTests.cs +++ b/src/ApiService/IntegrationTests/TasksTests.cs @@ -30,8 +30,7 @@ public TasksTestBase(ITestOutputHelper output, IStorage storage) [Fact] public async Async.Task SpecifyingVmIsNotPermitted() { - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new Tasks(Logger, auth, Context); + var func = new Tasks(Context); var req = new TaskCreate( Guid.NewGuid(), @@ -43,7 +42,8 @@ public async Async.Task SpecifyingVmIsNotPermitted() { var serialized = (JsonObject?)JsonSerializer.SerializeToNode(req, EntityConverter.GetJsonSerializerOptions()); serialized!["vm"] = new JsonObject { { "fake", 1 } }; var testData = new TestHttpRequestData("POST", new BinaryData(JsonSerializer.SerializeToUtf8Bytes(serialized, EntityConverter.GetJsonSerializerOptions()))); - var result = await func.Run(testData); + var ctx = new TestFunctionContext(); + var result = await func.Run(testData, ctx); Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode); var err = BodyAs(result); Assert.Equal("Unexpected property: \"vm\"", err.Detail); @@ -51,12 +51,11 @@ public async Async.Task SpecifyingVmIsNotPermitted() { [Fact] public async Async.Task PoolIsRequired() { - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new Tasks(Logger, auth, Context); + var func = new Tasks(Context); - // override the found user credentials - need these to check for admin - var userInfo = new UserInfo(ApplicationId: Guid.NewGuid(), ObjectId: Guid.NewGuid(), "upn"); - Context.UserCredentials = new TestUserCredentials(Logger, Context.ConfigOperations, OneFuzzResult.Ok(userInfo)); + // override the found user credentials - need these to store user + var ctx = new TestFunctionContext(); + ctx.SetUserAuthInfo(new UserInfo(ApplicationId: Guid.NewGuid(), ObjectId: Guid.NewGuid(), "upn")); var req = new TaskCreate( Guid.NewGuid(), @@ -64,7 +63,7 @@ public async Async.Task PoolIsRequired() { new TaskDetails(TaskType.DotnetCoverage, 100), null! /* <- here */); - var result = await func.Run(TestHttpRequestData.FromJson("POST", req)); + var result = await func.Run(TestHttpRequestData.FromJson("POST", req), ctx); Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode); var err = BodyAs(result); Assert.Equal("The Pool field is required.", err.Detail); @@ -72,15 +71,14 @@ public async Async.Task PoolIsRequired() { [Fact] public async Async.Task CanSearchWithJobIdAndEmptyListOfStates() { - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var req = new TaskSearch( JobId: Guid.NewGuid(), TaskId: null, State: new List()); - var func = new Tasks(Logger, auth, Context); - var result = await func.Run(TestHttpRequestData.FromJson("GET", req)); + var func = new Tasks(Context); + var ctx = new TestFunctionContext(); + var result = await func.Run(TestHttpRequestData.FromJson("GET", req), ctx); Assert.Equal(HttpStatusCode.OK, result.StatusCode); } } diff --git a/src/ApiService/IntegrationTests/ToolsTests.cs b/src/ApiService/IntegrationTests/ToolsTests.cs index 9bb3c3ee0b..9d86db7cfe 100644 --- a/src/ApiService/IntegrationTests/ToolsTests.cs +++ b/src/ApiService/IntegrationTests/ToolsTests.cs @@ -47,8 +47,7 @@ public async Async.Task CanDownload() { var r = await toolsContainerClient.UploadBlobAsync(path.ToString(), BinaryData.FromString(content.ToString())); Assert.False(r.GetRawResponse().IsError); } - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new Tools(auth, Context); + var func = new Tools(Context); var result = await func.Run(TestHttpRequestData.FromJson("GET", "")); Assert.Equal(HttpStatusCode.OK, result.StatusCode); diff --git a/src/ApiService/Tests/AuthTests.cs b/src/ApiService/Tests/AuthTests.cs new file mode 100644 index 0000000000..6b45ddc87b --- /dev/null +++ b/src/ApiService/Tests/AuthTests.cs @@ -0,0 +1,79 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using Castle.Core.Internal; +using Microsoft.Azure.Functions.Worker; +using Microsoft.OneFuzz.Service.Auth; +using Xunit; + +namespace Tests; + +public class AuthTests { + + public static IEnumerable AllFunctionEntryPoints() { + var asm = typeof(AuthorizeAttribute).Assembly; + foreach (var type in asm.GetTypes()) { + if (type.Namespace == "ApiService.TestHooks" + || type.Name == "TestHooks") { + // skip test hooks + continue; + } + + foreach (var method in type.GetMethods()) { + if (method.GetCustomAttribute() is not null) { + // it's a function entrypoint + yield return new object[] { type, method }; + } + } + } + } + + + [Theory] + [MemberData(nameof(AllFunctionEntryPoints))] + public void AllFunctionsHaveAuthAttributes(Type type, MethodInfo methodInfo) { + var trigger = methodInfo.GetParameters().First().GetCustomAttribute(); + if (trigger is null) { + return; // not an HTTP function + } + + // built-in auth level should be anonymous - we are implementing our own authorization + Assert.Equal(AuthorizationLevel.Anonymous, trigger.AuthLevel); + + if (type.Name == "Config" && methodInfo.Name == "Run") { + // this method alone is allowed to be anonymous + Assert.Null(methodInfo.GetAttribute()); + return; + } + + // authorize attribute can be on class or method + var authAttribute = methodInfo.GetAttribute() + ?? type.GetAttribute(); + Assert.NotNull(authAttribute); + + // naming convention: check that Agent* functions have Allow.Agent, and none other + var functionAttribute = methodInfo.GetCustomAttribute()!; + if (functionAttribute.Name.StartsWith("Agent")) { + Assert.Equal(Allow.Agent, authAttribute.Allow); + } else { + Assert.NotEqual(Allow.Agent, authAttribute.Allow); + } + + // naming convention: all *_Admin functions should be ALlow.Admin + // (some that aren't _Admin also require it) + if (functionAttribute.Name.EndsWith("_Admin")) { + Assert.Equal(Allow.Admin, authAttribute.Allow); + } + + // make sure other methods that _aren't_ function entry points don't have it, + // because it won't do anything there, and having it present would be misleading + foreach (var otherMethod in type.GetMethods()) { + if (otherMethod.GetCustomAttribute() is null) { + Assert.True( + otherMethod.GetCustomAttribute() is null, + "non-[Function] methods must not have [Authorize]"); + } + } + } +}