Skip to content
This repository has been archived by the owner on Nov 1, 2023. It is now read-only.

Commit

Permalink
Store authentication info in keyvault (#3127)
Browse files Browse the repository at this point in the history
* Store authentication info in keyvault

* fix tests

* fix tests

* fix test

* fix build

* test fix

* more fix

* format

* fix test

* fix test

* build

* cleanup

* build fix

* test fix

* catch exception when secret does not exist

* more cleanup

* fix tests

* cleanup

* address comments

* more null check
  • Loading branch information
chkeita authored Jun 6, 2023
1 parent 9aa2519 commit b44cff5
Show file tree
Hide file tree
Showing 27 changed files with 463 additions and 179 deletions.
26 changes: 21 additions & 5 deletions src/ApiService/ApiService/Functions/ReproVmss.cs
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,17 @@ private async Async.Task<HttpResponseData> Get(HttpRequestData req) {
if (vm == null) {
return await _context.RequestHandling.NotOk(req, Error.Create(ErrorCode.INVALID_REQUEST, "no such VM"), $"{request.OkV.VmId}");
}
var auth = await _context.SecretsOperations.GetSecretValue<Authentication>(vm.Auth);

if (auth == null) {
return await _context.RequestHandling.NotOk(req, Error.Create(ErrorCode.INVALID_REQUEST, "no auth info for the VM"), $"{request.OkV.VmId}");
}
var response = req.CreateResponse(HttpStatusCode.OK);
await response.WriteAsJsonAsync(vm);
await response.WriteAsJsonAsync(ReproVmResponse.FromRepro(vm, auth));
return response;
}

var vms = _context.ReproOperations.SearchStates(VmStateHelper.Available).Select(vm => vm with { Auth = null });
var vms = _context.ReproOperations.SearchStates(VmStateHelper.Available);
var response2 = req.CreateResponse(HttpStatusCode.OK);
await response2.WriteAsJsonAsync(vms);
return response2;
Expand Down Expand Up @@ -83,7 +87,15 @@ private async Async.Task<HttpResponseData> Post(HttpRequestData req) {
"repro_vm create");
}

// we’d like to track the usage of this feature;
var auth = await _context.SecretsOperations.GetSecretValue<Authentication>(vm.OkV.Auth);
if (auth is null) {
return await _context.RequestHandling.NotOk(
req,
Error.Create(ErrorCode.INVALID_REQUEST, "unable to find auth"),
"repro_vm create");
}

// we’d like to track the usage of this feature;
// anonymize the user ID so we can distinguish multiple requests
{
var data = userInfo.OkV.UserInfo.ToString(); // rely on record ToString
Expand All @@ -92,7 +104,8 @@ private async Async.Task<HttpResponseData> Post(HttpRequestData req) {
}

var response = req.CreateResponse(HttpStatusCode.OK);
await response.WriteAsJsonAsync(vm.OkV);

await response.WriteAsJsonAsync(ReproVmResponse.FromRepro(vm.OkV, auth));
return response;
}

Expand Down Expand Up @@ -127,8 +140,11 @@ private async Async.Task<HttpResponseData> Delete(HttpRequestData req) {
_log.WithHttpStatus(r.ErrorV).Error($"Failed to replace repro {updatedRepro.VmId:Tag:VmId}");
}

if (vm.Auth != null) {
await _context.SecretsOperations.DeleteSecret(vm.Auth);
}
var response = req.CreateResponse(HttpStatusCode.OK);
await response.WriteAsJsonAsync(updatedRepro);
await response.WriteAsJsonAsync(ReproVmResponse.FromRepro(vm, new Authentication("", "", "")));
return response;
}
}
18 changes: 13 additions & 5 deletions src/ApiService/ApiService/Functions/Scaleset.cs
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ private async Task<HttpResponseData> Post(HttpRequestData req) {
ScalesetId: Service.Scaleset.GenerateNewScalesetId(create.PoolName),
State: ScalesetState.Init,
NeedsConfigUpdate: false,
Auth: await Auth.BuildAuth(_log),
Auth: new SecretValue<Authentication>(await Auth.BuildAuth(_log)),
PoolName: create.PoolName,
VmSku: create.VmSku,
Image: image,
Expand Down Expand Up @@ -161,7 +161,8 @@ private async Task<HttpResponseData> Post(HttpRequestData req) {
}

// auth not included on create results, only GET with include_auth set
var response = ScalesetResponse.ForScaleset(scaleset, includeAuth: false);

var response = ScalesetResponse.ForScaleset(scaleset, null);
return await RequestHandling.Ok(req, response);
}

Expand Down Expand Up @@ -195,7 +196,7 @@ private async Task<HttpResponseData> Patch(HttpRequestData req) {
scaleset = await _context.ScalesetOperations.SetSize(scaleset, size);
}

var response = ScalesetResponse.ForScaleset(scaleset, includeAuth: false);
var response = ScalesetResponse.ForScaleset(scaleset, null);
return await RequestHandling.Ok(req, response);
}

Expand All @@ -214,15 +215,22 @@ private async Task<HttpResponseData> Get(HttpRequestData req) {

var scaleset = scalesetResult.OkV;

var response = ScalesetResponse.ForScaleset(scaleset, includeAuth: search.IncludeAuth);
Authentication? auth;
auth = scaleset.Auth == null
? null
: search.IncludeAuth
? await _context.SecretsOperations.GetSecretValue<Authentication>(scaleset.Auth)
: null;

var response = ScalesetResponse.ForScaleset(scaleset, auth);
response = response with { Nodes = await _context.ScalesetOperations.GetNodes(scaleset) };
return await RequestHandling.Ok(req, response);
}

var states = search.State ?? Enumerable.Empty<ScalesetState>();
var scalesets = await _context.ScalesetOperations.SearchStates(states).ToListAsync();
// don't return auths during list actions, only 'get'
var result = scalesets.Select(ss => ScalesetResponse.ForScaleset(ss, includeAuth: false));
var result = scalesets.Select(ss => ScalesetResponse.ForScaleset(ss));
return await RequestHandling.Ok(req, result);
}
}
4 changes: 3 additions & 1 deletion src/ApiService/ApiService/Functions/Tasks.cs
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,16 @@ private async Async.Task<HttpResponseData> Get(HttpRequestData req) {
_context.NodeTasksOperations.GetNodeAssignments(taskId).ToListAsync().AsTask(),
_context.TaskEventOperations.GetSummary(taskId).ToListAsync().AsTask());

var auth = task.Auth == null ? null : await _context.SecretsOperations.GetSecretValue(task.Auth);

var result = new TaskSearchResult(
JobId: task.JobId,
TaskId: task.TaskId,
State: task.State,
Os: task.Os,
Config: task.Config,
Error: task.Error,
Auth: task.Auth,
Auth: auth,
Heartbeat: task.Heartbeat,
EndTime: task.EndTime,
UserInfo: task.UserInfo,
Expand Down
45 changes: 37 additions & 8 deletions src/ApiService/ApiService/OneFuzzTypes/Model.cs
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ public record Proxy
[RowKey] Guid ProxyId,
DateTimeOffset? CreatedTimestamp,
VmState State,
Authentication Auth,
ISecret<Authentication> Auth,
string? Ip,
Error? Error,
string Version,
Expand Down Expand Up @@ -282,7 +282,7 @@ public record Task(
Os Os,
TaskConfig Config,
Error? Error = null,
Authentication? Auth = null,
ISecret<Authentication>? Auth = null,
DateTimeOffset? Heartbeat = null,
DateTimeOffset? EndTime = null,
UserInfo? UserInfo = null) : StatefulEntityBase<TaskState>(State) {
Expand Down Expand Up @@ -422,7 +422,7 @@ public partial record Scaleset(
bool EphemeralOsDisks,
bool NeedsConfigUpdate,
Dictionary<string, string> Tags,
Authentication? Auth = null,
ISecret<Authentication>? Auth = null,
Error? Error = null,
Guid? ClientId = null,
Guid? ClientObjectId = null
Expand Down Expand Up @@ -718,7 +718,7 @@ public record Repro(
[PartitionKey][RowKey] Guid VmId,
Guid TaskId,
ReproConfig Config,
Authentication? Auth,
ISecret<Authentication> Auth,
Os Os,
VmState State = VmState.Init,
Error? Error = null,
Expand Down Expand Up @@ -788,15 +788,23 @@ public record Vm(
Region Region,
string Sku,
ImageReference Image,
Authentication Auth,
ISecret<Authentication> Auth,
Nsg? Nsg,
IDictionary<string, string>? Tags
) {
public string Name { get; } = Name.Length > 40 ? throw new ArgumentOutOfRangeException("VM name too long") : Name;
};


public interface ISecret {
[JsonIgnore]
bool IsHIddden { get; }
[JsonIgnore]
Uri? Uri { get; }
string? GetValue();
}
[JsonConverter(typeof(ISecretConverterFactory))]
public interface ISecret<T> { }
public interface ISecret<T> : ISecret { }

public class ISecretConverterFactory : JsonConverterFactory {
public override bool CanConvert(Type typeToConvert) {
Expand Down Expand Up @@ -841,9 +849,30 @@ public override void Write(Utf8JsonWriter writer, ISecret<T> value, JsonSerializ



public record SecretValue<T>(T Value) : ISecret<T>;
public record SecretValue<T>(T Value) : ISecret<T> {
[JsonIgnore]
public bool IsHIddden => false;
[JsonIgnore]
public Uri? Uri => null;

public string? GetValue() {
if (Value is string secretString) {
return secretString.Trim();
}

return JsonSerializer.Serialize(Value, EntityConverter.GetJsonSerializerOptions());
}
}

public record SecretAddress<T>(Uri Url) : ISecret<T> {
[JsonIgnore]
public Uri? Uri => Url;
[JsonIgnore]
public bool IsHIddden => true;
public string? GetValue() => null;

public record SecretAddress<T>(Uri Url) : ISecret<T>;

}

public record SecretData<T>(ISecret<T> Secret) {
}
Expand Down
34 changes: 32 additions & 2 deletions src/ApiService/ApiService/OneFuzzTypes/Responses.cs
Original file line number Diff line number Diff line change
Expand Up @@ -139,12 +139,12 @@ public record ScalesetResponse(
Dictionary<string, string> Tags,
List<ScalesetNodeState>? Nodes
) : BaseResponse() {
public static ScalesetResponse ForScaleset(Scaleset s, bool includeAuth)
public static ScalesetResponse ForScaleset(Scaleset s, Authentication? auth = null)
=> new(
PoolName: s.PoolName,
ScalesetId: s.ScalesetId,
State: s.State,
Auth: includeAuth ? s.Auth : null,
Auth: auth,
VmSku: s.VmSku,
Image: s.Image,
Region: s.Region,
Expand Down Expand Up @@ -220,3 +220,33 @@ public record NotificationTestResponse(
bool Success,
string? Error = null
) : BaseResponse();


public record ReproVmResponse(
Guid VmId,
Guid TaskId,
ReproConfig Config,
Authentication? Auth,
Os Os,
VmState State = VmState.Init,
Error? Error = null,
string? Ip = null,
DateTimeOffset? EndTime = null,
UserInfo? UserInfo = null
) : BaseResponse() {

public static ReproVmResponse FromRepro(Repro repro, Authentication? auth) {
return new ReproVmResponse(
VmId: repro.VmId,
TaskId: repro.TaskId,
Config: repro.Config,
Auth: auth,
Os: repro.Os,
State: repro.State,
Error: repro.Error,
Ip: repro.Ip,
EndTime: repro.EndTime,
UserInfo: repro.UserInfo
);
}
}
16 changes: 4 additions & 12 deletions src/ApiService/ApiService/TestHooks/TestHooks.cs
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,6 @@ public async Task<HttpResponseData> Info([HttpTrigger(AuthorizationLevel.Anonymo
}


[Function("GetKeyvaultAddress")]
public async Task<HttpResponseData> GetKeyVaultAddress([HttpTrigger(AuthorizationLevel.Anonymous, "get", Route = "testhooks/secrets/keyvaultaddress")] HttpRequestData req) {
_log.Info($"Getting keyvault address");
var addr = _secretOps.GetKeyvaultAddress();
var resp = req.CreateResponse(HttpStatusCode.OK);
await resp.WriteAsJsonAsync(addr);
return resp;
}

[Function("SaveToKeyvault")]
public async Task<HttpResponseData> SaveToKeyvault([HttpTrigger(AuthorizationLevel.Anonymous, "post", Route = "testhooks/secrets/keyvault")] HttpRequestData req) {
Expand All @@ -60,10 +52,10 @@ public async Task<HttpResponseData> SaveToKeyvault([HttpTrigger(AuthorizationLev
return req.CreateResponse(HttpStatusCode.BadRequest);
} else {
_log.Info($"Saving secret data in the keyvault");
var r = await _secretOps.SaveToKeyvault(secretData);
var addr = _secretOps.GetKeyvaultAddress();
var r = await _secretOps.StoreSecretData(secretData);

var resp = req.CreateResponse(HttpStatusCode.OK);
await resp.WriteAsJsonAsync(addr);
await resp.WriteAsJsonAsync((r.Secret as SecretAddress<string>)?.Url);
return resp;
}
}
Expand All @@ -79,7 +71,7 @@ from cs in queryComponents
select new KeyValuePair<string, string>(Uri.UnescapeDataString(cs.Substring(0, i)), Uri.UnescapeDataString(cs.Substring(i + 1)));

var qs = new Dictionary<string, string>(q);
var d = await _secretOps.GetSecretStringValue(new SecretData<string>(new SecretValue<string>(qs["SecretName"])));
var d = await _secretOps.GetSecretValue(new SecretValue<string>(qs["SecretName"]));

var resp = req.CreateResponse(HttpStatusCode.OK);
await resp.WriteAsJsonAsync(d);
Expand Down
6 changes: 5 additions & 1 deletion src/ApiService/ApiService/onefuzzlib/Extension.cs
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,11 @@ public static VMExtensionWrapper GenevaExtension(AzureLocation region) {
var sep = pool.Os == Os.Windows ? "\r\n" : "\n";

if (pool.Os == Os.Windows && scaleSet.Auth is not null) {
var sshKey = scaleSet.Auth.PublicKey.Trim();
var auth = await _context.SecretsOperations.GetSecretValue<Authentication>(scaleSet.Auth);
if (auth is null) {
throw new Exception($"unable to retrieve auth: {scaleSet.Auth}");
}
var sshKey = auth.PublicKey.Trim();
var sshPath = "$env:ProgramData/ssh/administrators_authorized_keys";
commands.Add($"Set-Content -Path {sshPath} -Value \"{sshKey}\"");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,13 +141,13 @@ private async Async.Task<NotificationTemplate> HideSecrets(NotificationTemplate

switch (notificationTemplate) {
case AdoTemplate adoTemplate:
var hiddenAuthToken = await _context.SecretsOperations.SaveToKeyvault(adoTemplate.AuthToken);
var hiddenAuthToken = await _context.SecretsOperations.StoreSecretData(adoTemplate.AuthToken);
return adoTemplate with { AuthToken = hiddenAuthToken };
case GithubIssuesTemplate githubIssuesTemplate:
var hiddenAuth = await _context.SecretsOperations.SaveToKeyvault(githubIssuesTemplate.Auth);
var hiddenAuth = await _context.SecretsOperations.StoreSecretData(githubIssuesTemplate.Auth);
return githubIssuesTemplate with { Auth = hiddenAuth };
case TeamsTemplate teamsTemplate:
var hiddenUrl = await _context.SecretsOperations.SaveToKeyvault(teamsTemplate.Url);
var hiddenUrl = await _context.SecretsOperations.StoreSecretData(teamsTemplate.Url);
return teamsTemplate with { Url = hiddenUrl };
default:
throw new ArgumentOutOfRangeException(nameof(notificationTemplate));
Expand Down
2 changes: 1 addition & 1 deletion src/ApiService/ApiService/onefuzzlib/ProxyOperations.cs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ public async Async.Task<Proxy> GetOrCreate(Region region) {
}

_logTracer.Info($"creating proxy: region:{region:Tag:Region}");
var newProxy = new Proxy(region, Guid.NewGuid(), DateTimeOffset.UtcNow, VmState.Init, await Auth.BuildAuth(_logTracer), null, null, _context.ServiceConfiguration.OneFuzzVersion, null, false);
var newProxy = new Proxy(region, Guid.NewGuid(), DateTimeOffset.UtcNow, VmState.Init, new SecretValue<Authentication>(await Auth.BuildAuth(_logTracer)), null, null, _context.ServiceConfiguration.OneFuzzVersion, null, false);

var r = await Replace(newProxy);
if (!r.IsOk) {
Expand Down
16 changes: 10 additions & 6 deletions src/ApiService/ApiService/onefuzzlib/ReproOperations.cs
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,6 @@ await _context.Creds.GetBaseRegion(),
);
}

if (repro.Auth == null) {
throw new Exception("missing auth");
}

return new Vm(
repro.VmId.ToString(),
vmConfig.Region,
Expand Down Expand Up @@ -260,12 +256,18 @@ public async Task<OneFuzzResultVoid> BuildReproScript(Repro repro) {
}

var files = new Dictionary<string, string>();
var auth = await _context.SecretsOperations.GetSecretValue(repro.Auth);

if (auth == null) {
return OneFuzzResultVoid.Error(ErrorCode.VM_CREATE_FAILED, "unable to fetch auth secret");
}

switch (task.Os) {
case Os.Windows:
var sshPath = "$env:ProgramData/ssh/administrators_authorized_keys";
var cmds = new List<string>()
{
$"Set-Content -Path {sshPath} -Value \"{repro.Auth.PublicKey}\"",
$"Set-Content -Path {sshPath} -Value \"{auth.PublicKey}\"",
". C:\\onefuzz\\tools\\win64\\onefuzz.ps1",
"Set-SetSSHACL",
$"while (1) {{ cdb -server tcp:port=1337 -c \"g\" setup\\{task.Config.Task.TargetExe} {report?.InputBlob?.Name} }}"
Expand Down Expand Up @@ -333,12 +335,14 @@ public async Task<OneFuzzResult<Repro>> Create(ReproConfig config, UserInfo user
return OneFuzzResult<Repro>.Error(ErrorCode.INVALID_REQUEST, "unable to find task");
}

var auth = await _context.SecretsOperations.StoreSecret(new SecretValue<Authentication>(await Auth.BuildAuth(_logTracer)));

var vm = new Repro(
VmId: Guid.NewGuid(),
Config: config,
TaskId: task.TaskId,
Os: task.Os,
Auth: await Auth.BuildAuth(_logTracer),
Auth: new SecretAddress<Authentication>(auth),
EndTime: DateTimeOffset.UtcNow + TimeSpan.FromHours(config.Duration),
UserInfo: userInfo);

Expand Down
Loading

0 comments on commit b44cff5

Please sign in to comment.