diff --git a/airbyte-config/persistence/src/main/java/io/airbyte/config/persistence/ConfigRepository.java b/airbyte-config/persistence/src/main/java/io/airbyte/config/persistence/ConfigRepository.java index 5f60c4ebe611..c29b84667b56 100644 --- a/airbyte-config/persistence/src/main/java/io/airbyte/config/persistence/ConfigRepository.java +++ b/airbyte-config/persistence/src/main/java/io/airbyte/config/persistence/ConfigRepository.java @@ -44,6 +44,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Optional; import java.util.UUID; import java.util.stream.Stream; @@ -230,6 +231,17 @@ public SourceOAuthParameter getSourceOAuthParams(final UUID SourceOAuthParameter return persistence.getConfig(ConfigSchema.SOURCE_OAUTH_PARAM, SourceOAuthParameterId.toString(), SourceOAuthParameter.class); } + public Optional getSourceOAuthParamByDefinitionIdOptional(final UUID workspaceId, final UUID sourceDefinitionId) + throws JsonValidationException, IOException { + for (final SourceOAuthParameter oAuthParameter : listSourceOAuthParam()) { + if (sourceDefinitionId.equals(oAuthParameter.getSourceDefinitionId()) && + Objects.equals(workspaceId, oAuthParameter.getWorkspaceId())) { + return Optional.of(oAuthParameter); + } + } + return Optional.empty(); + } + public void writeSourceOAuthParam(final SourceOAuthParameter SourceOAuthParameter) throws JsonValidationException, IOException { persistence.writeConfig(ConfigSchema.SOURCE_OAUTH_PARAM, SourceOAuthParameter.getOauthParameterId().toString(), SourceOAuthParameter); } @@ -243,6 +255,18 @@ public DestinationOAuthParameter getDestinationOAuthParams(final UUID destinatio return persistence.getConfig(ConfigSchema.DESTINATION_OAUTH_PARAM, destinationOAuthParameterId.toString(), DestinationOAuthParameter.class); } + public Optional getDestinationOAuthParamByDefinitionIdOptional(final UUID workspaceId, + final UUID destinationDefinitionId) + throws JsonValidationException, IOException { + for (final DestinationOAuthParameter oAuthParameter : listDestinationOAuthParam()) { + if (destinationDefinitionId.equals(oAuthParameter.getDestinationDefinitionId()) && + Objects.equals(workspaceId, oAuthParameter.getWorkspaceId())) { + return Optional.of(oAuthParameter); + } + } + return Optional.empty(); + } + public void writeDestinationOAuthParam(final DestinationOAuthParameter destinationOAuthParameter) throws JsonValidationException, IOException { persistence.writeConfig(ConfigSchema.DESTINATION_OAUTH_PARAM, destinationOAuthParameter.getOauthParameterId().toString(), destinationOAuthParameter); diff --git a/airbyte-server/src/main/java/io/airbyte/server/handlers/OAuthHandler.java b/airbyte-server/src/main/java/io/airbyte/server/handlers/OAuthHandler.java index c8e3bb155e18..28d97d04659c 100644 --- a/airbyte-server/src/main/java/io/airbyte/server/handlers/OAuthHandler.java +++ b/airbyte-server/src/main/java/io/airbyte/server/handlers/OAuthHandler.java @@ -95,23 +95,25 @@ public Map completeDestinationOAuth(CompleteDestinationOAuthRequ oauthDestinationRequestBody.getRedirectUrl()); } + public void setSourceInstancewideOauthParams(SetInstancewideSourceOauthParamsRequestBody requestBody) throws JsonValidationException, IOException { + final SourceOAuthParameter param = configRepository + .getSourceOAuthParamByDefinitionIdOptional(null, requestBody.getSourceDefinitionId()) + .orElseGet(() -> new SourceOAuthParameter().withOauthParameterId(UUID.randomUUID())) + .withConfiguration(Jsons.jsonNode(requestBody.getParams())) + .withSourceDefinitionId(requestBody.getSourceDefinitionId()); + configRepository.writeSourceOAuthParam(param); + } + public void setDestinationInstancewideOauthParams(SetInstancewideDestinationOauthParamsRequestBody requestBody) throws JsonValidationException, IOException { - DestinationOAuthParameter param = new DestinationOAuthParameter() - .withOauthParameterId(UUID.randomUUID()) + final DestinationOAuthParameter param = configRepository + .getDestinationOAuthParamByDefinitionIdOptional(null, requestBody.getDestinationDefinitionId()) + .orElseGet(() -> new DestinationOAuthParameter().withOauthParameterId(UUID.randomUUID())) .withConfiguration(Jsons.jsonNode(requestBody.getParams())) .withDestinationDefinitionId(requestBody.getDestinationDefinitionId()); configRepository.writeDestinationOAuthParam(param); } - public void setSourceInstancewideOauthParams(SetInstancewideSourceOauthParamsRequestBody requestBody) throws JsonValidationException, IOException { - SourceOAuthParameter param = new SourceOAuthParameter() - .withOauthParameterId(UUID.randomUUID()) - .withConfiguration(Jsons.jsonNode(requestBody.getParams())) - .withSourceDefinitionId(requestBody.getSourceDefinitionId()); - configRepository.writeSourceOAuthParam(param); - } - private OAuthFlowImplementation getSourceOAuthFlowImplementation(UUID sourceDefinitionId) throws JsonValidationException, ConfigNotFoundException, IOException { final StandardSourceDefinition standardSourceDefinition = configRepository diff --git a/airbyte-server/src/test/java/io/airbyte/server/handlers/OAuthHandlerTest.java b/airbyte-server/src/test/java/io/airbyte/server/handlers/OAuthHandlerTest.java index 34f883c3ff74..9cb62ce33685 100644 --- a/airbyte-server/src/test/java/io/airbyte/server/handlers/OAuthHandlerTest.java +++ b/airbyte-server/src/test/java/io/airbyte/server/handlers/OAuthHandlerTest.java @@ -25,6 +25,7 @@ package io.airbyte.server.handlers; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.Mockito.when; import io.airbyte.api.model.SetInstancewideDestinationOauthParamsRequestBody; import io.airbyte.api.model.SetInstancewideSourceOauthParamsRequestBody; @@ -35,7 +36,9 @@ import io.airbyte.validation.json.JsonValidationException; import java.io.IOException; import java.util.HashMap; +import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.UUID; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -53,6 +56,58 @@ public void init() { handler = new OAuthHandler(configRepository); } + @Test + void setSourceInstancewideOauthParams() throws JsonValidationException, IOException { + UUID sourceDefId = UUID.randomUUID(); + Map params = new HashMap<>(); + params.put("client_id", "123"); + params.put("client_secret", "hunter2"); + + SetInstancewideSourceOauthParamsRequestBody actualRequest = new SetInstancewideSourceOauthParamsRequestBody() + .sourceDefinitionId(sourceDefId) + .params(params); + + handler.setSourceInstancewideOauthParams(actualRequest); + + ArgumentCaptor argument = ArgumentCaptor.forClass(SourceOAuthParameter.class); + Mockito.verify(configRepository).writeSourceOAuthParam(argument.capture()); + assertEquals(Jsons.jsonNode(params), argument.getValue().getConfiguration()); + assertEquals(sourceDefId, argument.getValue().getSourceDefinitionId()); + } + + @Test + void resetSourceInstancewideOauthParams() throws JsonValidationException, IOException { + UUID sourceDefId = UUID.randomUUID(); + Map firstParams = new HashMap<>(); + firstParams.put("client_id", "123"); + firstParams.put("client_secret", "hunter2"); + SetInstancewideSourceOauthParamsRequestBody firstRequest = new SetInstancewideSourceOauthParamsRequestBody() + .sourceDefinitionId(sourceDefId) + .params(firstParams); + handler.setSourceInstancewideOauthParams(firstRequest); + + final UUID oauthParameterId = UUID.randomUUID(); + when(configRepository.getSourceOAuthParamByDefinitionIdOptional(null, sourceDefId)) + .thenReturn(Optional.of(new SourceOAuthParameter().withOauthParameterId(oauthParameterId))); + + Map secondParams = new HashMap<>(); + secondParams.put("client_id", "456"); + secondParams.put("client_secret", "hunter3"); + SetInstancewideSourceOauthParamsRequestBody secondRequest = new SetInstancewideSourceOauthParamsRequestBody() + .sourceDefinitionId(sourceDefId) + .params(secondParams); + handler.setSourceInstancewideOauthParams(secondRequest); + + ArgumentCaptor argument = ArgumentCaptor.forClass(SourceOAuthParameter.class); + Mockito.verify(configRepository, Mockito.times(2)).writeSourceOAuthParam(argument.capture()); + List capturedValues = argument.getAllValues(); + assertEquals(Jsons.jsonNode(firstParams), capturedValues.get(0).getConfiguration()); + assertEquals(Jsons.jsonNode(secondParams), capturedValues.get(1).getConfiguration()); + assertEquals(sourceDefId, capturedValues.get(0).getSourceDefinitionId()); + assertEquals(sourceDefId, capturedValues.get(1).getSourceDefinitionId()); + assertEquals(oauthParameterId, capturedValues.get(1).getOauthParameterId()); + } + @Test void setDestinationInstancewideOauthParams() throws JsonValidationException, IOException { UUID destinationDefId = UUID.randomUUID(); @@ -73,22 +128,36 @@ void setDestinationInstancewideOauthParams() throws JsonValidationException, IOE } @Test - void setSourceInstancewideOauthParams() throws JsonValidationException, IOException { - UUID sourceDefId = UUID.randomUUID(); - Map params = new HashMap<>(); - params.put("client_id", "123"); - params.put("client_secret", "hunter2"); + void resetDestinationInstancewideOauthParams() throws JsonValidationException, IOException { + UUID destinationDefId = UUID.randomUUID(); + Map firstParams = new HashMap<>(); + firstParams.put("client_id", "123"); + firstParams.put("client_secret", "hunter2"); + SetInstancewideDestinationOauthParamsRequestBody firstRequest = new SetInstancewideDestinationOauthParamsRequestBody() + .destinationDefinitionId(destinationDefId) + .params(firstParams); + handler.setDestinationInstancewideOauthParams(firstRequest); - SetInstancewideSourceOauthParamsRequestBody actualRequest = new SetInstancewideSourceOauthParamsRequestBody() - .sourceDefinitionId(sourceDefId) - .params(params); + final UUID oauthParameterId = UUID.randomUUID(); + when(configRepository.getDestinationOAuthParamByDefinitionIdOptional(null, destinationDefId)) + .thenReturn(Optional.of(new DestinationOAuthParameter().withOauthParameterId(oauthParameterId))); - handler.setSourceInstancewideOauthParams(actualRequest); + Map secondParams = new HashMap<>(); + secondParams.put("client_id", "456"); + secondParams.put("client_secret", "hunter3"); + SetInstancewideDestinationOauthParamsRequestBody secondRequest = new SetInstancewideDestinationOauthParamsRequestBody() + .destinationDefinitionId(destinationDefId) + .params(secondParams); + handler.setDestinationInstancewideOauthParams(secondRequest); - ArgumentCaptor argument = ArgumentCaptor.forClass(SourceOAuthParameter.class); - Mockito.verify(configRepository).writeSourceOAuthParam(argument.capture()); - assertEquals(Jsons.jsonNode(params), argument.getValue().getConfiguration()); - assertEquals(sourceDefId, argument.getValue().getSourceDefinitionId()); + ArgumentCaptor argument = ArgumentCaptor.forClass(DestinationOAuthParameter.class); + Mockito.verify(configRepository, Mockito.times(2)).writeDestinationOAuthParam(argument.capture()); + List capturedValues = argument.getAllValues(); + assertEquals(Jsons.jsonNode(firstParams), capturedValues.get(0).getConfiguration()); + assertEquals(Jsons.jsonNode(secondParams), capturedValues.get(1).getConfiguration()); + assertEquals(destinationDefId, capturedValues.get(0).getDestinationDefinitionId()); + assertEquals(destinationDefId, capturedValues.get(1).getDestinationDefinitionId()); + assertEquals(oauthParameterId, capturedValues.get(1).getOauthParameterId()); } }