diff --git a/src/ApiService/ApiService/Auth/AuthenticationItems.cs b/src/ApiService/ApiService/Auth/AuthenticationItems.cs new file mode 100644 index 0000000000..b546a40889 --- /dev/null +++ b/src/ApiService/ApiService/Auth/AuthenticationItems.cs @@ -0,0 +1,16 @@ +using Microsoft.Azure.Functions.Worker; + +namespace Microsoft.OneFuzz.Service.Auth; + +public static class AuthenticationItems { + private const string Key = "ONEFUZZ_USER_INFO"; + + public static void SetUserAuthInfo(this FunctionContext context, UserAuthInfo info) + => context.Items[Key] = info; + + public static UserAuthInfo GetUserAuthInfo(this FunctionContext context) + => (UserAuthInfo)context.Items[Key]; + + public static UserAuthInfo? TryGetUserAuthInfo(this FunctionContext context) + => context.Items.TryGetValue(Key, out var result) ? (UserAuthInfo)result : null; +} diff --git a/src/ApiService/ApiService/Auth/AuthenticationMiddleware.cs b/src/ApiService/ApiService/Auth/AuthenticationMiddleware.cs new file mode 100644 index 0000000000..2cec43994b --- /dev/null +++ b/src/ApiService/ApiService/Auth/AuthenticationMiddleware.cs @@ -0,0 +1,111 @@ +using System.IdentityModel.Tokens.Jwt; +using System.Net; +using System.Net.Http.Headers; +using Microsoft.Azure.Functions.Worker; +using Microsoft.Azure.Functions.Worker.Http; +using Microsoft.Azure.Functions.Worker.Middleware; + +namespace Microsoft.OneFuzz.Service.Auth; + +public sealed class AuthenticationMiddleware : IFunctionsWorkerMiddleware { + private readonly IConfigOperations _config; + private readonly ILogTracer _log; + + public AuthenticationMiddleware(IConfigOperations config, ILogTracer log) { + _config = config; + _log = log; + } + + public async Async.Task Invoke(FunctionContext context, FunctionExecutionDelegate next) { + var requestData = await context.GetHttpRequestDataAsync(); + if (requestData is not null) { + var authToken = GetAuthToken(requestData); + if (authToken is not null) { + // note that no validation of the token is performed here + // this is done globally by Azure Functions; see the configuration in + // 'function.bicep' + var token = new JwtSecurityToken(authToken); + var allowedTenants = await AllowedTenants(); + if (!allowedTenants.Contains(token.Issuer)) { + await BadIssuer(requestData, context, token, allowedTenants); + return; + } + + context.SetUserAuthInfo(UserInfoFromAuthToken(token)); + } + } + + await next(context); + } + + private static UserAuthInfo UserInfoFromAuthToken(JwtSecurityToken token) + => token.Payload.Claims.Aggregate( + seed: new UserAuthInfo(new UserInfo(null, null, null), new List()), + (acc, claim) => { + switch (claim.Type) { + case "oid": + return acc with { UserInfo = acc.UserInfo with { ObjectId = Guid.Parse(claim.Value) } }; + case "appid": + return acc with { UserInfo = acc.UserInfo with { ApplicationId = Guid.Parse(claim.Value) } }; + case "upn": + return acc with { UserInfo = acc.UserInfo with { Upn = claim.Value } }; + case "roles": + acc.Roles.Add(claim.Value); + return acc; + default: + return acc; + } + }); + + private async Async.ValueTask BadIssuer( + HttpRequestData request, + FunctionContext context, + JwtSecurityToken token, + IEnumerable allowedTenants) { + + var tenantsStr = string.Join("; ", allowedTenants); + _log.Error($"issuer not from allowed tenant. issuer: {token.Issuer:Tag:Issuer} - tenants: {tenantsStr:Tag:Tenants}"); + + var response = HttpResponseData.CreateResponse(request); + var status = HttpStatusCode.BadRequest; + await response.WriteAsJsonAsync( + new ProblemDetails( + status, + new Error( + ErrorCode.INVALID_REQUEST, + new List { + "unauthorized AAD issuer. If multi-tenant auth is failing, make sure to include all tenant_ids in the `allowed_aad_tenants` list in the instance_config. To see the current instance_config, run `onefuzz instance_config get`. " + } + )), + "application/problem+json", + status); + + context.GetInvocationResult().Value = response; + } + + private async Async.Task> AllowedTenants() { + var config = await _config.Fetch(); + return config.AllowedAadTenants.Select(t => $"https://sts.windows.net/{t}/"); + } + + private static string? GetAuthToken(HttpRequestData requestData) + => GetBearerToken(requestData) ?? GetAadIdToken(requestData); + + private static string? GetAadIdToken(HttpRequestData requestData) { + if (!requestData.Headers.TryGetValues("x-ms-token-aad-id-token", out var values)) { + return null; + } + + return values.First(); + } + + private static string? GetBearerToken(HttpRequestData requestData) { + if (!requestData.Headers.TryGetValues("Authorization", out var values) + || !AuthenticationHeaderValue.TryParse(values.First(), out var headerValue) + || !string.Equals(headerValue.Scheme, "Bearer", StringComparison.OrdinalIgnoreCase)) { + return null; + } + + return headerValue.Parameter; + } +} diff --git a/src/ApiService/ApiService/Auth/AuthorizationMiddleware.cs b/src/ApiService/ApiService/Auth/AuthorizationMiddleware.cs new file mode 100644 index 0000000000..34ca16ab28 --- /dev/null +++ b/src/ApiService/ApiService/Auth/AuthorizationMiddleware.cs @@ -0,0 +1,103 @@ +using System.Collections.Immutable; +using System.Diagnostics; +using System.Net; +using System.Reflection; +using Microsoft.Azure.Functions.Worker; +using Microsoft.Azure.Functions.Worker.Http; +using Microsoft.Azure.Functions.Worker.Middleware; + +namespace Microsoft.OneFuzz.Service.Auth; + +public sealed class AuthorizationMiddleware : IFunctionsWorkerMiddleware { + private readonly IEndpointAuthorization _auth; + private readonly ILogTracer _log; + + public AuthorizationMiddleware(IEndpointAuthorization auth, ILogTracer log) { + _auth = auth; + _log = log; + } + + public async Async.Task Invoke(FunctionContext context, FunctionExecutionDelegate next) { + var attribute = GetAuthorizeAttribute(context); + if (attribute is not null) { + var req = await context.GetHttpRequestDataAsync() ?? throw new NotSupportedException("no HTTP request data found"); + var user = context.TryGetUserAuthInfo(); + if (user is null) { + await Reject(req, context, "no authentication"); + return; + } + + var (isAgent, _) = await _auth.IsAgent(user); + if (isAgent) { + if (attribute.Allow != Allow.Agent) { + await Reject(req, context, "endpoint not allowed for agents"); + return; + } + } else { + if (attribute.Allow == Allow.Agent) { + await Reject(req, context, "endpoint not allowed for users"); + return; + } + + Debug.Assert(attribute.Allow is Allow.User or Allow.Admin); + + // check access control first + var access = await _auth.CheckAccess(req); + if (!access.IsOk) { + await Reject(req, context, "access control rejected request"); + return; + } + + // check admin next + if (attribute.Allow == Allow.Admin) { + var adminAccess = await _auth.CheckRequireAdmins(user); + if (!adminAccess.IsOk) { + await Reject(req, context, "must be admin to use this endpoint"); + return; + } + } + } + } + + await next(context); + } + + private static async Async.ValueTask Reject(HttpRequestData request, FunctionContext context, string reason) { + var response = HttpResponseData.CreateResponse(request); + var status = HttpStatusCode.Unauthorized; + await response.WriteAsJsonAsync( + new ProblemDetails( + status, + Error.Create(ErrorCode.UNAUTHORIZED, reason)), + "application/problem+json", + status); + + context.GetInvocationResult().Value = response; + } + + // use ImmutableDictionary to prevent needing to lock and without the overhead + // of ConcurrentDictionary + private static ImmutableDictionary _authorizeCache = + ImmutableDictionary.Create(); + + private static AuthorizeAttribute? GetAuthorizeAttribute(FunctionContext context) { + // fully-qualified name of the method + var entryPoint = context.FunctionDefinition.EntryPoint; + if (_authorizeCache.TryGetValue(entryPoint, out var cached)) { + return cached; + } + + var lastDot = entryPoint.LastIndexOf('.'); + var (typeName, methodName) = (entryPoint[..lastDot], entryPoint[(lastDot + 1)..]); + var assemblyPath = context.FunctionDefinition.PathToAssembly; + var assembly = Assembly.LoadFrom(assemblyPath); // should already be loaded + var type = assembly.GetType(typeName)!; + var method = type.GetMethod(methodName)!; + var result = + method.GetCustomAttribute() + ?? type.GetCustomAttribute(); + + _authorizeCache = _authorizeCache.SetItem(entryPoint, result); + return result; + } +} diff --git a/src/ApiService/ApiService/Auth/AuthorizeAttribute.cs b/src/ApiService/ApiService/Auth/AuthorizeAttribute.cs new file mode 100644 index 0000000000..c195e33e0f --- /dev/null +++ b/src/ApiService/ApiService/Auth/AuthorizeAttribute.cs @@ -0,0 +1,17 @@ +namespace Microsoft.OneFuzz.Service.Auth; + +[AttributeUsage(AttributeTargets.Class | AttributeTargets.Method)] +public sealed class AuthorizeAttribute : Attribute { + public AuthorizeAttribute(Allow allow) { + Allow = allow; + } + + public Allow Allow { get; set; } +} + +public enum Allow { + Agent, + User, + Admin, + +} diff --git a/src/ApiService/ApiService/Functions/AgentCanSchedule.cs b/src/ApiService/ApiService/Functions/AgentCanSchedule.cs index 8e83a015b8..b4d54c9477 100644 --- a/src/ApiService/ApiService/Functions/AgentCanSchedule.cs +++ b/src/ApiService/ApiService/Functions/AgentCanSchedule.cs @@ -1,27 +1,23 @@ using Microsoft.Azure.Functions.Worker; using Microsoft.Azure.Functions.Worker.Http; +using Microsoft.OneFuzz.Service.Auth; namespace Microsoft.OneFuzz.Service.Functions; public class AgentCanSchedule { private readonly ILogTracer _log; - private readonly IEndpointAuthorization _auth; private readonly IOnefuzzContext _context; - - public AgentCanSchedule(ILogTracer log, IEndpointAuthorization auth, IOnefuzzContext context) { + public AgentCanSchedule(ILogTracer log, IOnefuzzContext context) { _log = log; - _auth = auth; _context = context; } [Function("AgentCanSchedule")] - public Async.Task Run( + [Authorize(Allow.Agent)] + public async Async.Task Run( [HttpTrigger(AuthorizationLevel.Anonymous, "POST", Route="agents/can_schedule")] - HttpRequestData req) - => _auth.CallIfAgent(req, Post); - - private async Async.Task Post(HttpRequestData req) { + HttpRequestData req) { var request = await RequestHandling.ParseRequest(req); if (!request.IsOk) { _log.Warning($"Cannot schedule due to {request.ErrorV}"); diff --git a/src/ApiService/ApiService/Functions/AgentCommands.cs b/src/ApiService/ApiService/Functions/AgentCommands.cs index 5c1d7721a4..2ba68e9e64 100644 --- a/src/ApiService/ApiService/Functions/AgentCommands.cs +++ b/src/ApiService/ApiService/Functions/AgentCommands.cs @@ -1,28 +1,28 @@ using Microsoft.Azure.Functions.Worker; using Microsoft.Azure.Functions.Worker.Http; +using Microsoft.OneFuzz.Service.Auth; namespace Microsoft.OneFuzz.Service.Functions; public class AgentCommands { private readonly ILogTracer _log; - private readonly IEndpointAuthorization _auth; private readonly IOnefuzzContext _context; - public AgentCommands(ILogTracer log, IEndpointAuthorization auth, IOnefuzzContext context) { + public AgentCommands(ILogTracer log, IOnefuzzContext context) { _log = log; - _auth = auth; _context = context; } [Function("AgentCommands")] + [Authorize(Allow.Agent)] public Async.Task Run( [HttpTrigger(AuthorizationLevel.Anonymous, "GET", "DELETE", Route="agents/commands")] HttpRequestData req) - => _auth.CallIfAgent(req, r => r.Method switch { + => req.Method switch { "GET" => Get(req), "DELETE" => Delete(req), _ => throw new NotSupportedException($"HTTP Method {req.Method} is not supported for this method") - }); + }; private async Async.Task Get(HttpRequestData req) { var request = await RequestHandling.ParseRequest(req); diff --git a/src/ApiService/ApiService/Functions/AgentEvents.cs b/src/ApiService/ApiService/Functions/AgentEvents.cs index 46f329856d..63559dd102 100644 --- a/src/ApiService/ApiService/Functions/AgentEvents.cs +++ b/src/ApiService/ApiService/Functions/AgentEvents.cs @@ -1,28 +1,25 @@ using System.Threading.Tasks; using Microsoft.Azure.Functions.Worker; using Microsoft.Azure.Functions.Worker.Http; +using Microsoft.OneFuzz.Service.Auth; using Microsoft.OneFuzz.Service.OneFuzzLib.Orm; namespace Microsoft.OneFuzz.Service.Functions; public class AgentEvents { private readonly ILogTracer _log; - private readonly IEndpointAuthorization _auth; private readonly IOnefuzzContext _context; - public AgentEvents(ILogTracer log, IEndpointAuthorization auth, IOnefuzzContext context) { + public AgentEvents(ILogTracer log, IOnefuzzContext context) { _log = log; - _auth = auth; _context = context; } [Function("AgentEvents")] - public Async.Task Run( + [Authorize(Allow.Agent)] + public async Async.Task Run( [HttpTrigger(AuthorizationLevel.Anonymous, "POST", Route="agents/events")] - HttpRequestData req) - => _auth.CallIfAgent(req, Post); - - private async Async.Task Post(HttpRequestData req) { + HttpRequestData req) { var request = await RequestHandling.ParseRequest(req); if (!request.IsOk) { return await _context.RequestHandling.NotOk(req, request.ErrorV, context: "node event"); diff --git a/src/ApiService/ApiService/Functions/AgentRegistration.cs b/src/ApiService/ApiService/Functions/AgentRegistration.cs index 6b715c2245..31134c78cb 100644 --- a/src/ApiService/ApiService/Functions/AgentRegistration.cs +++ b/src/ApiService/ApiService/Functions/AgentRegistration.cs @@ -1,33 +1,31 @@ using Azure.Storage.Sas; using Microsoft.Azure.Functions.Worker; using Microsoft.Azure.Functions.Worker.Http; +using Microsoft.OneFuzz.Service.Auth; namespace Microsoft.OneFuzz.Service.Functions; public class AgentRegistration { private readonly ILogTracer _log; - private readonly IEndpointAuthorization _auth; private readonly IOnefuzzContext _context; - public AgentRegistration(ILogTracer log, IEndpointAuthorization auth, IOnefuzzContext context) { + public AgentRegistration(ILogTracer log, IOnefuzzContext context) { _log = log; - _auth = auth; _context = context; } [Function("AgentRegistration")] + [Authorize(Allow.Agent)] public Async.Task Run( [HttpTrigger( AuthorizationLevel.Anonymous, "GET", "POST", Route="agents/registration")] HttpRequestData req) - => _auth.CallIfAgent( - req, - r => r.Method switch { - "GET" => Get(r), - "POST" => Post(r), - var m => throw new InvalidOperationException($"method {m} not supported"), - }); + => req.Method switch { + "GET" => Get(req), + "POST" => Post(req), + var m => throw new InvalidOperationException($"method {m} not supported"), + }; private async Async.Task Get(HttpRequestData req) { var request = await RequestHandling.ParseUri(req); diff --git a/src/ApiService/ApiService/Functions/Config.cs b/src/ApiService/ApiService/Functions/Config.cs index 2704ab6b00..6f4037584b 100644 --- a/src/ApiService/ApiService/Functions/Config.cs +++ b/src/ApiService/ApiService/Functions/Config.cs @@ -14,11 +14,10 @@ public Config(ILogTracer log, IOnefuzzContext context) { } [Function("Config")] - public Async.Task Run( - [HttpTrigger(AuthorizationLevel.Anonymous, "GET")] HttpRequestData req) { - return Get(req); - } - public async Async.Task Get(HttpRequestData req) { + public async Async.Task Run( + [HttpTrigger(AuthorizationLevel.Anonymous, "GET")] + HttpRequestData req) { + _log.Info($"getting endpoint config parameters"); var endpointParams = new ConfigResponse( diff --git a/src/ApiService/ApiService/Functions/Containers.cs b/src/ApiService/ApiService/Functions/Containers.cs index c3bedcfd66..8eed38e23b 100644 --- a/src/ApiService/ApiService/Functions/Containers.cs +++ b/src/ApiService/ApiService/Functions/Containers.cs @@ -1,28 +1,28 @@ using Azure.Storage.Sas; using Microsoft.Azure.Functions.Worker; using Microsoft.Azure.Functions.Worker.Http; +using Microsoft.OneFuzz.Service.Auth; namespace Microsoft.OneFuzz.Service.Functions; public class ContainersFunction { private readonly ILogTracer _logger; - private readonly IEndpointAuthorization _auth; private readonly IOnefuzzContext _context; - public ContainersFunction(ILogTracer logger, IEndpointAuthorization auth, IOnefuzzContext context) { + public ContainersFunction(ILogTracer logger, IOnefuzzContext context) { _logger = logger; - _auth = auth; _context = context; } [Function("Containers")] + [Authorize(Allow.User)] public Async.Task Run([HttpTrigger(AuthorizationLevel.Anonymous, "GET", "POST", "DELETE")] HttpRequestData req) - => _auth.CallIfUser(req, r => r.Method switch { - "GET" => Get(r), - "POST" => Post(r), - "DELETE" => Delete(r), + => req.Method switch { + "GET" => Get(req), + "POST" => Post(req), + "DELETE" => Delete(req), _ => throw new NotSupportedException(), - }); + }; private async Async.Task Get(HttpRequestData req) { diff --git a/src/ApiService/ApiService/Functions/Download.cs b/src/ApiService/ApiService/Functions/Download.cs index dac128b08b..27202afad1 100644 --- a/src/ApiService/ApiService/Functions/Download.cs +++ b/src/ApiService/ApiService/Functions/Download.cs @@ -2,23 +2,20 @@ using Azure.Storage.Sas; using Microsoft.Azure.Functions.Worker; using Microsoft.Azure.Functions.Worker.Http; +using Microsoft.OneFuzz.Service.Auth; namespace Microsoft.OneFuzz.Service.Functions; public class Download { - private readonly IEndpointAuthorization _auth; private readonly IOnefuzzContext _context; - public Download(IEndpointAuthorization auth, IOnefuzzContext context) { - _auth = auth; + public Download(IOnefuzzContext context) { _context = context; } [Function("Download")] - public Async.Task Run([HttpTrigger(AuthorizationLevel.Anonymous, "GET")] HttpRequestData req) - => _auth.CallIfUser(req, Get); - - private async Async.Task Get(HttpRequestData req) { + [Authorize(Allow.User)] + public async Async.Task Run([HttpTrigger(AuthorizationLevel.Anonymous, "GET")] HttpRequestData req) { var query = HttpUtility.ParseQueryString(req.Url.Query); var queryContainer = query["container"]; diff --git a/src/ApiService/ApiService/Functions/Events.cs b/src/ApiService/ApiService/Functions/Events.cs index 83cb585a19..67c25febec 100644 --- a/src/ApiService/ApiService/Functions/Events.cs +++ b/src/ApiService/ApiService/Functions/Events.cs @@ -1,27 +1,21 @@ using Microsoft.Azure.Functions.Worker; using Microsoft.Azure.Functions.Worker.Http; +using Microsoft.OneFuzz.Service.Auth; namespace Microsoft.OneFuzz.Service.Functions; public class EventsFunction { private readonly ILogTracer _log; - private readonly IEndpointAuthorization _auth; private readonly IOnefuzzContext _context; - public EventsFunction(ILogTracer log, IEndpointAuthorization auth, IOnefuzzContext context) { - _auth = auth; + public EventsFunction(ILogTracer log, IOnefuzzContext context) { _context = context; _log = log; } [Function("Events")] - public Async.Task Run([HttpTrigger(AuthorizationLevel.Anonymous, "GET")] HttpRequestData req) - => _auth.CallIfUser(req, r => r.Method switch { - "GET" => Get(r), - _ => throw new NotSupportedException(), - }); - - private async Async.Task Get(HttpRequestData req) { + [Authorize(Allow.User)] + public async Async.Task Run([HttpTrigger(AuthorizationLevel.Anonymous, "GET")] HttpRequestData req) { var request = await RequestHandling.ParseRequest(req); if (!request.IsOk) { return await _context.RequestHandling.NotOk(req, request.ErrorV, "events get"); diff --git a/src/ApiService/ApiService/Functions/Info.cs b/src/ApiService/ApiService/Functions/Info.cs index 1c5aef1033..6dd32a12a7 100644 --- a/src/ApiService/ApiService/Functions/Info.cs +++ b/src/ApiService/ApiService/Functions/Info.cs @@ -4,17 +4,16 @@ using System.Threading; using Microsoft.Azure.Functions.Worker; using Microsoft.Azure.Functions.Worker.Http; +using Microsoft.OneFuzz.Service.Auth; namespace Microsoft.OneFuzz.Service.Functions; public class Info { private readonly IOnefuzzContext _context; - private readonly IEndpointAuthorization _auth; private readonly Lazy> _response; - public Info(IEndpointAuthorization auth, IOnefuzzContext context) { + public Info(IOnefuzzContext context) { _context = context; - _auth = auth; // TODO: this isn’t actually shared between calls at the moment, // this needs to be placed into a class that can be registered into the @@ -60,10 +59,8 @@ private static string ReadResource(Assembly asm, string resourceName) { return sr.ReadToEnd().Trim(); } - private async Async.Task GetResponse(HttpRequestData req) - => await RequestHandling.Ok(req, await _response.Value); - [Function("Info")] - public Async.Task Run([HttpTrigger(AuthorizationLevel.Anonymous, "GET")] HttpRequestData req) - => _auth.CallIfUser(req, GetResponse); + [Authorize(Allow.User)] + public async Async.Task Run([HttpTrigger(AuthorizationLevel.Anonymous, "GET")] HttpRequestData req) + => await RequestHandling.Ok(req, await _response.Value); } diff --git a/src/ApiService/ApiService/Functions/InstanceConfig.cs b/src/ApiService/ApiService/Functions/InstanceConfig.cs index 59ae6d316c..4718c8e3f5 100644 --- a/src/ApiService/ApiService/Functions/InstanceConfig.cs +++ b/src/ApiService/ApiService/Functions/InstanceConfig.cs @@ -2,30 +2,26 @@ using System.Threading.Tasks; using Microsoft.Azure.Functions.Worker; using Microsoft.Azure.Functions.Worker.Http; +using Microsoft.OneFuzz.Service.Auth; namespace Microsoft.OneFuzz.Service.Functions; public class InstanceConfig { private readonly ILogTracer _log; - private readonly IEndpointAuthorization _auth; private readonly IOnefuzzContext _context; - public InstanceConfig(ILogTracer log, IEndpointAuthorization auth, IOnefuzzContext context) { + public InstanceConfig(ILogTracer log, IOnefuzzContext context) { _log = log; - _auth = auth; _context = context; } + public const string Route = "instance_config"; + [Function("InstanceConfig")] - public Async.Task Run( - [HttpTrigger(AuthorizationLevel.Anonymous, "GET", "POST", Route = "instance_config")] HttpRequestData req) { - return _auth.CallIfUser(req, r => r.Method switch { - "GET" => Get(r), - "POST" => Post(r), - _ => throw new InvalidOperationException("Unsupported HTTP method"), - }); - } - public async Async.Task Get(HttpRequestData req) { + [Authorize(Allow.User)] + public async Task Get( + [HttpTrigger(AuthorizationLevel.Anonymous, "GET", Route=Route)] + HttpRequestData req) { _log.Info($"getting instance_config"); var config = await _context.ConfigOperations.Fetch(); @@ -34,7 +30,11 @@ public async Async.Task Get(HttpRequestData req) { return response; } - public async Async.Task Post(HttpRequestData req) { + [Function("InstanceConfig_Admin")] + [Authorize(Allow.Admin)] + public async Task Post( + [HttpTrigger(AuthorizationLevel.Anonymous, "POST", Route=Route)] + HttpRequestData req) { _log.Info($"attempting instance_config update"); var request = await RequestHandling.ParseRequest(req); @@ -44,12 +44,8 @@ public async Async.Task Post(HttpRequestData req) { request.ErrorV, context: "instance_config update"); } - var (config, answer) = await ( - _context.ConfigOperations.Fetch(), - _auth.CheckRequireAdmins(req)); - if (!answer.IsOk) { - return await _context.RequestHandling.NotOk(req, answer.ErrorV, "instance_config update"); - } + + var config = await _context.ConfigOperations.Fetch(); var updateNsg = false; if (request.OkV.config.ProxyNsgConfig is NetworkSecurityGroupConfig requestConfig && config.ProxyNsgConfig is NetworkSecurityGroupConfig currentConfig) { @@ -58,7 +54,9 @@ public async Async.Task Post(HttpRequestData req) { updateNsg = true; } } + await _context.ConfigOperations.Save(request.OkV.config, false, false); + if (updateNsg) { await foreach (var nsg in _context.NsgOperations.ListNsgs()) { _log.Info($"Checking if nsg: {nsg.Data.Location!:Tag:Location} ({nsg.Data.Name:Tag:NsgName}) owned by OneFuzz"); diff --git a/src/ApiService/ApiService/Functions/Jobs.cs b/src/ApiService/ApiService/Functions/Jobs.cs index c7ee52bb5f..9326f773f1 100644 --- a/src/ApiService/ApiService/Functions/Jobs.cs +++ b/src/ApiService/ApiService/Functions/Jobs.cs @@ -1,39 +1,39 @@ using System.Threading.Tasks; using Microsoft.Azure.Functions.Worker; using Microsoft.Azure.Functions.Worker.Http; +using Microsoft.OneFuzz.Service.Auth; namespace Microsoft.OneFuzz.Service.Functions; public class Jobs { private readonly IOnefuzzContext _context; - private readonly IEndpointAuthorization _auth; private readonly ILogTracer _logTracer; - public Jobs(IEndpointAuthorization auth, IOnefuzzContext context, ILogTracer logTracer) { + public Jobs(IOnefuzzContext context, ILogTracer logTracer) { _context = context; - _auth = auth; _logTracer = logTracer; } [Function("Jobs")] - public Async.Task Run([HttpTrigger(AuthorizationLevel.Anonymous, "GET", "POST", "DELETE")] HttpRequestData req) - => _auth.CallIfUser(req, r => r.Method switch { - "GET" => Get(r), - "DELETE" => Delete(r), - "POST" => Post(r), + [Authorize(Allow.User)] + public Async.Task Run( + [HttpTrigger(AuthorizationLevel.Anonymous, "GET", "POST", "DELETE")] + HttpRequestData req, + FunctionContext context) + => req.Method switch { + "GET" => Get(req), + "DELETE" => Delete(req), + "POST" => Post(req, context), var m => throw new NotSupportedException($"Unsupported HTTP method {m}"), - }); + }; - private async Task Post(HttpRequestData req) { + private async Task Post(HttpRequestData req, FunctionContext context) { var request = await RequestHandling.ParseRequest(req); if (!request.IsOk) { return await _context.RequestHandling.NotOk(req, request.ErrorV, "jobs create"); } - var userInfo = await _context.UserCredentials.ParseJwtToken(req); - if (!userInfo.IsOk) { - return await _context.RequestHandling.NotOk(req, userInfo.ErrorV, "jobs create"); - } + var userInfo = context.GetUserAuthInfo(); var create = request.OkV; var cfg = new JobConfig( @@ -47,7 +47,7 @@ private async Task Post(HttpRequestData req) { JobId: Guid.NewGuid(), State: JobState.Init, Config: cfg) { - UserInfo = userInfo.OkV.UserInfo, + UserInfo = userInfo.UserInfo, }; // create the job logs container diff --git a/src/ApiService/ApiService/Functions/Migrations/JinjaToScriban.cs b/src/ApiService/ApiService/Functions/Migrations/JinjaToScriban.cs index 197c90ecae..0e69ab5103 100644 --- a/src/ApiService/ApiService/Functions/Migrations/JinjaToScriban.cs +++ b/src/ApiService/ApiService/Functions/Migrations/JinjaToScriban.cs @@ -1,28 +1,26 @@ using Microsoft.Azure.Functions.Worker; using Microsoft.Azure.Functions.Worker.Http; +using Microsoft.OneFuzz.Service.Auth; namespace Microsoft.OneFuzz.Service.Functions; public class JinjaToScriban { private readonly ILogTracer _log; - private readonly IEndpointAuthorization _auth; private readonly IOnefuzzContext _context; - public JinjaToScriban(ILogTracer log, IEndpointAuthorization auth, IOnefuzzContext context) { + public JinjaToScriban(ILogTracer log, IOnefuzzContext context) { _log = log; - _auth = auth; _context = context; } [Function("JinjaToScriban")] - public Async.Task Run( + [Authorize(Allow.Admin)] + public async Async.Task Run( [HttpTrigger(AuthorizationLevel.Anonymous, "POST", Route="migrations/jinja_to_scriban")] - HttpRequestData req) - => _auth.CallIfUser(req, Post); + HttpRequestData req) { - private async Async.Task Post(HttpRequestData req) { var request = await RequestHandling.ParseRequest(req); if (!request.IsOk) { return await _context.RequestHandling.NotOk( @@ -31,11 +29,6 @@ private async Async.Task Post(HttpRequestData req) { "JinjaToScriban"); } - var answer = await _auth.CheckRequireAdmins(req); - if (!answer.IsOk) { - return await _context.RequestHandling.NotOk(req, answer.ErrorV, "JinjaToScriban"); - } - _log.Info($"Finding notifications to migrate"); var notifications = _context.NotificationOperations.SearchAll() diff --git a/src/ApiService/ApiService/Functions/Negotiate.cs b/src/ApiService/ApiService/Functions/Negotiate.cs index 21e5f05c9a..16f9496ec1 100644 --- a/src/ApiService/ApiService/Functions/Negotiate.cs +++ b/src/ApiService/ApiService/Functions/Negotiate.cs @@ -2,32 +2,24 @@ using System.Threading.Tasks; using Microsoft.Azure.Functions.Worker; using Microsoft.Azure.Functions.Worker.Http; +using Microsoft.OneFuzz.Service.Auth; namespace Microsoft.OneFuzz.Service.Functions; public class Negotiate { - private readonly IEndpointAuthorization _auth; - public Negotiate(IEndpointAuthorization auth) { - _auth = auth; - } - [Function("Negotiate")] - public Task Run( + [Authorize(Allow.User)] + public static async Task Run( [HttpTrigger(AuthorizationLevel.Anonymous, "POST")] HttpRequestData req, - [SignalRConnectionInfoInput(HubName = "dashboard")] string info) - => _auth.CallIfUser(req, r => r.Method switch { - "POST" => Post(r, info), - var m => throw new InvalidOperationException($"Unsupported HTTP method {m}"), - }); + [SignalRConnectionInfoInput(HubName = "dashboard")] string info) { - // This endpoint handles the signalr negotation - // As we do not differentiate from clients at this time, we pass the Functions runtime - // provided connection straight to the client - // - // For more info: - // https://docs.microsoft.com/en-us/azure/azure-signalr/signalr-concept-internals + // This endpoint handles the signalr negotation + // As we do not differentiate from clients at this time, we pass the Functions runtime + // provided connection straight to the client + // + // For more info: + // https://docs.microsoft.com/en-us/azure/azure-signalr/signalr-concept-internals - private static async Task Post(HttpRequestData req, string info) { var resp = req.CreateResponse(HttpStatusCode.OK); resp.Headers.Add("Content-Type", "application/json"); await resp.WriteStringAsync(info); diff --git a/src/ApiService/ApiService/Functions/Node.cs b/src/ApiService/ApiService/Functions/Node.cs index 70231b90d2..b7f033ce22 100644 --- a/src/ApiService/ApiService/Functions/Node.cs +++ b/src/ApiService/ApiService/Functions/Node.cs @@ -1,30 +1,42 @@ using System.Threading.Tasks; using Microsoft.Azure.Functions.Worker; using Microsoft.Azure.Functions.Worker.Http; +using Microsoft.OneFuzz.Service.Auth; namespace Microsoft.OneFuzz.Service.Functions; public class Node { private readonly ILogTracer _log; - private readonly IEndpointAuthorization _auth; private readonly IOnefuzzContext _context; - public Node(ILogTracer log, IEndpointAuthorization auth, IOnefuzzContext context) { + public Node(ILogTracer log, IOnefuzzContext context) { _log = log; - _auth = auth; _context = context; } + public const string Route = "node"; + [Function("Node")] - public Async.Task Run([HttpTrigger(AuthorizationLevel.Anonymous, "GET", "PATCH", "POST", "DELETE")] HttpRequestData req) { - return _auth.CallIfUser(req, r => r.Method switch { - "GET" => Get(r), - "PATCH" => Patch(r), - "POST" => Post(r), - "DELETE" => Delete(r), + [Authorize(Allow.User)] + public Task Run( + [HttpTrigger(AuthorizationLevel.Anonymous, "GET", Route=Route)] + HttpRequestData req) + => req.Method switch { + "GET" => Get(req), _ => throw new InvalidOperationException("Unsupported HTTP method"), - }); - } + }; + + [Function("Node_Admin")] + [Authorize(Allow.Admin)] + public Task Admin( + [HttpTrigger(AuthorizationLevel.Anonymous, "PATCH", "POST", "DELETE", Route=Route)] + HttpRequestData req) + => req.Method switch { + "PATCH" => Patch(req), + "POST" => Post(req), + "DELETE" => Delete(req), + _ => throw new InvalidOperationException("Unsupported HTTP method"), + }; private async Async.Task Get(HttpRequestData req) { var request = await RequestHandling.ParseRequest(req); @@ -82,11 +94,6 @@ private async Async.Task Patch(HttpRequestData req) { "NodeReimage"); } - var authCheck = await _auth.CheckRequireAdmins(req); - if (!authCheck.IsOk) { - return await _context.RequestHandling.NotOk(req, authCheck.ErrorV, "NodeReimage"); - } - var patch = request.OkV; var node = await _context.NodeOperations.GetByMachineId(patch.MachineId); if (node is null) { @@ -116,11 +123,6 @@ private async Async.Task Post(HttpRequestData req) { "NodeUpdate"); } - var authCheck = await _auth.CheckRequireAdmins(req); - if (!authCheck.IsOk) { - return await _context.RequestHandling.NotOk(req, authCheck.ErrorV, "NodeUpdate"); - } - var post = request.OkV; var node = await _context.NodeOperations.GetByMachineId(post.MachineId); if (node is null) { @@ -150,11 +152,6 @@ private async Async.Task Delete(HttpRequestData req) { context: "NodeDelete"); } - var authCheck = await _auth.CheckRequireAdmins(req); - if (!authCheck.IsOk) { - return await _context.RequestHandling.NotOk(req, authCheck.ErrorV, "NodeDelete"); - } - var delete = request.OkV; var node = await _context.NodeOperations.GetByMachineId(delete.MachineId); if (node is null) { diff --git a/src/ApiService/ApiService/Functions/NodeAddSshKey.cs b/src/ApiService/ApiService/Functions/NodeAddSshKey.cs index 1bce7c9fb2..524acaf8fd 100644 --- a/src/ApiService/ApiService/Functions/NodeAddSshKey.cs +++ b/src/ApiService/ApiService/Functions/NodeAddSshKey.cs @@ -1,22 +1,20 @@ using System.Net; using Microsoft.Azure.Functions.Worker; using Microsoft.Azure.Functions.Worker.Http; +using Microsoft.OneFuzz.Service.Auth; namespace Microsoft.OneFuzz.Service.Functions; public class NodeAddSshKey { - - private readonly ILogTracer _log; - private readonly IEndpointAuthorization _auth; private readonly IOnefuzzContext _context; - public NodeAddSshKey(ILogTracer log, IEndpointAuthorization auth, IOnefuzzContext context) { - _log = log; - _auth = auth; + public NodeAddSshKey(IOnefuzzContext context) { _context = context; } - private async Async.Task Post(HttpRequestData req) { + [Function("NodeAddSshKey")] + [Authorize(Allow.User)] + public async Async.Task Run([HttpTrigger(AuthorizationLevel.Anonymous, "POST", Route = "node/add_ssh_key")] HttpRequestData req) { var request = await RequestHandling.ParseRequest(req); if (!request.IsOk) { return await _context.RequestHandling.NotOk( @@ -42,16 +40,5 @@ private async Async.Task Post(HttpRequestData req) { var response = req.CreateResponse(HttpStatusCode.OK); await response.WriteAsJsonAsync(new BoolResult(true)); return response; - - } - - [Function("NodeAddSshKey")] - public Async.Task Run([HttpTrigger(AuthorizationLevel.Anonymous, "POST", Route = "node/add_ssh_key")] HttpRequestData req) { - return _auth.CallIfUser(req, r => r.Method switch { - "POST" => Post(r), - _ => throw new InvalidOperationException("Unsupported HTTP method"), - }); - } - } diff --git a/src/ApiService/ApiService/Functions/Notifications.cs b/src/ApiService/ApiService/Functions/Notifications.cs index 5013521a26..323ef65c61 100644 --- a/src/ApiService/ApiService/Functions/Notifications.cs +++ b/src/ApiService/ApiService/Functions/Notifications.cs @@ -1,17 +1,16 @@ using System.Net; using Microsoft.Azure.Functions.Worker; using Microsoft.Azure.Functions.Worker.Http; +using Microsoft.OneFuzz.Service.Auth; namespace Microsoft.OneFuzz.Service.Functions; public class Notifications { private readonly ILogTracer _log; - private readonly IEndpointAuthorization _auth; private readonly IOnefuzzContext _context; - public Notifications(ILogTracer log, IEndpointAuthorization auth, IOnefuzzContext context) { + public Notifications(ILogTracer log, IOnefuzzContext context) { _log = log; - _auth = auth; _context = context; } @@ -82,12 +81,12 @@ private async Async.Task Delete(HttpRequestData req) { [Function("Notifications")] - public Async.Task Run([HttpTrigger(AuthorizationLevel.Anonymous, "GET", "POST", "DELETE")] HttpRequestData req) { - return _auth.CallIfUser(req, r => r.Method switch { - "GET" => Get(r), - "POST" => Post(r), - "DELETE" => Delete(r), + [Authorize(Allow.User)] + public Async.Task Run([HttpTrigger(AuthorizationLevel.Anonymous, "GET", "POST", "DELETE")] HttpRequestData req) + => req.Method switch { + "GET" => Get(req), + "POST" => Post(req), + "DELETE" => Delete(req), _ => throw new InvalidOperationException("Unsupported HTTP method"), - }); - } + }; } diff --git a/src/ApiService/ApiService/Functions/NotificationsTest.cs b/src/ApiService/ApiService/Functions/NotificationsTest.cs index 16a1a5c982..11beedcf70 100644 --- a/src/ApiService/ApiService/Functions/NotificationsTest.cs +++ b/src/ApiService/ApiService/Functions/NotificationsTest.cs @@ -1,21 +1,22 @@ using System.Net; using Microsoft.Azure.Functions.Worker; using Microsoft.Azure.Functions.Worker.Http; +using Microsoft.OneFuzz.Service.Auth; namespace Microsoft.OneFuzz.Service.Functions; public class NotificationsTest { private readonly ILogTracer _log; - private readonly IEndpointAuthorization _auth; private readonly IOnefuzzContext _context; - public NotificationsTest(ILogTracer log, IEndpointAuthorization auth, IOnefuzzContext context) { + public NotificationsTest(ILogTracer log, IOnefuzzContext context) { _log = log; - _auth = auth; _context = context; } - private async Async.Task Post(HttpRequestData req) { + [Function("NotificationsTest")] + [Authorize(Allow.User)] + public async Async.Task Run([HttpTrigger(AuthorizationLevel.Anonymous, "POST", Route = "notifications/test")] HttpRequestData req) { _log.WithTag("HttpRequest", "GET").Info($"Notification test"); var request = await RequestHandling.ParseRequest(req); if (!request.IsOk) { @@ -29,13 +30,4 @@ private async Async.Task Post(HttpRequestData req) { await response.WriteAsJsonAsync(new NotificationTestResponse(result.IsOk, result.ErrorV?.ToString())); return response; } - - - [Function("NotificationsTest")] - public Async.Task Run([HttpTrigger(AuthorizationLevel.Anonymous, "POST", Route = "notifications/test")] HttpRequestData req) { - return _auth.CallIfUser(req, r => r.Method switch { - "POST" => Post(r), - _ => throw new InvalidOperationException("Unsupported HTTP method"), - }); - } } diff --git a/src/ApiService/ApiService/Functions/Pool.cs b/src/ApiService/ApiService/Functions/Pool.cs index 9d4ee7f5df..c160e3b163 100644 --- a/src/ApiService/ApiService/Functions/Pool.cs +++ b/src/ApiService/ApiService/Functions/Pool.cs @@ -2,29 +2,40 @@ using Azure.Storage.Sas; using Microsoft.Azure.Functions.Worker; using Microsoft.Azure.Functions.Worker.Http; +using Microsoft.OneFuzz.Service.Auth; namespace Microsoft.OneFuzz.Service.Functions; public class Pool { - private readonly ILogTracer _log; - private readonly IEndpointAuthorization _auth; private readonly IOnefuzzContext _context; - public Pool(ILogTracer log, IEndpointAuthorization auth, IOnefuzzContext context) { - _log = log; - _auth = auth; + public Pool(IOnefuzzContext context) { _context = context; } + public const string Route = "pool"; + [Function("Pool")] - public Async.Task Run([HttpTrigger(AuthorizationLevel.Anonymous, "GET", "POST", "DELETE", "PATCH")] HttpRequestData req) - => _auth.CallIfUser(req, r => r.Method switch { - "GET" => Get(r), - "POST" => Post(r), - "DELETE" => Delete(r), - "PATCH" => Patch(r), - var m => throw new InvalidOperationException("Unsupported HTTP method {m}"), - }); + [Authorize(Allow.User)] + public Task Run( + [HttpTrigger(AuthorizationLevel.Anonymous, "GET", Route=Route)] + HttpRequestData req) + => req.Method switch { + "GET" => Get(req), + _ => throw new InvalidOperationException("Unsupported HTTP method {m}"), + }; + + [Function("Pool_Admin")] + [Authorize(Allow.Admin)] + public Task Admin( + [HttpTrigger(AuthorizationLevel.Anonymous, "POST", "DELETE", "PATCH", Route=Route)] + HttpRequestData req) + => req.Method switch { + "POST" => Post(req), + "DELETE" => Delete(req), + "PATCH" => Patch(req), + _ => throw new InvalidOperationException("Unsupported HTTP method {m}"), + }; private async Task Delete(HttpRequestData r) { var request = await RequestHandling.ParseRequest(r); @@ -32,11 +43,6 @@ private async Task Delete(HttpRequestData r) { return await _context.RequestHandling.NotOk(r, request.ErrorV, "PoolDelete"); } - var answer = await _auth.CheckRequireAdmins(r); - if (!answer.IsOk) { - return await _context.RequestHandling.NotOk(r, answer.ErrorV, "PoolDelete"); - } - var poolResult = await _context.PoolOperations.GetByName(request.OkV.Name); if (!poolResult.IsOk) { return await _context.RequestHandling.NotOk(r, poolResult.ErrorV, "pool stop"); @@ -53,11 +59,6 @@ private async Task Post(HttpRequestData req) { return await _context.RequestHandling.NotOk(req, request.ErrorV, "PoolCreate"); } - var answer = await _auth.CheckRequireAdmins(req); - if (!answer.IsOk) { - return await _context.RequestHandling.NotOk(req, answer.ErrorV, "PoolCreate"); - } - var create = request.OkV; var pool = await _context.PoolOperations.GetByName(create.Name); if (pool.IsOk) { @@ -77,11 +78,6 @@ private async Task Patch(HttpRequestData req) { return await _context.RequestHandling.NotOk(req, request.ErrorV, "PoolUpdate"); } - var answer = await _auth.CheckRequireAdmins(req); - if (!answer.IsOk) { - return await _context.RequestHandling.NotOk(req, answer.ErrorV, "PoolUpdate"); - } - var update = request.OkV; var pool = await _context.PoolOperations.GetByName(update.Name); if (!pool.IsOk) { diff --git a/src/ApiService/ApiService/Functions/Proxy.cs b/src/ApiService/ApiService/Functions/Proxy.cs index ff56dd5fa9..85e8b4221e 100644 --- a/src/ApiService/ApiService/Functions/Proxy.cs +++ b/src/ApiService/ApiService/Functions/Proxy.cs @@ -1,32 +1,32 @@ using System.Net; using Microsoft.Azure.Functions.Worker; using Microsoft.Azure.Functions.Worker.Http; +using Microsoft.OneFuzz.Service.Auth; using VmProxy = Microsoft.OneFuzz.Service.Proxy; namespace Microsoft.OneFuzz.Service.Functions; public class Proxy { private readonly ILogTracer _log; - private readonly IEndpointAuthorization _auth; private readonly IOnefuzzContext _context; - public Proxy(ILogTracer log, IEndpointAuthorization auth, IOnefuzzContext context) { + public Proxy(ILogTracer log, IOnefuzzContext context) { _log = log; - _auth = auth; _context = context; } [Function("Proxy")] - public Async.Task Run([HttpTrigger(AuthorizationLevel.Anonymous, "GET", "PATCH", "POST", "DELETE")] HttpRequestData req) { - return _auth.CallIfUser(req, r => r.Method switch { - "GET" => Get(r), - "PATCH" => Patch(r), - "POST" => Post(r), - "DELETE" => Delete(r), + [Authorize(Allow.User)] + public Async.Task Run( + [HttpTrigger(AuthorizationLevel.Anonymous, "GET", "PATCH", "POST", "DELETE")] + HttpRequestData req) + => req.Method switch { + "GET" => Get(req), + "PATCH" => Patch(req), + "POST" => Post(req), + "DELETE" => Delete(req), _ => throw new InvalidOperationException("Unsupported HTTP method"), - }); - } - + }; private ProxyGetResult GetResult(ProxyForward proxyForward, VmProxy? proxy) { var forward = _context.ProxyForwardOperations.ToForward(proxyForward); diff --git a/src/ApiService/ApiService/Functions/ReproVmss.cs b/src/ApiService/ApiService/Functions/ReproVmss.cs index f32be74614..6061c44098 100644 --- a/src/ApiService/ApiService/Functions/ReproVmss.cs +++ b/src/ApiService/ApiService/Functions/ReproVmss.cs @@ -3,29 +3,31 @@ using System.Text; using Microsoft.Azure.Functions.Worker; using Microsoft.Azure.Functions.Worker.Http; +using Microsoft.OneFuzz.Service.Auth; namespace Microsoft.OneFuzz.Service.Functions; public class ReproVmss { private readonly ILogTracer _log; - private readonly IEndpointAuthorization _auth; private readonly IOnefuzzContext _context; - public ReproVmss(ILogTracer log, IEndpointAuthorization auth, IOnefuzzContext context) { + public ReproVmss(ILogTracer log, IOnefuzzContext context) { _log = log; - _auth = auth; _context = context; } [Function("ReproVms")] - public Async.Task Run([HttpTrigger(AuthorizationLevel.Anonymous, "GET", "POST", "DELETE", Route = "repro_vms")] HttpRequestData req) { - return _auth.CallIfUser(req, r => r.Method switch { - "GET" => Get(r), - "POST" => Post(r), - "DELETE" => Delete(r), + [Authorize(Allow.User)] + public Async.Task Run( + [HttpTrigger(AuthorizationLevel.Anonymous, "GET", "POST", "DELETE", Route = "repro_vms")] + HttpRequestData req, + FunctionContext context) + => req.Method switch { + "GET" => Get(req), + "POST" => Post(req, context), + "DELETE" => Delete(req), _ => throw new InvalidOperationException("Unsupported HTTP method"), - }); - } + }; private async Async.Task Get(HttpRequestData req) { var request = await RequestHandling.ParseRequest(req); @@ -56,7 +58,7 @@ private async Async.Task Get(HttpRequestData req) { } - private async Async.Task Post(HttpRequestData req) { + private async Async.Task Post(HttpRequestData req, FunctionContext context) { var request = await RequestHandling.ParseRequest(req); if (!request.IsOk) { return await _context.RequestHandling.NotOk( @@ -65,13 +67,7 @@ private async Async.Task Post(HttpRequestData req) { "repro_vm create"); } - var userInfo = await _context.UserCredentials.ParseJwtToken(req); - if (!userInfo.IsOk) { - return await _context.RequestHandling.NotOk( - req, - userInfo.ErrorV, - "repro_vm create"); - } + var userInfo = context.GetUserAuthInfo(); var create = request.OkV; var cfg = new ReproConfig( @@ -79,7 +75,7 @@ private async Async.Task Post(HttpRequestData req) { Path: create.Path, Duration: create.Duration); - var vm = await _context.ReproOperations.Create(cfg, userInfo.OkV.UserInfo); + var vm = await _context.ReproOperations.Create(cfg, userInfo.UserInfo); if (!vm.IsOk) { return await _context.RequestHandling.NotOk( req, @@ -98,7 +94,7 @@ private async Async.Task Post(HttpRequestData req) { // we’d like to track the usage of this feature; // anonymize the user ID so we can distinguish multiple requests { - var data = userInfo.OkV.UserInfo.ToString(); // rely on record ToString + var data = userInfo.UserInfo.ToString(); // rely on record ToString var hash = Convert.ToBase64String(SHA256.HashData(Encoding.UTF8.GetBytes(data))); _log.Event($"created repro VM, user distinguisher: {hash:Tag:UserHash}"); } diff --git a/src/ApiService/ApiService/Functions/Scaleset.cs b/src/ApiService/ApiService/Functions/Scaleset.cs index 403ee54dfa..c01383fa83 100644 --- a/src/ApiService/ApiService/Functions/Scaleset.cs +++ b/src/ApiService/ApiService/Functions/Scaleset.cs @@ -1,30 +1,42 @@ using System.Threading.Tasks; using Microsoft.Azure.Functions.Worker; using Microsoft.Azure.Functions.Worker.Http; +using Microsoft.OneFuzz.Service.Auth; namespace Microsoft.OneFuzz.Service.Functions; public class Scaleset { private readonly ILogTracer _log; - private readonly IEndpointAuthorization _auth; private readonly IOnefuzzContext _context; - public Scaleset(ILogTracer log, IEndpointAuthorization auth, IOnefuzzContext context) { + public Scaleset(ILogTracer log, IOnefuzzContext context) { _log = log; - _auth = auth; _context = context; } + public const string Route = "scaleset"; + [Function("Scaleset")] - public Async.Task Run([HttpTrigger(AuthorizationLevel.Anonymous, "GET", "PATCH", "POST", "DELETE")] HttpRequestData req) { - return _auth.CallIfUser(req, r => r.Method switch { - "GET" => Get(r), - "PATCH" => Patch(r), - "POST" => Post(r), - "DELETE" => Delete(r), + [Authorize(Allow.User)] + public Async.Task Run( + [HttpTrigger(AuthorizationLevel.Anonymous, "GET", Route=Route)] + HttpRequestData req) + => req.Method switch { + "GET" => Get(req), _ => throw new InvalidOperationException("Unsupported HTTP method"), - }); - } + }; + + [Function("Scaleset_Admin")] + [Authorize(Allow.Admin)] + public Async.Task Admin( + [HttpTrigger(AuthorizationLevel.Anonymous, "PATCH", "POST", "DELETE", Route=Route)] + HttpRequestData req) + => req.Method switch { + "PATCH" => Patch(req), + "POST" => Post(req), + "DELETE" => Delete(req), + _ => throw new InvalidOperationException("Unsupported HTTP method"), + }; private async Task Delete(HttpRequestData req) { var request = await RequestHandling.ParseRequest(req); @@ -32,11 +44,6 @@ private async Task Delete(HttpRequestData req) { return await _context.RequestHandling.NotOk(req, request.ErrorV, "ScalesetDelete"); } - var answer = await _auth.CheckRequireAdmins(req); - if (!answer.IsOk) { - return await _context.RequestHandling.NotOk(req, answer.ErrorV, "ScalesetDelete"); - } - var scalesetResult = await _context.ScalesetOperations.GetById(request.OkV.ScalesetId); if (!scalesetResult.IsOk) { return await _context.RequestHandling.NotOk(req, scalesetResult.ErrorV, "ScalesetDelete"); @@ -54,11 +61,6 @@ private async Task Post(HttpRequestData req) { return await _context.RequestHandling.NotOk(req, request.ErrorV, "ScalesetCreate"); } - var answer = await _auth.CheckRequireAdmins(req); - if (!answer.IsOk) { - return await _context.RequestHandling.NotOk(req, answer.ErrorV, "ScalesetCreate"); - } - var create = request.OkV; // verify the pool exists var poolResult = await _context.PoolOperations.GetByName(create.PoolName); @@ -121,7 +123,7 @@ private async Task Post(HttpRequestData req) { ScalesetId: Service.Scaleset.GenerateNewScalesetId(create.PoolName), State: ScalesetState.Init, NeedsConfigUpdate: false, - Auth: new SecretValue(await Auth.BuildAuth(_log)), + Auth: new SecretValue(await AuthHelpers.BuildAuth(_log)), PoolName: create.PoolName, VmSku: create.VmSku, Image: image, @@ -172,11 +174,6 @@ private async Task Patch(HttpRequestData req) { return await _context.RequestHandling.NotOk(req, request.ErrorV, "ScalesetUpdate"); } - var answer = await _auth.CheckRequireAdmins(req); - if (!answer.IsOk) { - return await _context.RequestHandling.NotOk(req, answer.ErrorV, "ScalesetUpdate"); - } - var scalesetResult = await _context.ScalesetOperations.GetById(request.OkV.ScalesetId); if (!scalesetResult.IsOk) { return await _context.RequestHandling.NotOk(req, scalesetResult.ErrorV, "ScalesetUpdate"); diff --git a/src/ApiService/ApiService/Functions/Tasks.cs b/src/ApiService/ApiService/Functions/Tasks.cs index 5c8edbc786..c2b759f527 100644 --- a/src/ApiService/ApiService/Functions/Tasks.cs +++ b/src/ApiService/ApiService/Functions/Tasks.cs @@ -2,29 +2,29 @@ using System.Threading.Tasks; using Microsoft.Azure.Functions.Worker; using Microsoft.Azure.Functions.Worker.Http; +using Microsoft.OneFuzz.Service.Auth; namespace Microsoft.OneFuzz.Service.Functions; public class Tasks { - private readonly ILogTracer _log; - private readonly IEndpointAuthorization _auth; private readonly IOnefuzzContext _context; - public Tasks(ILogTracer log, IEndpointAuthorization auth, IOnefuzzContext context) { - _log = log; - _auth = auth; + public Tasks(IOnefuzzContext context) { _context = context; } [Function("Tasks")] - public Async.Task Run([HttpTrigger(AuthorizationLevel.Anonymous, "GET", "POST", "DELETE")] HttpRequestData req) { - return _auth.CallIfUser(req, r => r.Method switch { - "GET" => Get(r), - "POST" => Post(r), - "DELETE" => Delete(r), + [Authorize(Allow.User)] + public Async.Task Run( + [HttpTrigger(AuthorizationLevel.Anonymous, "GET", "POST", "DELETE")] + HttpRequestData req, + FunctionContext context) + => req.Method switch { + "GET" => Get(req), + "POST" => Post(req, context), + "DELETE" => Delete(req), _ => throw new InvalidOperationException("Unsupported HTTP method"), - }); - } + }; private async Async.Task Get(HttpRequestData req) { var request = await RequestHandling.ParseRequest(req); @@ -73,7 +73,7 @@ private async Async.Task Get(HttpRequestData req) { } - private async Async.Task Post(HttpRequestData req) { + private async Async.Task Post(HttpRequestData req, FunctionContext context) { var request = await RequestHandling.ParseRequest(req); if (!request.IsOk) { return await _context.RequestHandling.NotOk( @@ -82,10 +82,7 @@ private async Async.Task Post(HttpRequestData req) { "task create"); } - var userInfo = await _context.UserCredentials.ParseJwtToken(req); - if (!userInfo.IsOk) { - return await _context.RequestHandling.NotOk(req, userInfo.ErrorV, "task create"); - } + var userInfo = context.GetUserAuthInfo(); var create = request.OkV; var cfg = new TaskConfig( @@ -141,7 +138,7 @@ private async Async.Task Post(HttpRequestData req) { } } - var task = await _context.TaskOperations.Create(cfg, cfg.JobId, userInfo.OkV.UserInfo); + var task = await _context.TaskOperations.Create(cfg, cfg.JobId, userInfo.UserInfo); if (!task.IsOk) { return await _context.RequestHandling.NotOk( diff --git a/src/ApiService/ApiService/Functions/Tool.cs b/src/ApiService/ApiService/Functions/Tool.cs index 1bf4a3f910..b4b7936843 100644 --- a/src/ApiService/ApiService/Functions/Tool.cs +++ b/src/ApiService/ApiService/Functions/Tool.cs @@ -1,19 +1,22 @@ using System.Net; using Microsoft.Azure.Functions.Worker; using Microsoft.Azure.Functions.Worker.Http; +using Microsoft.OneFuzz.Service.Auth; namespace Microsoft.OneFuzz.Service.Functions; public class Tools { private readonly IOnefuzzContext _context; - private readonly IEndpointAuthorization _auth; - public Tools(IEndpointAuthorization auth, IOnefuzzContext context) { + public Tools(IOnefuzzContext context) { _context = context; - _auth = auth; } - public async Async.Task GetResponse(HttpRequestData req) { + [Function("Tools")] + [Authorize(Allow.User)] + public async Async.Task Run( + [HttpTrigger(AuthorizationLevel.Anonymous, "GET")] HttpRequestData req) { + //Note: streaming response are not currently supported by in isolated functions // https://github.com/Azure/azure-functions-dotnet-worker/issues/958 var response = req.CreateResponse(HttpStatusCode.OK); @@ -23,9 +26,4 @@ public async Async.Task GetResponse(HttpRequestData req) { } return response; } - - - [Function("Tools")] - public Async.Task Run([HttpTrigger(AuthorizationLevel.Anonymous, "GET")] HttpRequestData req) - => _auth.CallIfUser(req, GetResponse); } diff --git a/src/ApiService/ApiService/Functions/ValidateScriban.cs b/src/ApiService/ApiService/Functions/ValidateScriban.cs index 4e8b003354..90bd90b723 100644 --- a/src/ApiService/ApiService/Functions/ValidateScriban.cs +++ b/src/ApiService/ApiService/Functions/ValidateScriban.cs @@ -1,5 +1,6 @@ using Microsoft.Azure.Functions.Worker; using Microsoft.Azure.Functions.Worker.Http; +using Microsoft.OneFuzz.Service.Auth; namespace Microsoft.OneFuzz.Service.Functions; @@ -11,7 +12,11 @@ public ValidateScriban(ILogTracer log, IOnefuzzContext context) { _context = context; } - private async Async.Task Post(HttpRequestData req) { + [Function("ValidateScriban")] + [Authorize(Allow.User)] + public async Async.Task Run( + [HttpTrigger(AuthorizationLevel.Anonymous, "POST")] + HttpRequestData req) { var request = await RequestHandling.ParseRequest(req); if (!request.IsOk) { return await _context.RequestHandling.NotOk(req, request.ErrorV, "ValidateTemplate"); @@ -27,14 +32,5 @@ private async Async.Task Post(HttpRequestData req) { $"Template failed to render due to: `{e.Message}`" ); } - - } - - [Function("ValidateScriban")] - public Async.Task Run([HttpTrigger(AuthorizationLevel.Anonymous, "POST")] HttpRequestData req) { - return req.Method switch { - "POST" => Post(req), - _ => throw new InvalidOperationException("Unsupported HTTP method"), - }; } } diff --git a/src/ApiService/ApiService/Functions/WebhookLogs.cs b/src/ApiService/ApiService/Functions/WebhookLogs.cs index 140d16784e..1899fa1623 100644 --- a/src/ApiService/ApiService/Functions/WebhookLogs.cs +++ b/src/ApiService/ApiService/Functions/WebhookLogs.cs @@ -3,28 +3,22 @@ namespace Microsoft.OneFuzz.Service.Functions; using Microsoft.Azure.Functions.Worker; using Microsoft.Azure.Functions.Worker.Http; +using Microsoft.OneFuzz.Service.Auth; public class WebhookLogs { private readonly ILogTracer _log; - private readonly IEndpointAuthorization _auth; private readonly IOnefuzzContext _context; - public WebhookLogs(ILogTracer log, IEndpointAuthorization auth, IOnefuzzContext context) { + public WebhookLogs(ILogTracer log, IOnefuzzContext context) { _log = log; - _auth = auth; _context = context; } [Function("WebhookLogs")] - public Async.Task Run( - [HttpTrigger(AuthorizationLevel.Anonymous, "POST", Route = "webhooks/logs")] HttpRequestData req) { - return _auth.CallIfUser(req, r => r.Method switch { - "POST" => Post(r), - _ => throw new InvalidOperationException("Unsupported HTTP method"), - }); - } - - private async Async.Task Post(HttpRequestData req) { + [Authorize(Allow.User)] + public async Async.Task Run( + [HttpTrigger(AuthorizationLevel.Anonymous, "POST", Route = "webhooks/logs")] + HttpRequestData req) { var request = await RequestHandling.ParseRequest(req); if (!request.IsOk) { return await _context.RequestHandling.NotOk( diff --git a/src/ApiService/ApiService/Functions/WebhookPing.cs b/src/ApiService/ApiService/Functions/WebhookPing.cs index 623c0bed5c..e3e4574d85 100644 --- a/src/ApiService/ApiService/Functions/WebhookPing.cs +++ b/src/ApiService/ApiService/Functions/WebhookPing.cs @@ -3,28 +3,22 @@ namespace Microsoft.OneFuzz.Service.Functions; using Microsoft.Azure.Functions.Worker; using Microsoft.Azure.Functions.Worker.Http; +using Microsoft.OneFuzz.Service.Auth; public class WebhookPing { private readonly ILogTracer _log; - private readonly IEndpointAuthorization _auth; private readonly IOnefuzzContext _context; - public WebhookPing(ILogTracer log, IEndpointAuthorization auth, IOnefuzzContext context) { + public WebhookPing(ILogTracer log, IOnefuzzContext context) { _log = log; - _auth = auth; _context = context; } [Function("WebhookPing")] - public Async.Task Run( - [HttpTrigger(AuthorizationLevel.Anonymous, "POST", Route = "webhooks/ping")] HttpRequestData req) { - return _auth.CallIfUser(req, r => r.Method switch { - "POST" => Post(r), - _ => throw new InvalidOperationException("Unsupported HTTP method"), - }); - } - - private async Async.Task Post(HttpRequestData req) { + [Authorize(Allow.User)] + public async Async.Task Run( + [HttpTrigger(AuthorizationLevel.Anonymous, "POST", Route = "webhooks/ping")] + HttpRequestData req) { var request = await RequestHandling.ParseRequest(req); if (!request.IsOk) { return await _context.RequestHandling.NotOk( diff --git a/src/ApiService/ApiService/Functions/Webhooks.cs b/src/ApiService/ApiService/Functions/Webhooks.cs index 1ac1d30cb7..538e49217b 100644 --- a/src/ApiService/ApiService/Functions/Webhooks.cs +++ b/src/ApiService/ApiService/Functions/Webhooks.cs @@ -3,31 +3,29 @@ namespace Microsoft.OneFuzz.Service.Functions; using Microsoft.Azure.Functions.Worker; using Microsoft.Azure.Functions.Worker.Http; - +using Microsoft.OneFuzz.Service.Auth; public class Webhooks { private readonly ILogTracer _log; - private readonly IEndpointAuthorization _auth; private readonly IOnefuzzContext _context; - public Webhooks(ILogTracer log, IEndpointAuthorization auth, IOnefuzzContext context) { + public Webhooks(ILogTracer log, IOnefuzzContext context) { _log = log; - _auth = auth; _context = context; } [Function("Webhooks")] + [Authorize(Allow.User)] public Async.Task Run( - [HttpTrigger(AuthorizationLevel.Anonymous, "GET", "POST", "DELETE", "PATCH")] HttpRequestData req) { - return _auth.CallIfUser(req, r => r.Method switch { - "GET" => Get(r), - "POST" => Post(r), - "DELETE" => Delete(r), - "PATCH" => Patch(r), + [HttpTrigger(AuthorizationLevel.Anonymous, "GET", "POST", "DELETE", "PATCH")] + HttpRequestData req) + => req.Method switch { + "GET" => Get(req), + "POST" => Post(req), + "DELETE" => Delete(req), + "PATCH" => Patch(req), _ => throw new InvalidOperationException("Unsupported HTTP method"), - }); - } - + }; private async Async.Task Get(HttpRequestData req) { var request = await RequestHandling.ParseRequest(req); diff --git a/src/ApiService/ApiService/HttpClient.cs b/src/ApiService/ApiService/HttpClient.cs index 89de11da38..74d8007eb7 100644 --- a/src/ApiService/ApiService/HttpClient.cs +++ b/src/ApiService/ApiService/HttpClient.cs @@ -9,7 +9,7 @@ namespace Microsoft.OneFuzz.Service; public class Request { private readonly HttpClient _httpClient; - Func>? _auth; + private readonly Func>? _auth; public Request(HttpClient httpClient, Func>? auth = null) { _auth = auth; diff --git a/src/ApiService/ApiService/Program.cs b/src/ApiService/ApiService/Program.cs index a9a71d48ed..f3fe03dc66 100644 --- a/src/ApiService/ApiService/Program.cs +++ b/src/ApiService/ApiService/Program.cs @@ -1,11 +1,8 @@ -// to avoid collision with Task in model.cs -global using System; -global -using System.Collections.Generic; -global -using System.Linq; -global -using Async = System.Threading.Tasks; +global using System; +global using System.Collections.Generic; +global using System.Linq; +// to avoid collision with Task in model.cs +global using Async = System.Threading.Tasks; using System.Text.Json; using ApiService.OneFuzzLib.Orm; using Azure.Core.Serialization; @@ -99,7 +96,6 @@ public static async Async.Task Main() { .AddScoped() .AddScoped() .AddScoped() - .AddScoped() .AddScoped() .AddScoped() .AddScoped() @@ -140,6 +136,8 @@ public static async Async.Task Main() { .ConfigureFunctionsWorkerDefaults(builder => { builder.UseAzureAppConfiguration(); builder.UseMiddleware(); + builder.UseMiddleware(); + builder.UseMiddleware(); builder.AddApplicationInsights(options => { options.ConnectionString = $"InstrumentationKey={configuration.ApplicationInsightsInstrumentationKey}"; }); diff --git a/src/ApiService/ApiService/UserCredentials.cs b/src/ApiService/ApiService/UserCredentials.cs deleted file mode 100644 index 8072471951..0000000000 --- a/src/ApiService/ApiService/UserCredentials.cs +++ /dev/null @@ -1,103 +0,0 @@ -using System.IdentityModel.Tokens.Jwt; -using System.Net.Http.Headers; -using System.Threading.Tasks; -using Microsoft.Azure.Functions.Worker.Http; -using Microsoft.IdentityModel.Tokens; - - -namespace Microsoft.OneFuzz.Service; - -public interface IUserCredentials { - public string? GetBearerToken(HttpRequestData req); - public string? GetAuthToken(HttpRequestData req); - public Task> ParseJwtToken(HttpRequestData req); -} - -public record UserAuthInfo(UserInfo UserInfo, List Roles); - -public class UserCredentials : IUserCredentials { - ILogTracer _log; - IConfigOperations _instanceConfig; - private JwtSecurityTokenHandler _tokenHandler; - - public UserCredentials(ILogTracer log, IConfigOperations instanceConfig) { - _log = log; - _instanceConfig = instanceConfig; - _tokenHandler = new JwtSecurityTokenHandler(); - } - - public string? GetBearerToken(HttpRequestData req) { - if (!req.Headers.TryGetValues("Authorization", out var authHeader) || authHeader.IsNullOrEmpty()) { - return null; - } else { - var auth = AuthenticationHeaderValue.Parse(authHeader.First()); - return auth.Scheme.ToLower() switch { - "bearer" => auth.Parameter, - _ => null, - }; - } - } - - public string? GetAuthToken(HttpRequestData req) { - var token = GetBearerToken(req); - if (token is not null) { - return token; - } else { - if (!req.Headers.TryGetValues("x-ms-token-aad-id-token", out var tokenHeader) || tokenHeader.IsNullOrEmpty()) { - return null; - } else { - return tokenHeader.First(); - } - } - } - - async Task> GetAllowedTenants() { - var r = await _instanceConfig.Fetch(); - var allowedAddTenantsQuery = - from t in r.AllowedAadTenants - select $"https://sts.windows.net/{t}/"; - - return OneFuzzResult.Ok(allowedAddTenantsQuery.ToArray()); - } - - public virtual async Task> ParseJwtToken(HttpRequestData req) { - - - var authToken = GetAuthToken(req); - if (authToken is null) { - return OneFuzzResult.Error(ErrorCode.INVALID_REQUEST, new[] { "unable to find authorization token" }); - } else { - var token = new System.IdentityModel.Tokens.Jwt.JwtSecurityToken(authToken); - var allowedTenants = await GetAllowedTenants(); - if (allowedTenants.IsOk) { - if (allowedTenants.OkV is not null && allowedTenants.OkV.Contains(token.Issuer)) { - var userAuthInfo = new UserAuthInfo(new UserInfo(null, null, null), new List()); - var userInfo = - token.Payload.Claims.Aggregate(userAuthInfo, (acc, claim) => { - switch (claim.Type) { - case "oid": - return acc with { UserInfo = acc.UserInfo with { ObjectId = Guid.Parse(claim.Value) } }; - case "appid": - return acc with { UserInfo = acc.UserInfo with { ApplicationId = Guid.Parse(claim.Value) } }; - case "upn": - return acc with { UserInfo = acc.UserInfo with { Upn = claim.Value } }; - case "roles": - acc.Roles.Add(claim.Value); - return acc; - default: - return acc; - } - }); - return OneFuzzResult.Ok(userInfo); - } else { - var tenantsStr = allowedTenants.OkV is null ? "null" : String.Join(';', allowedTenants.OkV!); - _log.Error($"issuer not from allowed tenant. issuer: {token.Issuer:Tag:Issuer} - tenants: {tenantsStr:Tag:Tenants}"); - return OneFuzzResult.Error(ErrorCode.INVALID_REQUEST, new[] { "unauthorized AAD issuer. If multi-tenant auth is failing, make sure to include all tenant_ids in the `allowed_aad_tenants` list in the instance_config. To see the current instance_config, run `onefuzz instance_config get`. " }); - } - } else { - _log.Error($"Failed to get allowed tenants due to {allowedTenants.ErrorV:Tag:Error}"); - return OneFuzzResult.Error(allowedTenants.ErrorV); - } - } - } -} diff --git a/src/ApiService/ApiService/onefuzzlib/Auth.cs b/src/ApiService/ApiService/onefuzzlib/Auth.cs index 2869f78c25..c2a6395254 100644 --- a/src/ApiService/ApiService/onefuzzlib/Auth.cs +++ b/src/ApiService/ApiService/onefuzzlib/Auth.cs @@ -2,7 +2,7 @@ using System.Diagnostics; using System.IO; -public class Auth { +public static class AuthHelpers { private static ProcessStartInfo SshKeyGenProcConfig(string tempFile) { diff --git a/src/ApiService/ApiService/onefuzzlib/EndpointAuthorization.cs b/src/ApiService/ApiService/onefuzzlib/EndpointAuthorization.cs index 011e522310..cc88d1d924 100644 --- a/src/ApiService/ApiService/onefuzzlib/EndpointAuthorization.cs +++ b/src/ApiService/ApiService/onefuzzlib/EndpointAuthorization.cs @@ -1,36 +1,23 @@ -using System.Net; -using System.Net.Http; +using System.Net.Http; using Microsoft.Azure.Functions.Worker.Http; using Microsoft.Graph; namespace Microsoft.OneFuzz.Service; -public interface IEndpointAuthorization { - - Async.Task CallIfAgent( - HttpRequestData req, - Func> method) - => CallIf(req, method, allowAgent: true); - - Async.Task CallIfUser( - HttpRequestData req, - Func> method) - => CallIf(req, method, allowUser: true); - - Async.Task CallIf( - HttpRequestData req, - Func> method, - bool allowUser = false, - bool allowAgent = false); +public record UserAuthInfo(UserInfo UserInfo, List Roles); - Async.Task CheckRequireAdmins(HttpRequestData req); +public interface IEndpointAuthorization { + Async.Task CheckRequireAdmins(UserAuthInfo authInfo); + Async.Task<(bool, string)> IsAgent(UserAuthInfo authInfo); + Async.Task CheckAccess(HttpRequestData req); } + public class EndpointAuthorization : IEndpointAuthorization { private readonly IOnefuzzContext _context; private readonly ILogTracer _log; private readonly GraphServiceClient _graphClient; - private static readonly HashSet AgentRoles = new HashSet { "UnmanagedNode", "ManagedNode" }; + private static readonly IReadOnlySet _agentRoles = new HashSet() { "UnmanagedNode", "ManagedNode" }; public EndpointAuthorization(IOnefuzzContext context, ILogTracer log, GraphServiceClient graphClient) { _context = context; @@ -38,58 +25,7 @@ public EndpointAuthorization(IOnefuzzContext context, ILogTracer log, GraphServi _graphClient = graphClient; } - public virtual async Async.Task CallIf(HttpRequestData req, Func> method, bool allowUser = false, bool allowAgent = false) { - var tokenResult = await _context.UserCredentials.ParseJwtToken(req); - - if (!tokenResult.IsOk) { - return await _context.RequestHandling.NotOk(req, tokenResult.ErrorV, "token verification", HttpStatusCode.Unauthorized); - } - - var token = tokenResult.OkV.UserInfo; - - var (isAgent, reason) = await IsAgent(tokenResult.OkV); - - if (!isAgent) { - if (!allowUser) { - return await Reject(req, token, "endpoint not allowed for users"); - } - - var access = await CheckAccess(req); - if (!access.IsOk) { - return await _context.RequestHandling.NotOk(req, access.ErrorV, "access control", HttpStatusCode.Unauthorized); - } - } - - - if (isAgent && !allowAgent) { - return await Reject(req, token, reason); - } - - return await method(req); - } - - - public async Async.Task Reject(HttpRequestData req, UserInfo token, String? reason = null) { - var body = await req.ReadAsStringAsync(); - _log.Error($"reject token. reason:{reason} url:{req.Url:Tag:Url} token:{token:Tag:Token} body:{body:Tag:Body}"); - - return await _context.RequestHandling.NotOk( - req, - Error.Create( - ErrorCode.UNAUTHORIZED, - reason ?? "Unrecognized agent" - ), - "token verification", - HttpStatusCode.Unauthorized - ); - } - - public async Async.Task CheckRequireAdmins(HttpRequestData req) { - var tokenResult = await _context.UserCredentials.ParseJwtToken(req); - if (!tokenResult.IsOk) { - return tokenResult.ErrorV; - } - + public async Async.Task CheckRequireAdmins(UserAuthInfo authInfo) { var config = await _context.ConfigOperations.Fetch(); if (config is null) { return Error.Create( @@ -97,7 +33,7 @@ public async Async.Task CheckRequireAdmins(HttpRequestData re "no instance configuration found "); } - return CheckRequireAdminsImpl(config, tokenResult.OkV.UserInfo); + return CheckRequireAdminsImpl(config, authInfo.UserInfo); } private static OneFuzzResultVoid CheckRequireAdminsImpl(InstanceConfig config, UserInfo userInfo) { @@ -177,9 +113,8 @@ private GroupMembershipChecker CreateGroupMembershipChecker(InstanceConfig confi return null; } - public async Async.Task<(bool, string)> IsAgent(UserAuthInfo authInfo) { - if (!AgentRoles.Overlaps(authInfo.Roles)) { + if (!_agentRoles.Overlaps(authInfo.Roles)) { return (false, "no agent role"); } diff --git a/src/ApiService/ApiService/onefuzzlib/OnefuzzContext.cs b/src/ApiService/ApiService/onefuzzlib/OnefuzzContext.cs index 95a599452d..d877bfddbb 100644 --- a/src/ApiService/ApiService/onefuzzlib/OnefuzzContext.cs +++ b/src/ApiService/ApiService/onefuzzlib/OnefuzzContext.cs @@ -37,7 +37,6 @@ public interface IOnefuzzContext { IStorage Storage { get; } ITaskOperations TaskOperations { get; } ITaskEventOperations TaskEventOperations { get; } - IUserCredentials UserCredentials { get; } IVmOperations VmOperations { get; } IVmssOperations VmssOperations { get; } IWebhookMessageLogOperations WebhookMessageLogOperations { get; } @@ -77,7 +76,6 @@ public OnefuzzContext(IServiceProvider serviceProvider) { public IContainers Containers => _serviceProvider.GetRequiredService(); public IReports Reports => _serviceProvider.GetRequiredService(); public INotificationOperations NotificationOperations => _serviceProvider.GetRequiredService(); - public IUserCredentials UserCredentials => _serviceProvider.GetRequiredService(); public IReproOperations ReproOperations => _serviceProvider.GetRequiredService(); public IPoolOperations PoolOperations => _serviceProvider.GetRequiredService(); public IIpOperations IpOperations => _serviceProvider.GetRequiredService(); diff --git a/src/ApiService/ApiService/onefuzzlib/ProxyOperations.cs b/src/ApiService/ApiService/onefuzzlib/ProxyOperations.cs index e5cba14e6c..8f192235ef 100644 --- a/src/ApiService/ApiService/onefuzzlib/ProxyOperations.cs +++ b/src/ApiService/ApiService/onefuzzlib/ProxyOperations.cs @@ -59,7 +59,17 @@ public async Async.Task GetOrCreate(Region region) { } _logTracer.Info($"creating proxy: region:{region:Tag:Region}"); - var newProxy = new Proxy(region, Guid.NewGuid(), DateTimeOffset.UtcNow, VmState.Init, new SecretValue(await Auth.BuildAuth(_logTracer)), null, null, _context.ServiceConfiguration.OneFuzzVersion, null, false); + var newProxy = new Proxy( + region, + Guid.NewGuid(), + DateTimeOffset.UtcNow, + VmState.Init, + new SecretValue(await AuthHelpers.BuildAuth(_logTracer)), + null, + null, + _context.ServiceConfiguration.OneFuzzVersion, + null, + false); var r = await Replace(newProxy); if (!r.IsOk) { diff --git a/src/ApiService/ApiService/onefuzzlib/ReproOperations.cs b/src/ApiService/ApiService/onefuzzlib/ReproOperations.cs index 42fbe1b4c7..bd03824321 100644 --- a/src/ApiService/ApiService/onefuzzlib/ReproOperations.cs +++ b/src/ApiService/ApiService/onefuzzlib/ReproOperations.cs @@ -335,7 +335,7 @@ public async Task> Create(ReproConfig config, UserInfo user return OneFuzzResult.Error(ErrorCode.INVALID_REQUEST, "unable to find task"); } - var auth = await _context.SecretsOperations.StoreSecret(new SecretValue(await Auth.BuildAuth(_logTracer))); + var auth = await _context.SecretsOperations.StoreSecret(new SecretValue(await AuthHelpers.BuildAuth(_logTracer))); var vm = new Repro( VmId: Guid.NewGuid(), diff --git a/src/ApiService/IntegrationTests/AgentCanScheduleTests.cs b/src/ApiService/IntegrationTests/AgentCanScheduleTests.cs index f7a3c8aa1f..e6b663e87d 100644 --- a/src/ApiService/IntegrationTests/AgentCanScheduleTests.cs +++ b/src/ApiService/IntegrationTests/AgentCanScheduleTests.cs @@ -1,10 +1,6 @@ -using System.Net; -using IntegrationTests.Fakes; -using Microsoft.OneFuzz.Service; -using Microsoft.OneFuzz.Service.Functions; +using Microsoft.OneFuzz.Service; using Xunit; using Xunit.Abstractions; -using Async = System.Threading.Tasks; namespace IntegrationTests; @@ -23,31 +19,4 @@ public abstract class AgentCanScheduleTestsBase : FunctionTestBase { public AgentCanScheduleTestsBase(ITestOutputHelper output, IStorage storage) : base(output, storage) { } - - [Fact] - public async Async.Task Authorization_IsRequired() { - var auth = new TestEndpointAuthorization(RequestType.NoAuthorization, Logger, Context); - var func = new AgentCanSchedule(Logger, auth, Context); - - var result = await func.Run(TestHttpRequestData.Empty("POST")); - Assert.Equal(HttpStatusCode.Unauthorized, result.StatusCode); - } - - [Fact] - public async Async.Task UserAuthorization_IsNotPermitted() { - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new AgentCanSchedule(Logger, auth, Context); - - var result = await func.Run(TestHttpRequestData.Empty("POST")); - Assert.Equal(HttpStatusCode.Unauthorized, result.StatusCode); - } - - [Fact] - public async Async.Task AgentAuthorization_IsAccepted() { - var auth = new TestEndpointAuthorization(RequestType.Agent, Logger, Context); - var func = new AgentCanSchedule(Logger, auth, Context); - - var result = await func.Run(TestHttpRequestData.Empty("POST")); - Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode); // BadRequest due to no body, not Unauthorized - } } diff --git a/src/ApiService/IntegrationTests/AgentCommandsTests.cs b/src/ApiService/IntegrationTests/AgentCommandsTests.cs index b18b32e979..47b951fc15 100644 --- a/src/ApiService/IntegrationTests/AgentCommandsTests.cs +++ b/src/ApiService/IntegrationTests/AgentCommandsTests.cs @@ -26,33 +26,6 @@ public AgentCommandsTestsBase(ITestOutputHelper output, IStorage storage) : base(output, storage) { } - [Fact] - public async Async.Task Authorization_IsRequired() { - var auth = new TestEndpointAuthorization(RequestType.NoAuthorization, Logger, Context); - var func = new AgentCommands(Logger, auth, Context); - - var result = await func.Run(TestHttpRequestData.Empty("GET")); - Assert.Equal(HttpStatusCode.Unauthorized, result.StatusCode); - } - - [Fact] - public async Async.Task UserAuthorization_IsNotPermitted() { - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new AgentCommands(Logger, auth, Context); - - var result = await func.Run(TestHttpRequestData.Empty("GET")); - Assert.Equal(HttpStatusCode.Unauthorized, result.StatusCode); - } - - [Fact] - public async Async.Task AgentAuthorization_IsAccepted() { - var auth = new TestEndpointAuthorization(RequestType.Agent, Logger, Context); - var func = new AgentCommands(Logger, auth, Context); - - var result = await func.Run(TestHttpRequestData.Empty("GET")); - Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode); // BadRequest due to no body, not Unauthorized - } - [Fact] public async Async.Task AgentCommand_GetsCommand() { var machineId = Guid.NewGuid(); @@ -69,8 +42,7 @@ await Context.InsertAll(new[] { }); var commandRequest = new NodeCommandGet(machineId); - var auth = new TestEndpointAuthorization(RequestType.Agent, Logger, Context); - var func = new AgentCommands(Logger, auth, Context); + var func = new AgentCommands(Logger, Context); var result = await func.Run(TestHttpRequestData.FromJson("GET", commandRequest)); Assert.Equal(HttpStatusCode.OK, result.StatusCode); diff --git a/src/ApiService/IntegrationTests/AgentEventsTests.cs b/src/ApiService/IntegrationTests/AgentEventsTests.cs index 00bc0a1740..a662b4890a 100644 --- a/src/ApiService/IntegrationTests/AgentEventsTests.cs +++ b/src/ApiService/IntegrationTests/AgentEventsTests.cs @@ -35,28 +35,9 @@ public AgentEventsTestsBase(ITestOutputHelper output, IStorage storage) readonly Guid _poolId = Guid.NewGuid(); readonly string _poolVersion = $"version-{Guid.NewGuid()}"; - [Fact] - public async Async.Task Authorization_IsRequired() { - var auth = new TestEndpointAuthorization(RequestType.NoAuthorization, Logger, Context); - var func = new AgentEvents(Logger, auth, Context); - - var result = await func.Run(TestHttpRequestData.Empty("POST")); - Assert.Equal(HttpStatusCode.Unauthorized, result.StatusCode); - } - - [Fact] - public async Async.Task UserAuthorization_IsNotPermitted() { - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new AgentEvents(Logger, auth, Context); - - var result = await func.Run(TestHttpRequestData.Empty("POST")); - Assert.Equal(HttpStatusCode.Unauthorized, result.StatusCode); - } - [Fact] public async Async.Task WorkerEventMustHaveDoneOrRunningSet() { - var auth = new TestEndpointAuthorization(RequestType.Agent, Logger, Context); - var func = new AgentEvents(Logger, auth, Context); + var func = new AgentEvents(Logger, Context); var data = new NodeStateEnvelope( MachineId: Guid.NewGuid(), @@ -75,8 +56,7 @@ await Context.InsertAll( new Task(_jobId, _taskId, TaskState.Running, Os.Linux, new TaskConfig(_jobId, null, new TaskDetails(TaskType.Coverage, 100)))); - var auth = new TestEndpointAuthorization(RequestType.Agent, Logger, Context); - var func = new AgentEvents(Logger, auth, Context); + var func = new AgentEvents(Logger, Context); var data = new NodeStateEnvelope( MachineId: _machineId, @@ -103,8 +83,7 @@ await Context.InsertAll( new Task(_jobId, _taskId, TaskState.Running, Os.Linux, new TaskConfig(_jobId, null, new TaskDetails(TaskType.Coverage, 100)))); - var auth = new TestEndpointAuthorization(RequestType.Agent, Logger, Context); - var func = new AgentEvents(Logger, auth, Context); + var func = new AgentEvents(Logger, Context); var data = new NodeStateEnvelope( MachineId: _machineId, @@ -130,8 +109,7 @@ await Context.InsertAll( new Task(_jobId, _taskId, TaskState.Scheduled, Os.Linux, new TaskConfig(_jobId, null, new TaskDetails(TaskType.Coverage, 100)))); - var auth = new TestEndpointAuthorization(RequestType.Agent, Logger, Context); - var func = new AgentEvents(Logger, auth, Context); + var func = new AgentEvents(Logger, Context); var data = new NodeStateEnvelope( MachineId: _machineId, @@ -156,8 +134,7 @@ public async Async.Task WorkerRunning_ForMissingTask_ReturnsError() { await Context.InsertAll( new Node(_poolName, _machineId, _poolId, _poolVersion)); - var auth = new TestEndpointAuthorization(RequestType.Agent, Logger, Context); - var func = new AgentEvents(Logger, auth, Context); + var func = new AgentEvents(Logger, Context); var data = new NodeStateEnvelope( MachineId: _machineId, Event: new WorkerEvent(Running: new WorkerRunningEvent(_taskId))); @@ -173,8 +150,7 @@ await Context.InsertAll( new Task(_jobId, _taskId, TaskState.Running, Os.Linux, new TaskConfig(_jobId, null, new TaskDetails(TaskType.Coverage, 0)))); - var auth = new TestEndpointAuthorization(RequestType.Agent, Logger, Context); - var func = new AgentEvents(Logger, auth, Context); + var func = new AgentEvents(Logger, Context); var data = new NodeStateEnvelope( MachineId: _machineId, Event: new WorkerEvent(Running: new WorkerRunningEvent(_taskId))); @@ -191,8 +167,7 @@ await Context.InsertAll( new Task(_jobId, _taskId, TaskState.Running, Os.Linux, new TaskConfig(_jobId, null, new TaskDetails(TaskType.Coverage, 0)))); - var auth = new TestEndpointAuthorization(RequestType.Agent, Logger, Context); - var func = new AgentEvents(Logger, auth, Context); + var func = new AgentEvents(Logger, Context); var data = new NodeStateEnvelope( MachineId: _machineId, Event: new WorkerEvent(Running: new WorkerRunningEvent(_taskId))); @@ -232,8 +207,7 @@ await Async.Task.WhenAll( public async Async.Task NodeStateUpdate_ForMissingNode_IgnoresEvent() { // nothing present in storage - var auth = new TestEndpointAuthorization(RequestType.Agent, Logger, Context); - var func = new AgentEvents(Logger, auth, Context); + var func = new AgentEvents(Logger, Context); var data = new NodeStateEnvelope( MachineId: _machineId, Event: new NodeStateUpdate(NodeState.Init)); @@ -248,8 +222,7 @@ public async Async.Task NodeStateUpdate_CanTransitionFromInitToReady() { await Context.InsertAll( new Node(_poolName, _machineId, _poolId, _poolVersion, State: NodeState.Init)); - var auth = new TestEndpointAuthorization(RequestType.Agent, Logger, Context); - var func = new AgentEvents(Logger, auth, Context); + var func = new AgentEvents(Logger, Context); var data = new NodeStateEnvelope( MachineId: _machineId, Event: new NodeStateUpdate(NodeState.Ready)); @@ -266,8 +239,7 @@ public async Async.Task NodeStateUpdate_BecomingFree_StopsNode_IfMarkedForReimag await Context.InsertAll( new Node(_poolName, _machineId, _poolId, _poolVersion, ReimageRequested: true)); - var auth = new TestEndpointAuthorization(RequestType.Agent, Logger, Context); - var func = new AgentEvents(Logger, auth, Context); + var func = new AgentEvents(Logger, Context); var data = new NodeStateEnvelope( MachineId: _machineId, Event: new NodeStateUpdate(NodeState.Free)); @@ -295,8 +267,7 @@ public async Async.Task NodeStateUpdate_BecomingFree_StopsNode_IfMarkedForDeleti await Context.InsertAll( new Node(_poolName, _machineId, _poolId, _poolVersion, DeleteRequested: true)); - var auth = new TestEndpointAuthorization(RequestType.Agent, Logger, Context); - var func = new AgentEvents(Logger, auth, Context); + var func = new AgentEvents(Logger, Context); var data = new NodeStateEnvelope( MachineId: _machineId, Event: new NodeStateUpdate(NodeState.Free)); diff --git a/src/ApiService/IntegrationTests/AgentRegistrationTests.cs b/src/ApiService/IntegrationTests/AgentRegistrationTests.cs index 618616bbdc..202a66628b 100644 --- a/src/ApiService/IntegrationTests/AgentRegistrationTests.cs +++ b/src/ApiService/IntegrationTests/AgentRegistrationTests.cs @@ -33,37 +33,9 @@ public AgentRegistrationTestsBase(ITestOutputHelper output, IStorage storage) private readonly ScalesetId _scalesetId = ScalesetId.Parse($"scaleset-{Guid.NewGuid()}"); private readonly PoolName _poolName = PoolName.Parse($"pool-{Guid.NewGuid()}"); - [Fact] - public async Async.Task Authorization_IsRequired() { - var auth = new TestEndpointAuthorization(RequestType.NoAuthorization, Logger, Context); - var func = new AgentRegistration(Logger, auth, Context); - - var result = await func.Run(TestHttpRequestData.Empty("POST")); - Assert.Equal(HttpStatusCode.Unauthorized, result.StatusCode); - } - - [Fact] - public async Async.Task UserAuthorization_IsNotPermitted() { - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new AgentRegistration(Logger, auth, Context); - - var result = await func.Run(TestHttpRequestData.Empty("POST")); - Assert.Equal(HttpStatusCode.Unauthorized, result.StatusCode); - } - - [Fact] - public async Async.Task AgentAuthorization_IsAccepted() { - var auth = new TestEndpointAuthorization(RequestType.Agent, Logger, Context); - var func = new AgentRegistration(Logger, auth, Context); - - var result = await func.Run(TestHttpRequestData.Empty("POST")); - Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode); // BadRequest due to missing parameters, not Unauthorized - } - [Fact] public async Async.Task Get_UrlParameterRequired() { - var auth = new TestEndpointAuthorization(RequestType.Agent, Logger, Context); - var func = new AgentRegistration(Logger, auth, Context); + var func = new AgentRegistration(Logger, Context); var req = TestHttpRequestData.Empty("GET"); var result = await func.Run(req); @@ -76,8 +48,7 @@ public async Async.Task Get_UrlParameterRequired() { [Fact] public async Async.Task Get_MissingNode() { - var auth = new TestEndpointAuthorization(RequestType.Agent, Logger, Context); - var func = new AgentRegistration(Logger, auth, Context); + var func = new AgentRegistration(Logger, Context); var req = TestHttpRequestData.Empty("GET"); req.SetUrlParameter("machine_id", _machineId); @@ -95,8 +66,7 @@ public async Async.Task Get_MissingPool() { await Context.InsertAll( new Node(_poolName, _machineId, _poolId, "1.0.0")); - var auth = new TestEndpointAuthorization(RequestType.Agent, Logger, Context); - var func = new AgentRegistration(Logger, auth, Context); + var func = new AgentRegistration(Logger, Context); var req = TestHttpRequestData.Empty("GET"); req.SetUrlParameter("machine_id", _machineId); @@ -115,8 +85,7 @@ await Context.InsertAll( new Node(_poolName, _machineId, _poolId, "1.0.0"), new Pool(_poolName, _poolId, Os.Linux, false, Architecture.x86_64, PoolState.Init, null)); - var auth = new TestEndpointAuthorization(RequestType.Agent, Logger, Context); - var func = new AgentRegistration(Logger, auth, Context); + var func = new AgentRegistration(Logger, Context); var req = TestHttpRequestData.Empty("GET"); req.SetUrlParameter("machine_id", _machineId); @@ -135,8 +104,7 @@ public async Async.Task Post_SetsDefaultVersion_IfNotSupplied() { await Context.InsertAll( new Pool(_poolName, _poolId, Os.Linux, false, Architecture.x86_64, PoolState.Init, null)); - var auth = new TestEndpointAuthorization(RequestType.Agent, Logger, Context); - var func = new AgentRegistration(Logger, auth, Context); + var func = new AgentRegistration(Logger, Context); var req = TestHttpRequestData.Empty("POST"); req.SetUrlParameter("machine_id", _machineId); @@ -157,8 +125,7 @@ public async Async.Task Post_SetsCorrectVersion() { await Context.InsertAll( new Pool(_poolName, _poolId, Os.Linux, false, Architecture.x86_64, PoolState.Init, null)); - var auth = new TestEndpointAuthorization(RequestType.Agent, Logger, Context); - var func = new AgentRegistration(Logger, auth, Context); + var func = new AgentRegistration(Logger, Context); var req = TestHttpRequestData.Empty("POST"); req.SetUrlParameter("machine_id", _machineId); @@ -181,8 +148,7 @@ await Context.InsertAll( new Node(PoolName.Parse("another-pool"), _machineId, _poolId, "1.0.0"), new Pool(_poolName, _poolId, Os.Linux, false, Architecture.x86_64, PoolState.Init, null)); - var auth = new TestEndpointAuthorization(RequestType.Agent, Logger, Context); - var func = new AgentRegistration(Logger, auth, Context); + var func = new AgentRegistration(Logger, Context); var req = TestHttpRequestData.Empty("POST"); req.SetUrlParameter("machine_id", _machineId); @@ -205,8 +171,7 @@ public async Async.Task Post_ChecksRequiredParameters(string parameterToSkip) { await Context.InsertAll( new Pool(_poolName, _poolId, Os.Linux, false, Architecture.x86_64, PoolState.Init, null)); - var auth = new TestEndpointAuthorization(RequestType.Agent, Logger, Context); - var func = new AgentRegistration(Logger, auth, Context); + var func = new AgentRegistration(Logger, Context); var req = TestHttpRequestData.Empty("POST"); if (parameterToSkip != "machine_id") { diff --git a/src/ApiService/IntegrationTests/AuthTests.cs b/src/ApiService/IntegrationTests/AuthTests.cs index 5441c49b29..3632658fbc 100644 --- a/src/ApiService/IntegrationTests/AuthTests.cs +++ b/src/ApiService/IntegrationTests/AuthTests.cs @@ -14,7 +14,7 @@ public AuthTests(ITestOutputHelper output) { [Fact] public async System.Threading.Tasks.Task TestAuth() { - var auth = await Microsoft.OneFuzz.Service.Auth.BuildAuth(Logger); + var auth = await AuthHelpers.BuildAuth(Logger); auth.Should().NotBeNull(); auth.PrivateKey.StartsWith("-----BEGIN OPENSSH PRIVATE KEY-----").Should().BeTrue(); diff --git a/src/ApiService/IntegrationTests/ContainersTests.cs b/src/ApiService/IntegrationTests/ContainersTests.cs index 8b122a52a2..70339edf16 100644 --- a/src/ApiService/IntegrationTests/ContainersTests.cs +++ b/src/ApiService/IntegrationTests/ContainersTests.cs @@ -31,22 +31,6 @@ public abstract class ContainersTestBase : FunctionTestBase { public ContainersTestBase(ITestOutputHelper output, IStorage storage) : base(output, storage) { } - [Theory] - [InlineData("GET")] - [InlineData("POST")] - [InlineData("DELETE")] - public async Async.Task WithoutAuthorization_IsRejected(string method) { - var auth = new TestEndpointAuthorization(RequestType.NoAuthorization, Logger, Context); - var func = new ContainersFunction(Logger, auth, Context); - - var result = await func.Run(TestHttpRequestData.Empty(method)); - Assert.Equal(HttpStatusCode.Unauthorized, result.StatusCode); - - var err = BodyAs(result); - Assert.Equal(ErrorCode.UNAUTHORIZED.ToString(), err.Title); - } - - [Fact] public async Async.Task CanDelete() { var containerName = Container.Parse("test"); @@ -55,8 +39,7 @@ public async Async.Task CanDelete() { var msg = TestHttpRequestData.FromJson("DELETE", new ContainerDelete(containerName)); - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new ContainersFunction(Logger, auth, Context); + var func = new ContainersFunction(Logger, Context); var result = await func.Run(msg); Assert.Equal(HttpStatusCode.OK, result.StatusCode); @@ -71,8 +54,7 @@ public async Async.Task CanPost_New() { var containerName = Container.Parse("test"); var msg = TestHttpRequestData.FromJson("POST", new ContainerCreate(containerName, meta)); - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new ContainersFunction(Logger, auth, Context); + var func = new ContainersFunction(Logger, Context); var result = await func.Run(msg); Assert.Equal(HttpStatusCode.OK, result.StatusCode); @@ -95,8 +77,7 @@ public async Async.Task CanPost_Existing() { var metadata = new Dictionary { { "some", "value" } }; var msg = TestHttpRequestData.FromJson("POST", new ContainerCreate(containerName, metadata)); - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new ContainersFunction(Logger, auth, Context); + var func = new ContainersFunction(Logger, Context); var result = await func.Run(msg); Assert.Equal(HttpStatusCode.OK, result.StatusCode); @@ -119,8 +100,7 @@ public async Async.Task Get_Existing() { var msg = TestHttpRequestData.FromJson("GET", new ContainerGet(containerName)); - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new ContainersFunction(Logger, auth, Context); + var func = new ContainersFunction(Logger, Context); var result = await func.Run(msg); Assert.Equal(HttpStatusCode.OK, result.StatusCode); @@ -134,8 +114,7 @@ public async Async.Task Get_Missing_Fails() { var container = Container.Parse("container"); var msg = TestHttpRequestData.FromJson("GET", new ContainerGet(container)); - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new ContainersFunction(Logger, auth, Context); + var func = new ContainersFunction(Logger, Context); var result = await func.Run(msg); Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode); } @@ -149,8 +128,7 @@ public async Async.Task List_Existing() { var msg = TestHttpRequestData.Empty("GET"); // this means list all - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new ContainersFunction(Logger, auth, Context); + var func = new ContainersFunction(Logger, Context); var result = await func.Run(msg); Assert.Equal(HttpStatusCode.OK, result.StatusCode); @@ -188,8 +166,7 @@ public async Async.Task BadContainerNameProducesGoodErrorMessage() { // use anonymous type so we can send an invalid name var msg = TestHttpRequestData.FromJson("POST", new { Name = "AbCd" }); - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new ContainersFunction(Logger, auth, Context); + var func = new ContainersFunction(Logger, Context); var result = await func.Run(msg); Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode); diff --git a/src/ApiService/IntegrationTests/DownloadTests.cs b/src/ApiService/IntegrationTests/DownloadTests.cs index 0feae4db95..1912071caf 100644 --- a/src/ApiService/IntegrationTests/DownloadTests.cs +++ b/src/ApiService/IntegrationTests/DownloadTests.cs @@ -26,26 +26,13 @@ public abstract class DownloadTestBase : FunctionTestBase { public DownloadTestBase(ITestOutputHelper output, IStorage storage) : base(output, storage) { } - [Fact] - public async Async.Task Download_WithoutAuthorization_IsRejected() { - var auth = new TestEndpointAuthorization(RequestType.NoAuthorization, Logger, Context); - var func = new Download(auth, Context); - - var result = await func.Run(TestHttpRequestData.Empty("GET")); - Assert.Equal(HttpStatusCode.Unauthorized, result.StatusCode); - - var err = BodyAs(result); - Assert.Equal(ErrorCode.UNAUTHORIZED.ToString(), err.Title); - } - [Fact] public async Async.Task Download_WithoutContainer_IsRejected() { var req = TestHttpRequestData.Empty("GET"); var url = new UriBuilder(req.Url) { Query = "filename=xxx" }.Uri; req.SetUrl(url); - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new Download(auth, Context); + var func = new Download(Context); var result = await func.Run(req); Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode); @@ -59,8 +46,7 @@ public async Async.Task Download_WithoutFilename_IsRejected() { var url = new UriBuilder(req.Url) { Query = "container=xxx" }.Uri; req.SetUrl(url); - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new Download(auth, Context); + var func = new Download(Context); var result = await func.Run(req); Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode); @@ -81,8 +67,7 @@ public async Async.Task Download_RedirectsToResult_WithLocationHeader() { var url = new UriBuilder(req.Url) { Query = "container=xxx&filename=yyy" }.Uri; req.SetUrl(url); - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new Download(auth, Context); + var func = new Download(Context); var result = await func.Run(req); Assert.Equal(HttpStatusCode.Found, result.StatusCode); diff --git a/src/ApiService/IntegrationTests/EndpointAuthTests.cs b/src/ApiService/IntegrationTests/EndpointAuthTests.cs new file mode 100644 index 0000000000..6a1355362e --- /dev/null +++ b/src/ApiService/IntegrationTests/EndpointAuthTests.cs @@ -0,0 +1,93 @@ + +using System; +using System.Collections.Generic; +using System.Threading.Tasks; +using Microsoft.OneFuzz.Service; +using Xunit; +using Xunit.Abstractions; +using Async = System.Threading.Tasks; + +namespace IntegrationTests; + + +[Trait("Category", "Live")] +public class AzureStorageEndpointAuthTest : EndpointAuthTestBase { + public AzureStorageEndpointAuthTest(ITestOutputHelper output) + : base(output, Integration.AzureStorage.FromEnvironment()) { } +} + +public class AzuriteEndpointAuthTest : EndpointAuthTestBase { + public AzuriteEndpointAuthTest(ITestOutputHelper output) + : base(output, new Integration.AzuriteStorage()) { } +} + +public abstract class EndpointAuthTestBase : FunctionTestBase { + public EndpointAuthTestBase(ITestOutputHelper output, IStorage storage) + : base(output, storage) { + } + + private readonly Guid _applicationId = Guid.NewGuid(); + private readonly Guid _userObjectId = Guid.NewGuid(); + + private Task CheckUserAdmin() { + var userAuthInfo = new UserAuthInfo( + new UserInfo(ApplicationId: _applicationId, ObjectId: _userObjectId, "upn"), + new List()); + + var auth = new EndpointAuthorization(Context, Logger, null!); + + return auth.CheckRequireAdmins(userAuthInfo); + } + + [Fact] + public async Async.Task IfRequireAdminPrivilegesIsEnabled_UserIsNotPermitted() { + // config specifies that a different user is admin + await Context.InsertAll( + new InstanceConfig(Context.ServiceConfiguration.OneFuzzInstanceName!) { + RequireAdminPrivileges = true, + }); + + var result = await CheckUserAdmin(); + Assert.False(result.IsOk, "should not be admin"); + } + + [Fact] + public async Async.Task IfRequireAdminPrivilegesIsDisabled_UserIsPermitted() { + // disable requiring admin privileges + await Context.InsertAll( + new InstanceConfig(Context.ServiceConfiguration.OneFuzzInstanceName!) { + RequireAdminPrivileges = false, + }); + + var result = await CheckUserAdmin(); + Assert.True(result.IsOk, "should be admin"); + } + + [Fact] + public async Async.Task EnablingAdminForAnotherUserDoesNotPermitThisUser() { + var otherUserObjectId = Guid.NewGuid(); + + // config specifies that a different user is admin + await Context.InsertAll( + new InstanceConfig(Context.ServiceConfiguration.OneFuzzInstanceName!) { + Admins = new[] { otherUserObjectId }, + RequireAdminPrivileges = true, + }); + + var result = await CheckUserAdmin(); + Assert.False(result.IsOk, "should not be admin"); + } + + [Fact] + public async Async.Task UserCanBeAdmin() { + // config specifies that user is admin + await Context.InsertAll( + new InstanceConfig(Context.ServiceConfiguration.OneFuzzInstanceName!) { + Admins = new[] { _userObjectId }, + RequireAdminPrivileges = true, + }); + + var result = await CheckUserAdmin(); + Assert.True(result.IsOk, "should be admin"); + } +} diff --git a/src/ApiService/IntegrationTests/EventsTests.cs b/src/ApiService/IntegrationTests/EventsTests.cs index 7f9d4b11f0..b5b4e03a87 100644 --- a/src/ApiService/IntegrationTests/EventsTests.cs +++ b/src/ApiService/IntegrationTests/EventsTests.cs @@ -46,8 +46,7 @@ public async Async.Task BlobIsCreatedAndIsAccessible() { ping.Should().NotBeNull(); var msg = TestHttpRequestData.FromJson("GET", new EventsGet(ping.PingId)); - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new EventsFunction(Logger, auth, Context); + var func = new EventsFunction(Logger, Context); var result = await func.Run(msg); result.StatusCode.Should().Be(HttpStatusCode.OK); diff --git a/src/ApiService/IntegrationTests/Fakes/TestContext.cs b/src/ApiService/IntegrationTests/Fakes/TestContext.cs index f8fa1ca5fc..41d2219924 100644 --- a/src/ApiService/IntegrationTests/Fakes/TestContext.cs +++ b/src/ApiService/IntegrationTests/Fakes/TestContext.cs @@ -41,7 +41,6 @@ public TestContext(IHttpClientFactory httpClientFactory, ILogTracer logTracer, I ScalesetOperations = new ScalesetOperations(logTracer, cache, this); ReproOperations = new ReproOperations(logTracer, this); Reports = new Reports(logTracer, Containers); - UserCredentials = new UserCredentials(logTracer, ConfigOperations); NotificationOperations = new NotificationOperations(logTracer, this); FeatureManagerSnapshot = new TestFeatureManagerSnapshot(); @@ -79,7 +78,6 @@ public Async.Task InsertAll(params EntityBase[] objs) public ICreds Creds { get; } public IContainers Containers { get; set; } public IQueue Queue { get; } - public IUserCredentials UserCredentials { get; set; } public IRequestHandling RequestHandling { get; } diff --git a/src/ApiService/IntegrationTests/Fakes/TestEndpointAuthorization.cs b/src/ApiService/IntegrationTests/Fakes/TestEndpointAuthorization.cs deleted file mode 100644 index 0d20c42bf6..0000000000 --- a/src/ApiService/IntegrationTests/Fakes/TestEndpointAuthorization.cs +++ /dev/null @@ -1,47 +0,0 @@ - -using System; -using System.Net; -using System.Threading.Tasks; -using Microsoft.Azure.Functions.Worker.Http; -using Microsoft.OneFuzz.Service; - -namespace IntegrationTests.Fakes; - -public enum RequestType { - NoAuthorization, - User, - Agent, -} - -sealed class TestEndpointAuthorization : EndpointAuthorization { - private readonly RequestType _type; - private readonly IOnefuzzContext _context; - - public TestEndpointAuthorization(RequestType type, ILogTracer log, IOnefuzzContext context) - : base(context, log, null! /* not needed for test */) { - _type = type; - _context = context; - } - - public override Task CallIf( - HttpRequestData req, - Func> method, - bool allowUser = false, - bool allowAgent = false) { - - if ((_type == RequestType.User && allowUser) || - (_type == RequestType.Agent && allowAgent)) { - return method(req); - } - - return _context.RequestHandling.NotOk( - req, - Error.Create( - ErrorCode.UNAUTHORIZED, - "Unrecognized agent" - ), - "token verification", - HttpStatusCode.Unauthorized - ); - } -} diff --git a/src/ApiService/IntegrationTests/Fakes/TestFunctionContext.cs b/src/ApiService/IntegrationTests/Fakes/TestFunctionContext.cs new file mode 100644 index 0000000000..f27f863d29 --- /dev/null +++ b/src/ApiService/IntegrationTests/Fakes/TestFunctionContext.cs @@ -0,0 +1,32 @@ +using System; +using System.Collections.Generic; +using Microsoft.Azure.Functions.Worker; +using Microsoft.OneFuzz.Service; +using Microsoft.OneFuzz.Service.Auth; + +namespace IntegrationTests.Fakes; + +public class TestFunctionContext : FunctionContext { + public override IDictionary Items { get; set; } = new Dictionary(); + + public void SetUserAuthInfo(UserInfo userInfo) + => this.SetUserAuthInfo(new UserAuthInfo(userInfo, new List())); + + // everything else unsupported + + public override string InvocationId => throw new NotSupportedException(); + + public override string FunctionId => throw new NotSupportedException(); + + public override TraceContext TraceContext => throw new NotSupportedException(); + + public override BindingContext BindingContext => throw new NotSupportedException(); + + public override RetryContext RetryContext => throw new NotSupportedException(); + + public override IServiceProvider InstanceServices { get => throw new NotSupportedException(); set => throw new NotImplementedException(); } + + public override FunctionDefinition FunctionDefinition => throw new NotSupportedException(); + + public override IInvocationFeatures Features => throw new NotSupportedException(); +} diff --git a/src/ApiService/IntegrationTests/Fakes/TestUserCredentials.cs b/src/ApiService/IntegrationTests/Fakes/TestUserCredentials.cs deleted file mode 100644 index 4ee641880f..0000000000 --- a/src/ApiService/IntegrationTests/Fakes/TestUserCredentials.cs +++ /dev/null @@ -1,20 +0,0 @@ -using System.Collections.Generic; -using System.Threading.Tasks; -using Microsoft.Azure.Functions.Worker.Http; -using Microsoft.OneFuzz.Service; - -using Async = System.Threading.Tasks; - -namespace IntegrationTests.Fakes; - -sealed class TestUserCredentials : UserCredentials { - - private readonly OneFuzzResult _tokenResult; - - public TestUserCredentials(ILogTracer log, IConfigOperations instanceConfig, OneFuzzResult tokenResult) - : base(log, instanceConfig) { - _tokenResult = tokenResult.IsOk ? OneFuzzResult.Ok(new UserAuthInfo(tokenResult.OkV, new List())) : OneFuzzResult.Error(tokenResult.ErrorV); - } - - public override Task> ParseJwtToken(HttpRequestData req) => Async.Task.FromResult(_tokenResult); -} diff --git a/src/ApiService/IntegrationTests/InfoTests.cs b/src/ApiService/IntegrationTests/InfoTests.cs index 1bd7ef339e..b2ccafb225 100644 --- a/src/ApiService/IntegrationTests/InfoTests.cs +++ b/src/ApiService/IntegrationTests/InfoTests.cs @@ -26,27 +26,8 @@ public InfoTestBase(ITestOutputHelper output, IStorage storage) : base(output, storage) { } [Fact] - public async Async.Task TestInfo_WithoutAuthorization_IsRejected() { - var auth = new TestEndpointAuthorization(RequestType.NoAuthorization, Logger, Context); - var func = new Info(auth, Context); - - var result = await func.Run(TestHttpRequestData.Empty("GET")); - Assert.Equal(HttpStatusCode.Unauthorized, result.StatusCode); - } - - [Fact] - public async Async.Task TestInfo_WithAgentCredentials_IsRejected() { - var auth = new TestEndpointAuthorization(RequestType.Agent, Logger, Context); - var func = new Info(auth, Context); - - var result = await func.Run(TestHttpRequestData.Empty("GET")); - Assert.Equal(HttpStatusCode.Unauthorized, result.StatusCode); - } - - [Fact] - public async Async.Task TestInfo_WithUserCredentials_Succeeds() { - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new Info(auth, Context); + public async Async.Task TestInfo_Succeeds() { + var func = new Info(Context); var result = await func.Run(TestHttpRequestData.Empty("GET")); Assert.Equal(HttpStatusCode.OK, result.StatusCode); diff --git a/src/ApiService/IntegrationTests/JinjaToScribanMigrationTests.cs b/src/ApiService/IntegrationTests/JinjaToScribanMigrationTests.cs index 43e90f547b..09387877d7 100644 --- a/src/ApiService/IntegrationTests/JinjaToScribanMigrationTests.cs +++ b/src/ApiService/IntegrationTests/JinjaToScribanMigrationTests.cs @@ -24,8 +24,6 @@ protected JinjaToScribanMigrationTestBase(ITestOutputHelper output, IStorage sto [Fact] public async Async.Task Dry_Run_Does_Not_Make_Changes() { - await ConfigureAuth(); - var notificationContainer = Container.Parse("abc123"); var _ = await Context.Containers.CreateContainer(notificationContainer, StorageType.Corpus, null); var r = await Context.NotificationOperations.Create( @@ -39,8 +37,7 @@ public async Async.Task Dry_Run_Does_Not_Make_Changes() { var notificationBefore = r.OkV!; var adoTemplateBefore = (notificationBefore.Config as AdoTemplate)!; - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new JinjaToScribanMigrationFunction(Logger, auth, Context); + var func = new JinjaToScribanMigrationFunction(Logger, Context); var req = new JinjaToScribanMigrationPost(DryRun: true); var result = await func.Run(TestHttpRequestData.FromJson("POST", req)); @@ -62,8 +59,6 @@ public async Async.Task Dry_Run_Does_Not_Make_Changes() { [Fact] public async Async.Task Migration_Happens_When_Not_Dry_run() { - await ConfigureAuth(); - var notificationContainer = Container.Parse("abc123"); var _ = await Context.Containers.CreateContainer(notificationContainer, StorageType.Corpus, null); var r = await Context.NotificationOperations.Create( @@ -77,8 +72,7 @@ public async Async.Task Migration_Happens_When_Not_Dry_run() { var notificationBefore = r.OkV!; var adoTemplateBefore = (notificationBefore.Config as AdoTemplate)!; - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new JinjaToScribanMigrationFunction(Logger, auth, Context); + var func = new JinjaToScribanMigrationFunction(Logger, Context); var req = new JinjaToScribanMigrationPost(); var result = await func.Run(TestHttpRequestData.FromJson("POST", req)); @@ -99,8 +93,6 @@ public async Async.Task Migration_Happens_When_Not_Dry_run() { [Fact] public async Async.Task OptionalFieldsAreSupported() { - await ConfigureAuth(); - var adoTemplate = new AdoTemplate( new Uri("http://example.com"), new SecretData(new SecretValue("some secret")), @@ -126,8 +118,6 @@ public async Async.Task OptionalFieldsAreSupported() { [Fact] public async Async.Task All_ADO_Fields_Are_Migrated() { - await ConfigureAuth(); - var notificationContainer = Container.Parse("abc123"); var adoTemplate = new AdoTemplate( new Uri("http://example.com"), @@ -161,8 +151,7 @@ public async Async.Task All_ADO_Fields_Are_Migrated() { var notificationBefore = r.OkV!; var adoTemplateBefore = (notificationBefore.Config as AdoTemplate)!; - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new JinjaToScribanMigrationFunction(Logger, auth, Context); + var func = new JinjaToScribanMigrationFunction(Logger, Context); var req = new JinjaToScribanMigrationPost(); var result = await func.Run(TestHttpRequestData.FromJson("POST", req)); @@ -196,8 +185,6 @@ public async Async.Task All_ADO_Fields_Are_Migrated() { [Fact] public async Async.Task All_Github_Fields_Are_Migrated() { - await ConfigureAuth(); - var githubTemplate = MigratableGithubTemplate(); var notificationContainer = Container.Parse("abc123"); @@ -213,8 +200,7 @@ public async Async.Task All_Github_Fields_Are_Migrated() { var notificationBefore = r.OkV!; var githubTemplateBefore = (notificationBefore.Config as GithubIssuesTemplate)!; - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new JinjaToScribanMigrationFunction(Logger, auth, Context); + var func = new JinjaToScribanMigrationFunction(Logger, Context); var req = new JinjaToScribanMigrationPost(); var result = await func.Run(TestHttpRequestData.FromJson("POST", req)); @@ -252,8 +238,6 @@ public async Async.Task All_Github_Fields_Are_Migrated() { [Fact] public async Async.Task Teams_Template_Not_Migrated() { - await ConfigureAuth(); - var teamsTemplate = GetTeamsTemplate(); var notificationContainer = Container.Parse("abc123"); var _ = await Context.Containers.CreateContainer(notificationContainer, StorageType.Corpus, null); @@ -268,8 +252,7 @@ public async Async.Task Teams_Template_Not_Migrated() { var notificationBefore = r.OkV!; var teamsTemplateBefore = (notificationBefore.Config as TeamsTemplate)!; - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new JinjaToScribanMigrationFunction(Logger, auth, Context); + var func = new JinjaToScribanMigrationFunction(Logger, Context); var req = new JinjaToScribanMigrationPost(); var result = await func.Run(TestHttpRequestData.FromJson("POST", req)); @@ -288,8 +271,6 @@ public async Async.Task Teams_Template_Not_Migrated() { // Multiple notification configs can be migrated [Fact] public async Async.Task Can_Migrate_Multiple_Notification_Configs() { - await ConfigureAuth(); - var notificationContainer = Container.Parse("abc123"); var _ = await Context.Containers.CreateContainer(notificationContainer, StorageType.Corpus, null); @@ -323,8 +304,7 @@ public async Async.Task Can_Migrate_Multiple_Notification_Configs() { var githubNotificationBefore = r.OkV!; var githubTemplateBefore = (githubNotificationBefore.Config as GithubIssuesTemplate)!; - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new JinjaToScribanMigrationFunction(Logger, auth, Context); + var func = new JinjaToScribanMigrationFunction(Logger, Context); var req = new JinjaToScribanMigrationPost(); var result = await func.Run(TestHttpRequestData.FromJson("POST", req)); @@ -351,33 +331,12 @@ public async Async.Task Can_Migrate_Multiple_Notification_Configs() { githubTemplateAfter.Organization.Should().BeEquivalentTo(JinjaTemplateAdapter.AdaptForScriban(githubTemplateBefore.Organization)); } - [Fact] - public async Async.Task Access_WithoutAuthorization_IsRejected() { - - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new JinjaToScribanMigrationFunction(Logger, auth, Context); - var req = new JinjaToScribanMigrationPost(); - var result = await func.Run(TestHttpRequestData.FromJson("POST", req)); - - result.StatusCode.Should().Be(System.Net.HttpStatusCode.BadRequest); - } - [Fact] public async Async.Task Do_Not_Enforce_Key_Exists_In_Strict_Validation() { (await JinjaTemplateAdapter.IsValidScribanNotificationTemplate(Context, Logger, ValidScribanAdoTemplate())) .Should().BeTrue(); } - private async Async.Task ConfigureAuth() { - await Context.InsertAll( - new InstanceConfig(Context.ServiceConfiguration.OneFuzzInstanceName!) { Admins = new[] { _userObjectId } } // needed for admin check - ); - - // override the found user credentials - need these to check for admin - var userInfo = new UserInfo(ApplicationId: Guid.NewGuid(), ObjectId: _userObjectId, "upn"); - Context.UserCredentials = new TestUserCredentials(Logger, Context.ConfigOperations, OneFuzzResult.Ok(userInfo)); - } - private static AdoTemplate MigratableAdoTemplate() { return new AdoTemplate( new Uri("http://example.com"), diff --git a/src/ApiService/IntegrationTests/JobsTests.cs b/src/ApiService/IntegrationTests/JobsTests.cs index 41b0b4b1c5..365df56d88 100644 --- a/src/ApiService/IntegrationTests/JobsTests.cs +++ b/src/ApiService/IntegrationTests/JobsTests.cs @@ -29,27 +29,12 @@ public JobsTestBase(ITestOutputHelper output, IStorage storage) private readonly Guid _jobId = Guid.NewGuid(); private readonly JobConfig _config = new("project", "name", "build", 1000, null); - [Theory] - [InlineData("POST")] - [InlineData("GET")] - [InlineData("DELETE")] - public async Async.Task Access_WithoutAuthorization_IsRejected(string method) { - var auth = new TestEndpointAuthorization(RequestType.NoAuthorization, Logger, Context); - var func = new Jobs(auth, Context, Logger); - - var result = await func.Run(TestHttpRequestData.Empty(method)); - Assert.Equal(HttpStatusCode.Unauthorized, result.StatusCode); - - var err = BodyAs(result); - Assert.Equal(ErrorCode.UNAUTHORIZED.ToString(), err.Title); - } - [Fact] public async Async.Task Delete_NonExistentJob_Fails() { - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new Jobs(auth, Context, Logger); + var func = new Jobs(Context, Logger); - var result = await func.Run(TestHttpRequestData.FromJson("DELETE", new JobGet(_jobId))); + var ctx = new TestFunctionContext(); + var result = await func.Run(TestHttpRequestData.FromJson("DELETE", new JobGet(_jobId)), ctx); Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode); var err = BodyAs(result); @@ -61,10 +46,10 @@ public async Async.Task Delete_ExistingJob_SetsStoppingState() { await Context.InsertAll( new Job(_jobId, JobState.Enabled, _config)); - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new Jobs(auth, Context, Logger); + var func = new Jobs(Context, Logger); - var result = await func.Run(TestHttpRequestData.FromJson("DELETE", new JobGet(_jobId))); + var ctx = new TestFunctionContext(); + var result = await func.Run(TestHttpRequestData.FromJson("DELETE", new JobGet(_jobId)), ctx); Assert.Equal(HttpStatusCode.OK, result.StatusCode); var response = BodyAs(result); @@ -80,10 +65,10 @@ public async Async.Task Delete_ExistingStoppedJob_DoesNotSetStoppingState() { await Context.InsertAll( new Job(_jobId, JobState.Stopped, _config)); - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new Jobs(auth, Context, Logger); + var func = new Jobs(Context, Logger); - var result = await func.Run(TestHttpRequestData.FromJson("DELETE", new JobGet(_jobId))); + var ctx = new TestFunctionContext(); + var result = await func.Run(TestHttpRequestData.FromJson("DELETE", new JobGet(_jobId)), ctx); Assert.Equal(HttpStatusCode.OK, result.StatusCode); var response = BodyAs(result); @@ -100,10 +85,10 @@ public async Async.Task Get_CanFindSpecificJob() { await Context.InsertAll( new Job(_jobId, JobState.Stopped, _config)); - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new Jobs(auth, Context, Logger); + var func = new Jobs(Context, Logger); - var result = await func.Run(TestHttpRequestData.FromJson("GET", new JobSearch(JobId: _jobId))); + var ctx = new TestFunctionContext(); + var result = await func.Run(TestHttpRequestData.FromJson("GET", new JobSearch(JobId: _jobId)), ctx); Assert.Equal(HttpStatusCode.OK, result.StatusCode); var response = BodyAs(result); @@ -119,11 +104,11 @@ await Context.InsertAll( new Job(Guid.NewGuid(), JobState.Enabled, _config), new Job(Guid.NewGuid(), JobState.Stopped, _config)); - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new Jobs(auth, Context, Logger); + var func = new Jobs(Context, Logger); var req = new JobSearch(State: new List { JobState.Enabled }); - var result = await func.Run(TestHttpRequestData.FromJson("GET", req)); + var ctx = new TestFunctionContext(); + var result = await func.Run(TestHttpRequestData.FromJson("GET", req), ctx); Assert.Equal(HttpStatusCode.OK, result.StatusCode); var response = BodyAs(result); @@ -138,11 +123,11 @@ await Context.InsertAll( new Job(Guid.NewGuid(), JobState.Enabled, _config), new Job(Guid.NewGuid(), JobState.Stopped, _config)); - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new Jobs(auth, Context, Logger); + var func = new Jobs(Context, Logger); var req = new JobSearch(State: new List { JobState.Enabled, JobState.Stopping }); - var result = await func.Run(TestHttpRequestData.FromJson("GET", req)); + var ctx = new TestFunctionContext(); + var result = await func.Run(TestHttpRequestData.FromJson("GET", req), ctx); Assert.Equal(HttpStatusCode.OK, result.StatusCode); var response = BodyAs(result); @@ -153,14 +138,12 @@ await Context.InsertAll( [Fact] public async Async.Task Post_CreatesJob_AndContainer() { - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new Jobs(auth, Context, Logger); + var func = new Jobs(Context, Logger); // need user credentials to put into the job object - var userInfo = new UserInfo(Guid.NewGuid(), Guid.NewGuid(), "upn"); - Context.UserCredentials = new TestUserCredentials(Logger, Context.ConfigOperations, OneFuzzResult.Ok(userInfo)); - - var result = await func.Run(TestHttpRequestData.FromJson("POST", _config)); + var ctx = new TestFunctionContext(); + ctx.SetUserAuthInfo(new UserInfo(Guid.NewGuid(), Guid.NewGuid(), "upn")); + var result = await func.Run(TestHttpRequestData.FromJson("POST", _config), ctx); Assert.Equal(HttpStatusCode.OK, result.StatusCode); var job = Assert.Single(await Context.JobOperations.SearchAll().ToListAsync()); diff --git a/src/ApiService/IntegrationTests/NodeTests.cs b/src/ApiService/IntegrationTests/NodeTests.cs index 8902a01ab3..8900fb9d6f 100644 --- a/src/ApiService/IntegrationTests/NodeTests.cs +++ b/src/ApiService/IntegrationTests/NodeTests.cs @@ -35,10 +35,8 @@ public NodeTestBase(ITestOutputHelper output, IStorage storage) [Fact] public async Async.Task Search_SpecificNode_NotFound_ReturnsNotFound() { - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var req = new NodeSearch(MachineId: _machineId); - var func = new NodeFunction(Logger, auth, Context); + var func = new NodeFunction(Logger, Context); var result = await func.Run(TestHttpRequestData.FromJson("GET", req)); Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode); } @@ -48,10 +46,8 @@ public async Async.Task Search_SpecificNode_Found_ReturnsOk() { await Context.InsertAll( new Node(_poolName, _machineId, null, _version)); - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var req = new NodeSearch(MachineId: _machineId); - var func = new NodeFunction(Logger, auth, Context); + var func = new NodeFunction(Logger, Context); var result = await func.Run(TestHttpRequestData.FromJson("GET", req)); Assert.Equal(HttpStatusCode.OK, result.StatusCode); @@ -62,10 +58,8 @@ await Context.InsertAll( [Fact] public async Async.Task Search_MultipleNodes_CanFindNone() { - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var req = new NodeSearch(); - var func = new NodeFunction(Logger, auth, Context); + var func = new NodeFunction(Logger, Context); var result = await func.Run(TestHttpRequestData.FromJson("GET", req)); Assert.Equal(HttpStatusCode.OK, result.StatusCode); Assert.Equal("[]", BodyAsString(result)); @@ -80,8 +74,7 @@ await Context.InsertAll( var req = new NodeSearch(PoolName: _poolName); - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new NodeFunction(Logger, auth, Context); + var func = new NodeFunction(Logger, Context); var result = await func.Run(TestHttpRequestData.FromJson("GET", req)); Assert.Equal(HttpStatusCode.OK, result.StatusCode); @@ -99,8 +92,7 @@ await Context.InsertAll( var req = new NodeSearch(ScalesetId: _scalesetId); - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new NodeFunction(Logger, auth, Context); + var func = new NodeFunction(Logger, Context); var result = await func.Run(TestHttpRequestData.FromJson("GET", req)); Assert.Equal(HttpStatusCode.OK, result.StatusCode); @@ -109,7 +101,6 @@ await Context.InsertAll( Assert.Equal(2, deserialized.Length); } - [Fact] public async Async.Task Search_MultipleNodes_ByState() { await Context.InsertAll( @@ -119,8 +110,7 @@ await Context.InsertAll( var req = new NodeSearch(State: new List { NodeState.Busy }); - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new NodeFunction(Logger, auth, Context); + var func = new NodeFunction(Logger, Context); var result = await func.Run(TestHttpRequestData.FromJson("GET", req)); Assert.Equal(HttpStatusCode.OK, result.StatusCode); @@ -139,8 +129,7 @@ await Context.InsertAll( var req = new NodeSearch(State: new List { NodeState.Free, NodeState.Busy }); - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new NodeFunction(Logger, auth, Context); + var func = new NodeFunction(Logger, Context); var result = await func.Run(TestHttpRequestData.FromJson("GET", req)); Assert.Equal(HttpStatusCode.OK, result.StatusCode); @@ -148,147 +137,4 @@ await Context.InsertAll( var deserialized = BodyAs(result); Assert.Equal(3, deserialized.Length); } - - [Theory] - [InlineData("PATCH")] - [InlineData("POST")] - [InlineData("DELETE")] - public async Async.Task RequiresAdmin(string method) { - // config must be found - await Context.InsertAll( - new InstanceConfig(Context.ServiceConfiguration.OneFuzzInstanceName!) { - RequireAdminPrivileges = true - }); - - // must be a user to auth - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - - // override the found user credentials - var userInfo = new UserInfo(ApplicationId: Guid.NewGuid(), ObjectId: Guid.NewGuid(), "upn"); - Context.UserCredentials = new TestUserCredentials(Logger, Context.ConfigOperations, OneFuzzResult.Ok(userInfo)); - - var req = new NodeGet(MachineId: _machineId); - var func = new NodeFunction(Logger, auth, Context); - var result = await func.Run(TestHttpRequestData.FromJson(method, req)); - Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode); - - var err = BodyAs(result); - Assert.Equal(ErrorCode.UNAUTHORIZED.ToString(), err.Title); - Assert.Contains("pool modification disabled", err.Detail); - } - - [Theory] - [InlineData("PATCH")] - [InlineData("POST")] - [InlineData("DELETE")] - public async Async.Task RequiresAdmin_CanBeDisabled(string method) { - // disable requiring admin privileges - await Context.InsertAll( - new InstanceConfig(Context.ServiceConfiguration.OneFuzzInstanceName!) { - RequireAdminPrivileges = false - }); - - // must be a user to auth - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - - // override the found user credentials - var userInfo = new UserInfo(ApplicationId: Guid.NewGuid(), ObjectId: Guid.NewGuid(), "upn"); - Context.UserCredentials = new TestUserCredentials(Logger, Context.ConfigOperations, OneFuzzResult.Ok(userInfo)); - - var req = new NodeGet(MachineId: _machineId); - var func = new NodeFunction(Logger, auth, Context); - var result = await func.Run(TestHttpRequestData.FromJson(method, req)); - - // we will fail with BadRequest but due to not being able to find the Node, - // not because of UNAUTHORIZED - Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode); - Assert.Equal(ErrorCode.UNABLE_TO_FIND.ToString(), BodyAs(result).Title); - } - - [Theory] - [InlineData("PATCH")] - [InlineData("POST")] - [InlineData("DELETE")] - public async Async.Task UserCanBeAdmin(string method) { - var userObjectId = Guid.NewGuid(); - - // config specifies that user is admin - await Context.InsertAll( - new InstanceConfig(Context.ServiceConfiguration.OneFuzzInstanceName!) { - Admins = new[] { userObjectId } - }); - - // must be a user to auth - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - - // override the found user credentials - var userInfo = new UserInfo(ApplicationId: Guid.NewGuid(), ObjectId: userObjectId, "upn"); - Context.UserCredentials = new TestUserCredentials(Logger, Context.ConfigOperations, OneFuzzResult.Ok(userInfo)); - - var req = new NodeGet(MachineId: _machineId); - var func = new NodeFunction(Logger, auth, Context); - var result = await func.Run(TestHttpRequestData.FromJson(method, req)); - - // we will fail with BadRequest but due to not being able to find the Node, - // not because of UNAUTHORIZED - Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode); - Assert.Equal(ErrorCode.UNABLE_TO_FIND.ToString(), BodyAs(result).Title); - } - - [Theory] - [InlineData("PATCH")] - [InlineData("POST")] - [InlineData("DELETE")] - public async Async.Task EnablingAdminForAnotherUserDoesNotPermitThisUser(string method) { - var userObjectId = Guid.NewGuid(); - var otherObjectId = Guid.NewGuid(); - - // config specifies that a different user is admin - await Context.InsertAll( - new InstanceConfig(Context.ServiceConfiguration.OneFuzzInstanceName!) { - Admins = new[] { otherObjectId }, RequireAdminPrivileges = true - }); - - // must be a user to auth - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - - // override the found user credentials - var userInfo = new UserInfo(ApplicationId: Guid.NewGuid(), ObjectId: userObjectId, "upn"); - Context.UserCredentials = new TestUserCredentials(Logger, Context.ConfigOperations, OneFuzzResult.Ok(userInfo)); - - var req = new NodeGet(MachineId: _machineId); - var func = new NodeFunction(Logger, auth, Context); - var result = await func.Run(TestHttpRequestData.FromJson(method, req)); - Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode); - - var err = BodyAs(result); - Assert.Equal(ErrorCode.UNAUTHORIZED.ToString(), err.Title); - Assert.Contains("not authorized to manage instance", err.Detail); - } - - [Theory] - [InlineData("PATCH")] - [InlineData("POST")] - [InlineData("DELETE")] - public async Async.Task CanPerformOperation(string method) { - // disable requiring admin privileges - await Context.InsertAll( - new InstanceConfig(Context.ServiceConfiguration.OneFuzzInstanceName!) { - RequireAdminPrivileges = false - }, - new Node(_poolName, _machineId, null, _version)); - - // must be a user to auth - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - - // override the found user credentials - var userInfo = new UserInfo(ApplicationId: Guid.NewGuid(), ObjectId: Guid.NewGuid(), "upn"); - Context.UserCredentials = new TestUserCredentials(Logger, Context.ConfigOperations, OneFuzzResult.Ok(userInfo)); - - // all of these operations use NodeGet - var req = new NodeGet(MachineId: _machineId); - var func = new NodeFunction(Logger, auth, Context); - var result = await func.Run(TestHttpRequestData.FromJson(method, req)); - Assert.Equal(HttpStatusCode.OK, result.StatusCode); - } } diff --git a/src/ApiService/IntegrationTests/PoolTests.cs b/src/ApiService/IntegrationTests/PoolTests.cs index f861053ffb..437120291a 100644 --- a/src/ApiService/IntegrationTests/PoolTests.cs +++ b/src/ApiService/IntegrationTests/PoolTests.cs @@ -31,26 +31,10 @@ public PoolTestBase(ITestOutputHelper output, IStorage storage) private readonly PoolName _poolName = PoolName.Parse("pool-" + Guid.NewGuid()); - [Theory] - [InlineData("POST", RequestType.Agent)] - [InlineData("POST", RequestType.NoAuthorization)] - [InlineData("GET", RequestType.Agent)] - [InlineData("GET", RequestType.NoAuthorization)] - [InlineData("DELETE", RequestType.Agent)] - [InlineData("DELETE", RequestType.NoAuthorization)] - public async Async.Task UserAuthorization_IsRequired(string method, RequestType authType) { - var auth = new TestEndpointAuthorization(authType, Logger, Context); - var func = new PoolFunction(Logger, auth, Context); - var result = await func.Run(TestHttpRequestData.Empty(method)); - Assert.Equal(HttpStatusCode.Unauthorized, result.StatusCode); - } - [Fact] public async Async.Task Search_SpecificPool_ById_NotFound_ReturnsBadRequest() { - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var req = new PoolSearch(PoolId: _poolId); - var func = new PoolFunction(Logger, auth, Context); + var func = new PoolFunction(Context); var result = await func.Run(TestHttpRequestData.FromJson("GET", req)); Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode); } @@ -66,10 +50,8 @@ await Context.InsertAll( // use test class to override instance ID Context.Containers = new TestContainers(Logger, Context.Storage, Context.ServiceConfiguration); - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var req = new PoolSearch(PoolId: _poolId); - var func = new PoolFunction(Logger, auth, Context); + var func = new PoolFunction(Context); var result = await func.Run(TestHttpRequestData.FromJson("GET", req)); Assert.Equal(HttpStatusCode.OK, result.StatusCode); @@ -79,10 +61,8 @@ await Context.InsertAll( [Fact] public async Async.Task Search_SpecificPool_ByName_NotFound_ReturnsBadRequest() { - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var req = new PoolSearch(Name: _poolName); - var func = new PoolFunction(Logger, auth, Context); + var func = new PoolFunction(Context); var result = await func.Run(TestHttpRequestData.FromJson("GET", req)); Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode); } @@ -98,10 +78,8 @@ await Context.InsertAll( // use test class to override instance ID Context.Containers = new TestContainers(Logger, Context.Storage, Context.ServiceConfiguration); - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var req = new PoolSearch(Name: _poolName); - var func = new PoolFunction(Logger, auth, Context); + var func = new PoolFunction(Context); var result = await func.Run(TestHttpRequestData.FromJson("GET", req)); Assert.Equal(HttpStatusCode.OK, result.StatusCode); @@ -111,10 +89,8 @@ await Context.InsertAll( [Fact] public async Async.Task Search_SpecificPool_ByState_NotFound_ReturnsEmptyResult() { - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var req = new PoolSearch(State: new List { PoolState.Init }); - var func = new PoolFunction(Logger, auth, Context); + var func = new PoolFunction(Context); var result = await func.Run(TestHttpRequestData.FromJson("GET", req)); Assert.Equal(HttpStatusCode.OK, result.StatusCode); @@ -127,9 +103,7 @@ await Context.InsertAll( new InstanceConfig(Context.ServiceConfiguration.OneFuzzInstanceName!) { Admins = new[] { _userObjectId } }, // needed for admin check new Pool(_poolName, _poolId, Os.Linux, true, Architecture.x86_64, PoolState.Running, null)); - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - - var func = new PoolFunction(Logger, auth, Context); + var func = new PoolFunction(Context); var result = await func.Run(TestHttpRequestData.FromJson("GET", new PoolSearch())); Assert.Equal(HttpStatusCode.OK, result.StatusCode); @@ -144,15 +118,9 @@ await Context.InsertAll( new InstanceConfig(Context.ServiceConfiguration.OneFuzzInstanceName!) { Admins = new[] { _userObjectId } }, // needed for admin check new Pool(_poolName, _poolId, Os.Linux, true, Architecture.x86_64, PoolState.Running, null)); - // override the found user credentials - need these to check for admin - var userInfo = new UserInfo(ApplicationId: Guid.NewGuid(), ObjectId: _userObjectId, "upn"); - Context.UserCredentials = new TestUserCredentials(Logger, Context.ConfigOperations, OneFuzzResult.Ok(userInfo)); - - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new PoolFunction(Logger, auth, Context); - + var func = new PoolFunction(Context); var req = new PoolStop(Name: _poolName, Now: false); - var result = await func.Run(TestHttpRequestData.FromJson("DELETE", req)); + var result = await func.Admin(TestHttpRequestData.FromJson("DELETE", req)); Assert.Equal(HttpStatusCode.OK, result.StatusCode); var pool = await Context.PoolOperations.GetByName(_poolName); @@ -166,15 +134,9 @@ await Context.InsertAll( new InstanceConfig(Context.ServiceConfiguration.OneFuzzInstanceName!) { Admins = new[] { _userObjectId } }, // needed for admin check new Pool(_poolName, _poolId, Os.Linux, true, Architecture.x86_64, PoolState.Halt, null)); - // override the found user credentials - need these to check for admin - var userInfo = new UserInfo(ApplicationId: Guid.NewGuid(), ObjectId: _userObjectId, "upn"); - Context.UserCredentials = new TestUserCredentials(Logger, Context.ConfigOperations, OneFuzzResult.Ok(userInfo)); - - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new PoolFunction(Logger, auth, Context); - + var func = new PoolFunction(Context); var req = new PoolStop(Name: _poolName, Now: false); - var result = await func.Run(TestHttpRequestData.FromJson("DELETE", req)); + var result = await func.Admin(TestHttpRequestData.FromJson("DELETE", req)); Assert.Equal(HttpStatusCode.OK, result.StatusCode); var pool = await Context.PoolOperations.GetByName(_poolName); @@ -188,15 +150,9 @@ await Context.InsertAll( new InstanceConfig(Context.ServiceConfiguration.OneFuzzInstanceName!) { Admins = new[] { _userObjectId } }, // needed for admin check new Pool(_poolName, _poolId, Os.Linux, true, Architecture.x86_64, PoolState.Running, null)); - // override the found user credentials - need these to check for admin - var userInfo = new UserInfo(ApplicationId: Guid.NewGuid(), ObjectId: _userObjectId, "upn"); - Context.UserCredentials = new TestUserCredentials(Logger, Context.ConfigOperations, OneFuzzResult.Ok(userInfo)); - - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new PoolFunction(Logger, auth, Context); - + var func = new PoolFunction(Context); var req = new PoolStop(Name: _poolName, Now: true); - var result = await func.Run(TestHttpRequestData.FromJson("DELETE", req)); + var result = await func.Admin(TestHttpRequestData.FromJson("DELETE", req)); Assert.Equal(HttpStatusCode.OK, result.StatusCode); var pool = await Context.PoolOperations.GetByName(_poolName); @@ -209,18 +165,12 @@ public async Async.Task Post_CreatesNewPool() { await Context.InsertAll( new InstanceConfig(Context.ServiceConfiguration.OneFuzzInstanceName!) { Admins = new[] { _userObjectId } }); // needed for admin check - // override the found user credentials - need these to check for admin - var userInfo = new UserInfo(ApplicationId: Guid.NewGuid(), ObjectId: _userObjectId, "upn"); - Context.UserCredentials = new TestUserCredentials(Logger, Context.ConfigOperations, OneFuzzResult.Ok(userInfo)); - // need to override instance id Context.Containers = new TestContainers(Logger, Context.Storage, Context.ServiceConfiguration); - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new PoolFunction(Logger, auth, Context); - + var func = new PoolFunction(Context); var req = new PoolCreate(Name: _poolName, Os.Linux, Architecture.x86_64, true); - var result = await func.Run(TestHttpRequestData.FromJson("POST", req)); + var result = await func.Admin(TestHttpRequestData.FromJson("POST", req)); Assert.Equal(HttpStatusCode.OK, result.StatusCode); // should get a pool back @@ -240,15 +190,9 @@ await Context.InsertAll( new InstanceConfig(Context.ServiceConfiguration.OneFuzzInstanceName!) { Admins = new[] { _userObjectId } }, // needed for admin check new Pool(_poolName, _poolId, Os.Linux, true, Architecture.x86_64, PoolState.Running, null)); - // override the found user credentials - need these to check for admin - var userInfo = new UserInfo(ApplicationId: Guid.NewGuid(), ObjectId: _userObjectId, "upn"); - Context.UserCredentials = new TestUserCredentials(Logger, Context.ConfigOperations, OneFuzzResult.Ok(userInfo)); - - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new PoolFunction(Logger, auth, Context); - + var func = new PoolFunction(Context); var req = new PoolCreate(Name: _poolName, Os.Linux, Architecture.x86_64, true); - var result = await func.Run(TestHttpRequestData.FromJson("POST", req)); + var result = await func.Admin(TestHttpRequestData.FromJson("POST", req)); Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode); // should get an error back diff --git a/src/ApiService/IntegrationTests/ReproVmssTests.cs b/src/ApiService/IntegrationTests/ReproVmssTests.cs index 29168df2be..83b274f4cc 100644 --- a/src/ApiService/IntegrationTests/ReproVmssTests.cs +++ b/src/ApiService/IntegrationTests/ReproVmssTests.cs @@ -28,27 +28,12 @@ public abstract class ReproVmssTestBase : FunctionTestBase { public ReproVmssTestBase(ITestOutputHelper output, IStorage storage) : base(output, storage) { } - - [Theory] - [InlineData("POST", RequestType.Agent)] - [InlineData("POST", RequestType.NoAuthorization)] - [InlineData("GET", RequestType.Agent)] - [InlineData("GET", RequestType.NoAuthorization)] - [InlineData("DELETE", RequestType.Agent)] - [InlineData("DELETE", RequestType.NoAuthorization)] - public async Async.Task UserAuthorization_IsRequired(string method, RequestType authType) { - var auth = new TestEndpointAuthorization(authType, Logger, Context); - var func = new ReproVmss(Logger, auth, Context); - var result = await func.Run(TestHttpRequestData.Empty(method)); - Assert.Equal(HttpStatusCode.Unauthorized, result.StatusCode); - } - [Fact] public async Async.Task GetMissingVmFails() { - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new ReproVmss(Logger, auth, Context); + var func = new ReproVmss(Logger, Context); var req = new ReproGet(VmId: Guid.NewGuid()); - var result = await func.Run(TestHttpRequestData.FromJson("GET", req)); + var ctx = new TestFunctionContext(); + var result = await func.Run(TestHttpRequestData.FromJson("GET", req), ctx); // TODO: should this be 404? Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode); var err = BodyAs(result); @@ -57,10 +42,10 @@ public async Async.Task GetMissingVmFails() { [Fact] public async Async.Task GetAvailableVMsCanReturnEmpty() { - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new ReproVmss(Logger, auth, Context); + var func = new ReproVmss(Logger, Context); var req = new ReproGet(VmId: null); // this means "all available" - var result = await func.Run(TestHttpRequestData.FromJson("GET", req)); + var ctx = new TestFunctionContext(); + var result = await func.Run(TestHttpRequestData.FromJson("GET", req), ctx); Assert.Equal(HttpStatusCode.OK, result.StatusCode); Assert.Empty(BodyAs(result)); } @@ -77,10 +62,10 @@ await Context.InsertAll( Auth: new SecretValue(new Authentication("", "", "")), Os: Os.Linux)); - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new ReproVmss(Logger, auth, Context); + var func = new ReproVmss(Logger, Context); var req = new ReproGet(VmId: null); // this means "all available" - var result = await func.Run(TestHttpRequestData.FromJson("GET", req)); + var ctx = new TestFunctionContext(); + var result = await func.Run(TestHttpRequestData.FromJson("GET", req), ctx); Assert.Equal(HttpStatusCode.OK, result.StatusCode); var repro = Assert.Single(BodyAs(result)); Assert.Equal(vmId, repro.VmId); @@ -101,10 +86,10 @@ await Context.InsertAll( Auth: new SecretAddress(secretUri), Os: Os.Linux)); - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new ReproVmss(Logger, auth, Context); + var func = new ReproVmss(Logger, Context); var req = new ReproGet(VmId: vmId); - var result = await func.Run(TestHttpRequestData.FromJson("GET", req)); + var ctx = new TestFunctionContext(); + var result = await func.Run(TestHttpRequestData.FromJson("GET", req), ctx); Assert.Equal(HttpStatusCode.OK, result.StatusCode); Assert.Equal(vmId, BodyAs(result).VmId); } @@ -128,37 +113,23 @@ await Context.InsertAll( Os: Os.Linux, State: VmState.Stopped)); - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new ReproVmss(Logger, auth, Context); + var func = new ReproVmss(Logger, Context); var req = new ReproGet(VmId: null); // this means "all available" - var result = await func.Run(TestHttpRequestData.FromJson("GET", req)); + var ctx = new TestFunctionContext(); + var result = await func.Run(TestHttpRequestData.FromJson("GET", req), ctx); Assert.Equal(HttpStatusCode.OK, result.StatusCode); Assert.Empty(BodyAs(result)); } - [Fact] - public async Async.Task CannotCreateVMWithoutCredentials() { - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - - var func = new ReproVmss(Logger, auth, Context); - var req = new ReproCreate(Container.Parse("abcd"), "/", 12345); - var result = await func.Run(TestHttpRequestData.FromJson("POST", req)); - Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode); - var err = BodyAs(result); - Assert.Equal(new ProblemDetails(400, "INVALID_REQUEST", "unable to find authorization token"), err); - } - [Fact] public async Async.Task CannotCreateVMForMissingReport() { - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - // setup fake user - var userInfo = new UserInfo(Guid.NewGuid(), Guid.NewGuid(), "upn"); - Context.UserCredentials = new TestUserCredentials(Logger, Context.ConfigOperations, OneFuzzResult.Ok(userInfo)); + var ctx = new TestFunctionContext(); + ctx.SetUserAuthInfo(new UserInfo(Guid.NewGuid(), Guid.NewGuid(), "upn")); - var func = new ReproVmss(Logger, auth, Context); + var func = new ReproVmss(Logger, Context); var req = new ReproCreate(Container.Parse("abcd"), "/", 12345); - var result = await func.Run(TestHttpRequestData.FromJson("POST", req)); + var result = await func.Run(TestHttpRequestData.FromJson("POST", req), ctx); Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode); var err = BodyAs(result); Assert.Equal(new ProblemDetails(400, "UNABLE_TO_FIND", "unable to find report"), err); @@ -209,15 +180,13 @@ public async Async.Task CannotCreateVMForMissingReport() { public async Async.Task CannotCreateVMForMissingTask() { var (container, filename) = await CreateContainerWithReport(Guid.NewGuid(), Guid.NewGuid()); - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - // setup fake user - var userInfo = new UserInfo(Guid.NewGuid(), Guid.NewGuid(), "upn"); - Context.UserCredentials = new TestUserCredentials(Logger, Context.ConfigOperations, OneFuzzResult.Ok(userInfo)); + var ctx = new TestFunctionContext(); + ctx.SetUserAuthInfo(new UserInfo(Guid.NewGuid(), Guid.NewGuid(), "upn")); - var func = new ReproVmss(Logger, auth, Context); + var func = new ReproVmss(Logger, Context); var req = new ReproCreate(container, filename, 12345); - var result = await func.Run(TestHttpRequestData.FromJson("POST", req)); + var result = await func.Run(TestHttpRequestData.FromJson("POST", req), ctx); Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode); var err = BodyAs(result); Assert.Equal(new ProblemDetails(400, "INVALID_REQUEST", "unable to find task"), err); @@ -241,15 +210,13 @@ await Context.InsertAll( null, new TaskDetails(TaskType.LibfuzzerFuzz, 12345)))); - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - // setup fake user - var userInfo = new UserInfo(Guid.NewGuid(), Guid.NewGuid(), "upn"); - Context.UserCredentials = new TestUserCredentials(Logger, Context.ConfigOperations, OneFuzzResult.Ok(userInfo)); + var ctx = new TestFunctionContext(); + ctx.SetUserAuthInfo(new UserInfo(Guid.NewGuid(), Guid.NewGuid(), "upn")); - var func = new ReproVmss(Logger, auth, Context); + var func = new ReproVmss(Logger, Context); var req = new ReproCreate(container, filename, 12345); - var result = await func.Run(TestHttpRequestData.FromJson("POST", req)); + var result = await func.Run(TestHttpRequestData.FromJson("POST", req), ctx); Assert.Equal(HttpStatusCode.OK, result.StatusCode); var repro = BodyAs(result); Assert.Equal(taskId, repro.TaskId); diff --git a/src/ApiService/IntegrationTests/ScalesetTests.cs b/src/ApiService/IntegrationTests/ScalesetTests.cs index 6ee51271ae..441ed9dc8f 100644 --- a/src/ApiService/IntegrationTests/ScalesetTests.cs +++ b/src/ApiService/IntegrationTests/ScalesetTests.cs @@ -25,28 +25,10 @@ public abstract class ScalesetTestBase : FunctionTestBase { public ScalesetTestBase(ITestOutputHelper output, IStorage storage) : base(output, storage) { } - [Theory] - [InlineData("POST", RequestType.Agent)] - [InlineData("POST", RequestType.NoAuthorization)] - [InlineData("PATCH", RequestType.Agent)] - [InlineData("PATCH", RequestType.NoAuthorization)] - [InlineData("GET", RequestType.Agent)] - [InlineData("GET", RequestType.NoAuthorization)] - [InlineData("DELETE", RequestType.Agent)] - [InlineData("DELETE", RequestType.NoAuthorization)] - public async Async.Task UserAuthorization_IsRequired(string method, RequestType authType) { - var auth = new TestEndpointAuthorization(authType, Logger, Context); - var func = new ScalesetFunction(Logger, auth, Context); - var result = await func.Run(TestHttpRequestData.Empty(method)); - Assert.Equal(HttpStatusCode.Unauthorized, result.StatusCode); - } - [Fact] public async Async.Task Search_SpecificScaleset_ReturnsErrorIfNoneFound() { - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var req = new ScalesetSearch(ScalesetId: ScalesetId.Parse(Guid.NewGuid().ToString())); - var func = new ScalesetFunction(Logger, auth, Context); + var func = new ScalesetFunction(Logger, Context); var result = await func.Run(TestHttpRequestData.FromJson("GET", req)); Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode); @@ -56,10 +38,8 @@ public async Async.Task Search_SpecificScaleset_ReturnsErrorIfNoneFound() { [Fact] public async Async.Task Search_AllScalesets_ReturnsEmptyIfNoneFound() { - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var req = new ScalesetSearch(); - var func = new ScalesetFunction(Logger, auth, Context); + var func = new ScalesetFunction(Logger, Context); var result = await func.Run(TestHttpRequestData.FromJson("GET", req)); Assert.Equal(HttpStatusCode.OK, result.StatusCode); @@ -81,10 +61,8 @@ await Context.InsertAll( new Node(poolName, Guid.NewGuid(), poolId, "version", ScalesetId: scalesetId) ); - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var req = new ScalesetSearch(ScalesetId: scalesetId); - var func = new ScalesetFunction(Logger, auth, Context); + var func = new ScalesetFunction(Logger, Context); var result = await func.Run(TestHttpRequestData.FromJson("GET", req)); Assert.Equal(HttpStatusCode.OK, result.StatusCode); @@ -95,17 +73,10 @@ await Context.InsertAll( [Fact] public async Async.Task Create_Scaleset() { - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - - // override the found user credentials - var userObjectId = Guid.NewGuid(); - var userInfo = new UserInfo(ApplicationId: Guid.NewGuid(), ObjectId: userObjectId, "upn"); - Context.UserCredentials = new TestUserCredentials(Logger, Context.ConfigOperations, OneFuzzResult.Ok(userInfo)); - var poolName = PoolName.Parse("mypool"); await Context.InsertAll( - // user must be admin - new InstanceConfig(Context.ServiceConfiguration.OneFuzzInstanceName!) { Admins = new[] { userObjectId } }, + // config must exist + new InstanceConfig(Context.ServiceConfiguration.OneFuzzInstanceName!), // pool must exist and be managed new Pool(poolName, Guid.NewGuid(), Os.Linux, Managed: true, Architecture.x86_64, PoolState.Running)); @@ -118,8 +89,8 @@ await Context.InsertAll( SpotInstances: false, Tags: new Dictionary()); - var func = new ScalesetFunction(Logger, auth, Context); - var result = await func.Run(TestHttpRequestData.FromJson("POST", req)); + var func = new ScalesetFunction(Logger, Context); + var result = await func.Admin(TestHttpRequestData.FromJson("POST", req)); Assert.Equal(HttpStatusCode.OK, result.StatusCode); @@ -132,17 +103,6 @@ await Context.InsertAll( [Fact] public async Async.Task Create_Scaleset_Under_NonExistent_Pool_Provides_Error() { - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - - // override the found user credentials - var userObjectId = Guid.NewGuid(); - var userInfo = new UserInfo(ApplicationId: Guid.NewGuid(), ObjectId: userObjectId, "upn"); - Context.UserCredentials = new TestUserCredentials(Logger, Context.ConfigOperations, OneFuzzResult.Ok(userInfo)); - - await Context.InsertAll( - // user must be admin - new InstanceConfig(Context.ServiceConfiguration.OneFuzzInstanceName!) { Admins = new[] { userObjectId } }); - var poolName = PoolName.Parse("nosuchpool"); // pool not created var req = new ScalesetCreate( @@ -154,8 +114,8 @@ await Context.InsertAll( SpotInstances: false, Tags: new Dictionary()); - var func = new ScalesetFunction(Logger, auth, Context); - var result = await func.Run(TestHttpRequestData.FromJson("POST", req)); + var func = new ScalesetFunction(Logger, Context); + var result = await func.Admin(TestHttpRequestData.FromJson("POST", req)); Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode); diff --git a/src/ApiService/IntegrationTests/TasksTests.cs b/src/ApiService/IntegrationTests/TasksTests.cs index 847a712153..139c3784ee 100644 --- a/src/ApiService/IntegrationTests/TasksTests.cs +++ b/src/ApiService/IntegrationTests/TasksTests.cs @@ -30,8 +30,7 @@ public TasksTestBase(ITestOutputHelper output, IStorage storage) [Fact] public async Async.Task SpecifyingVmIsNotPermitted() { - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new Tasks(Logger, auth, Context); + var func = new Tasks(Context); var req = new TaskCreate( Guid.NewGuid(), @@ -43,7 +42,8 @@ public async Async.Task SpecifyingVmIsNotPermitted() { var serialized = (JsonObject?)JsonSerializer.SerializeToNode(req, EntityConverter.GetJsonSerializerOptions()); serialized!["vm"] = new JsonObject { { "fake", 1 } }; var testData = new TestHttpRequestData("POST", new BinaryData(JsonSerializer.SerializeToUtf8Bytes(serialized, EntityConverter.GetJsonSerializerOptions()))); - var result = await func.Run(testData); + var ctx = new TestFunctionContext(); + var result = await func.Run(testData, ctx); Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode); var err = BodyAs(result); Assert.Equal("Unexpected property: \"vm\"", err.Detail); @@ -51,12 +51,11 @@ public async Async.Task SpecifyingVmIsNotPermitted() { [Fact] public async Async.Task PoolIsRequired() { - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new Tasks(Logger, auth, Context); + var func = new Tasks(Context); - // override the found user credentials - need these to check for admin - var userInfo = new UserInfo(ApplicationId: Guid.NewGuid(), ObjectId: Guid.NewGuid(), "upn"); - Context.UserCredentials = new TestUserCredentials(Logger, Context.ConfigOperations, OneFuzzResult.Ok(userInfo)); + // override the found user credentials - need these to store user + var ctx = new TestFunctionContext(); + ctx.SetUserAuthInfo(new UserInfo(ApplicationId: Guid.NewGuid(), ObjectId: Guid.NewGuid(), "upn")); var req = new TaskCreate( Guid.NewGuid(), @@ -64,7 +63,7 @@ public async Async.Task PoolIsRequired() { new TaskDetails(TaskType.DotnetCoverage, 100), null! /* <- here */); - var result = await func.Run(TestHttpRequestData.FromJson("POST", req)); + var result = await func.Run(TestHttpRequestData.FromJson("POST", req), ctx); Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode); var err = BodyAs(result); Assert.Equal("The Pool field is required.", err.Detail); @@ -72,15 +71,14 @@ public async Async.Task PoolIsRequired() { [Fact] public async Async.Task CanSearchWithJobIdAndEmptyListOfStates() { - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var req = new TaskSearch( JobId: Guid.NewGuid(), TaskId: null, State: new List()); - var func = new Tasks(Logger, auth, Context); - var result = await func.Run(TestHttpRequestData.FromJson("GET", req)); + var func = new Tasks(Context); + var ctx = new TestFunctionContext(); + var result = await func.Run(TestHttpRequestData.FromJson("GET", req), ctx); Assert.Equal(HttpStatusCode.OK, result.StatusCode); } } diff --git a/src/ApiService/IntegrationTests/ToolsTests.cs b/src/ApiService/IntegrationTests/ToolsTests.cs index 9bb3c3ee0b..9d86db7cfe 100644 --- a/src/ApiService/IntegrationTests/ToolsTests.cs +++ b/src/ApiService/IntegrationTests/ToolsTests.cs @@ -47,8 +47,7 @@ public async Async.Task CanDownload() { var r = await toolsContainerClient.UploadBlobAsync(path.ToString(), BinaryData.FromString(content.ToString())); Assert.False(r.GetRawResponse().IsError); } - var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); - var func = new Tools(auth, Context); + var func = new Tools(Context); var result = await func.Run(TestHttpRequestData.FromJson("GET", "")); Assert.Equal(HttpStatusCode.OK, result.StatusCode); diff --git a/src/ApiService/Tests/AuthTests.cs b/src/ApiService/Tests/AuthTests.cs new file mode 100644 index 0000000000..6b45ddc87b --- /dev/null +++ b/src/ApiService/Tests/AuthTests.cs @@ -0,0 +1,79 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using Castle.Core.Internal; +using Microsoft.Azure.Functions.Worker; +using Microsoft.OneFuzz.Service.Auth; +using Xunit; + +namespace Tests; + +public class AuthTests { + + public static IEnumerable AllFunctionEntryPoints() { + var asm = typeof(AuthorizeAttribute).Assembly; + foreach (var type in asm.GetTypes()) { + if (type.Namespace == "ApiService.TestHooks" + || type.Name == "TestHooks") { + // skip test hooks + continue; + } + + foreach (var method in type.GetMethods()) { + if (method.GetCustomAttribute() is not null) { + // it's a function entrypoint + yield return new object[] { type, method }; + } + } + } + } + + + [Theory] + [MemberData(nameof(AllFunctionEntryPoints))] + public void AllFunctionsHaveAuthAttributes(Type type, MethodInfo methodInfo) { + var trigger = methodInfo.GetParameters().First().GetCustomAttribute(); + if (trigger is null) { + return; // not an HTTP function + } + + // built-in auth level should be anonymous - we are implementing our own authorization + Assert.Equal(AuthorizationLevel.Anonymous, trigger.AuthLevel); + + if (type.Name == "Config" && methodInfo.Name == "Run") { + // this method alone is allowed to be anonymous + Assert.Null(methodInfo.GetAttribute()); + return; + } + + // authorize attribute can be on class or method + var authAttribute = methodInfo.GetAttribute() + ?? type.GetAttribute(); + Assert.NotNull(authAttribute); + + // naming convention: check that Agent* functions have Allow.Agent, and none other + var functionAttribute = methodInfo.GetCustomAttribute()!; + if (functionAttribute.Name.StartsWith("Agent")) { + Assert.Equal(Allow.Agent, authAttribute.Allow); + } else { + Assert.NotEqual(Allow.Agent, authAttribute.Allow); + } + + // naming convention: all *_Admin functions should be ALlow.Admin + // (some that aren't _Admin also require it) + if (functionAttribute.Name.EndsWith("_Admin")) { + Assert.Equal(Allow.Admin, authAttribute.Allow); + } + + // make sure other methods that _aren't_ function entry points don't have it, + // because it won't do anything there, and having it present would be misleading + foreach (var otherMethod in type.GetMethods()) { + if (otherMethod.GetCustomAttribute() is null) { + Assert.True( + otherMethod.GetCustomAttribute() is null, + "non-[Function] methods must not have [Authorize]"); + } + } + } +}