Skip to content

Commit

Permalink
Development: Fix LTI on multi node systems (#9085)
Browse files Browse the repository at this point in the history
  • Loading branch information
Strohgelaender authored Aug 4, 2024
1 parent e9f6878 commit 8cc18bb
Show file tree
Hide file tree
Showing 9 changed files with 173 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,18 @@ public class SecurityConfiguration {

private final ProfileService profileService;

private final Optional<CustomLti13Configurer> customLti13Configurer;

@Value("#{'${spring.prometheus.monitoringIp:127.0.0.1}'.split(',')}")
private List<String> monitoringIpAddresses;

public SecurityConfiguration(TokenProvider tokenProvider, PasswordService passwordService, CorsFilter corsFilter, ProfileService profileService) {
public SecurityConfiguration(TokenProvider tokenProvider, PasswordService passwordService, CorsFilter corsFilter, ProfileService profileService,
Optional<CustomLti13Configurer> customLti13Configurer) {
this.tokenProvider = tokenProvider;
this.passwordService = passwordService;
this.corsFilter = corsFilter;
this.profileService = profileService;
this.customLti13Configurer = customLti13Configurer;
}

/**
Expand Down Expand Up @@ -224,7 +228,8 @@ public SecurityFilterChain filterChain(HttpSecurity http, SecurityProblemSupport

// Conditionally adds configuration for LTI if it is active.
if (profileService.isLtiActive()) {
http.with(new CustomLti13Configurer(), configurer -> configurer.configure(http));
// Activates the LTI endpoints and filters.
http.with(customLti13Configurer.orElseThrow(), configurer -> configurer.configure(http));
}

// Builds and returns the SecurityFilterChain.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,28 +1,27 @@
package de.tum.in.www1.artemis.config.lti;

import java.time.Duration;

import org.springframework.context.ApplicationContext;
import org.springframework.context.annotation.Profile;
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
import org.springframework.security.oauth2.client.web.HttpSessionOAuth2AuthorizationRequestRepository;
import org.springframework.security.web.authentication.logout.LogoutFilter;
import org.springframework.security.web.authentication.preauth.AbstractPreAuthenticatedProcessingFilter;
import org.springframework.stereotype.Component;

import de.tum.in.www1.artemis.service.OnlineCourseConfigurationService;
import de.tum.in.www1.artemis.service.connectors.lti.Lti13Service;
import de.tum.in.www1.artemis.web.filter.Lti13LaunchFilter;
import uk.ac.ox.ctl.lti13.Lti13Configurer;
import uk.ac.ox.ctl.lti13.security.oauth2.client.lti.authentication.OidcLaunchFlowAuthenticationProvider;
import uk.ac.ox.ctl.lti13.security.oauth2.client.lti.web.HttpSessionOAuth2AuthorizationRequestRepository;
import uk.ac.ox.ctl.lti13.security.oauth2.client.lti.web.OAuth2LoginAuthenticationFilter;
import uk.ac.ox.ctl.lti13.security.oauth2.client.lti.web.OptimisticAuthorizationRequestRepository;
import uk.ac.ox.ctl.lti13.security.oauth2.client.lti.web.StateAuthorizationRequestRepository;

/**
* Configures and registers Security Filters to handle LTI 1.3 Resource Link Launches
*/
@Profile("lti")
@Component
public class CustomLti13Configurer extends Lti13Configurer {

/** Path for login. **/
Expand Down Expand Up @@ -52,10 +51,13 @@ public class CustomLti13Configurer extends Lti13Configurer {
/** Value for LTI 1.3 deep linking request message. */
public static final String LTI13_DEEPLINK_MESSAGE_REQUEST = "LtiDeepLinkingRequest";

public CustomLti13Configurer() {
private final DistributedStateAuthorizationRequestRepository stateRepository;

public CustomLti13Configurer(DistributedStateAuthorizationRequestRepository stateRepository) {
super.ltiPath("/" + LTI13_BASE_PATH);
super.loginInitiationPath(LOGIN_INITIATION_PATH);
super.loginPath(LOGIN_PATH);
this.stateRepository = stateRepository;
}

@Override
Expand Down Expand Up @@ -83,10 +85,16 @@ protected ClientRegistrationRepository clientRegistrationRepository(HttpSecurity
return http.getSharedObject(ApplicationContext.class).getBean(OnlineCourseConfigurationService.class);
}

/**
* Configures and returns an {@link StateBasedOptimisticAuthorizationRequestRepository} for handling OAuth2 authorization requests.
* This method sets up a multinode-distributed state repository for managing authorization requests using Hazelcast.
* This is necessary to support LTI on multinode systems where the different requests might get processed by different nodes.
*
* @return An instance of {@link StateBasedOptimisticAuthorizationRequestRepository} that combines session-based and distributed state management.
*/
@Override
protected OptimisticAuthorizationRequestRepository configureRequestRepository() {
HttpSessionOAuth2AuthorizationRequestRepository sessionRepository = new HttpSessionOAuth2AuthorizationRequestRepository();
StateAuthorizationRequestRepository stateRepository = new StateAuthorizationRequestRepository(Duration.ofMinutes(1));
stateRepository.setLimitIpAddress(limitIpAddresses);
return new StateBasedOptimisticAuthorizationRequestRepository(sessionRepository, stateRepository);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
package de.tum.in.www1.artemis.config.lti;

import java.util.Map;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit;

import jakarta.annotation.PostConstruct;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.context.annotation.Profile;
import org.springframework.security.oauth2.client.web.AuthorizationRequestRepository;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.stereotype.Component;
import org.springframework.util.Assert;

import com.hazelcast.core.HazelcastInstance;

/**
* A specialized {@link AuthorizationRequestRepository} that uses Hazelcast to store OAuth2 authorization requests.
* This allows for sharing state across multiple nodes.
* <p>
* This is based on a copy of {@link uk.ac.ox.ctl.lti13.security.oauth2.client.lti.web.StateAuthorizationRequestRepository}.
*/
@Component
@Profile("lti")
class DistributedStateAuthorizationRequestRepository implements AuthorizationRequestRepository<OAuth2AuthorizationRequest> {

private static final Logger log = LoggerFactory.getLogger(DistributedStateAuthorizationRequestRepository.class);

/**
* Executor for delayed tasks, here used to remove authorization requests after a timeout.
*/
private final Executor delayedExecutor = CompletableFuture.delayedExecutor(2L, TimeUnit.MINUTES);

private final HazelcastInstance hazelcastInstance;

private Map<String, OAuth2AuthorizationRequest> store;

/**
* Should we limit the login to a single IP address.
* This may cause problems when users are on mobile devices and subsequent requests don't use the same IP address.
*/
private boolean limitIpAddress = true;

DistributedStateAuthorizationRequestRepository(HazelcastInstance hazelcastInstance) {
this.hazelcastInstance = hazelcastInstance;
}

@PostConstruct
void init() {
this.store = hazelcastInstance.getMap("ltiStateAuthorizationRequestStore");
}

public void setLimitIpAddress(boolean limitIpAddress) {
this.limitIpAddress = limitIpAddress;
}

@Override
public OAuth2AuthorizationRequest loadAuthorizationRequest(HttpServletRequest request) {
log.info("Loading authorization request from distributed store");
Objects.requireNonNull(request, "request cannot be null");
String stateParameter = request.getParameter("state");
if (stateParameter == null) {
return null;
}
OAuth2AuthorizationRequest oAuth2AuthorizationRequest = this.store.get(stateParameter);
if (oAuth2AuthorizationRequest == null) {
return null;
}

String initialIp = oAuth2AuthorizationRequest.getAttribute("remote_ip");
if (initialIp != null) {
String requestIp = request.getRemoteAddr();
if (!initialIp.equals(requestIp)) {
log.info("IP mismatch detected. Initial IP: {}, Request IP: {}.", initialIp, requestIp);
if (this.limitIpAddress) {
return null;
}
}
}

return oAuth2AuthorizationRequest;
}

@Override
public void saveAuthorizationRequest(OAuth2AuthorizationRequest authorizationRequest, HttpServletRequest request, HttpServletResponse response) {
log.info("Saving authorization request to distributed store");
Objects.requireNonNull(request, "request cannot be null");
Objects.requireNonNull(response, "response cannot be null");
if (authorizationRequest == null) {
this.removeAuthorizationRequest(request, response);
}
else {
String state = authorizationRequest.getState();
Assert.hasText(state, "authorizationRequest.state cannot be empty");
this.store.put(state, authorizationRequest);
// Remove request after timeout
delayedExecutor.execute(() -> this.store.remove(state));
}
}

@Override
public OAuth2AuthorizationRequest removeAuthorizationRequest(HttpServletRequest request, HttpServletResponse response) {
log.info("Removing authorization request from distributed store");
OAuth2AuthorizationRequest authorizationRequest = this.loadAuthorizationRequest(request);
if (authorizationRequest != null) {
String stateParameter = request.getParameter("state");
this.store.remove(stateParameter);
}

return authorizationRequest;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ public JWTCookieService(TokenProvider tokenProvider, Environment environment) {
* Builds the cookie containing the jwt for a login
*
* @param rememberMe boolean used to determine the duration of the jwt.
* @return the login ResponseCookie contaning the JWT
* @return the login ResponseCookie containing the JWT
*/
public ResponseCookie buildLoginCookie(boolean rememberMe) {
String jwt = tokenProvider.createToken(SecurityContextHolder.getContext().getAuthentication(), rememberMe);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ public class Lti13TokenRetriever {

private static final Logger log = LoggerFactory.getLogger(Lti13TokenRetriever.class);

private static final int JWT_LIFETIME = 60;
public static final int JWT_LIFETIME_SECONDS = 60;

public Lti13TokenRetriever(OAuth2JWKSService keyPairService, RestTemplate restTemplate) {
this.oAuth2JWKSService = keyPairService;
Expand Down Expand Up @@ -122,7 +122,8 @@ public String createDeepLinkingJWT(String clientRegistrationId, Map<String, Obje
claimSetBuilder.claim(entry.getKey(), entry.getValue());
}

JWTClaimsSet claimsSet = claimSetBuilder.issueTime(Date.from(Instant.now())).expirationTime(Date.from(Instant.now().plusSeconds(JWT_LIFETIME))).build();
var now = Instant.now();
JWTClaimsSet claimsSet = claimSetBuilder.issueTime(Date.from(now)).expirationTime(Date.from(now.plusSeconds(JWT_LIFETIME_SECONDS))).build();

JWSHeader jwt = new JWSHeader.Builder(JWSAlgorithm.RS256).type(JOSEObjectType.JWT).keyID(jwk.getKeyID()).build();
SignedJWT signedJWT = new SignedJWT(jwt, claimsSet);
Expand All @@ -148,9 +149,10 @@ private SignedJWT createJWT(ClientRegistration clientRegistration) {
KeyPair keyPair = jwk.toRSAKey().toKeyPair();
RSASSASigner signer = new RSASSASigner(keyPair.getPrivate());

var now = Instant.now();
JWTClaimsSet claimsSet = new JWTClaimsSet.Builder().issuer(clientRegistration.getClientId()).subject(clientRegistration.getClientId())
.audience(clientRegistration.getProviderDetails().getTokenUri()).issueTime(Date.from(Instant.now())).jwtID(UUID.randomUUID().toString())
.expirationTime(Date.from(Instant.now().plusSeconds(JWT_LIFETIME))).build();
.audience(clientRegistration.getProviderDetails().getTokenUri()).issueTime(Date.from(now)).jwtID(UUID.randomUUID().toString())
.expirationTime(Date.from(now.plusSeconds(JWT_LIFETIME_SECONDS))).build();

JWSHeader jwt = new JWSHeader.Builder(JWSAlgorithm.RS256).type(JOSEObjectType.JWT).keyID(jwk.getKeyID()).build();
SignedJWT signedJWT = new SignedJWT(jwt, claimsSet);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ public Lti13Service(UserRepository userRepository, ExerciseRepository exerciseRe
* @param clientRegistrationId the clientRegistrationId of the source LMS
*/
public void performLaunch(OidcIdToken ltiIdToken, String clientRegistrationId) {

String targetLinkUrl = ltiIdToken.getClaim(Claims.TARGET_LINK_URI);
Optional<Exercise> targetExercise = getExerciseFromTargetLink(targetLinkUrl);
if (targetExercise.isEmpty()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,32 +82,40 @@ public LtiService(UserCreationService userCreationService, UserRepository userRe
* @throws InternalAuthenticationServiceException if no email is provided, or if no user can be authenticated, this exception will be thrown
*/
public void authenticateLtiUser(String email, String username, String firstName, String lastName, boolean requireExistingUser) throws InternalAuthenticationServiceException {

log.info("Authenticating LTI user with email: {}, username: {}, firstName: {}, lastName: {}, requireExistingUser: {}", email, username, firstName, lastName,
requireExistingUser);
if (!StringUtils.hasLength(email)) {
log.warn("No email address sent by launch request. Please make sure the user has an accessible email address.");
throw new InternalAuthenticationServiceException("No email address sent by launch request. Please make sure the user has an accessible email address.");
}

if (SecurityUtils.isAuthenticated()) {
log.info("User is already signed in. Checking if email matches the one provided in the launch.");
User user = userRepository.getUser();
if (email.equalsIgnoreCase(user.getEmail())) { // 1. Case: User is already signed in and email matches the one provided in the launch
log.info("User is already signed in and email matches the one provided in the launch. No further action required.");
return;
}
else {
log.info("User is already signed in but email does not match the one provided in the launch. Signing out user.");
SecurityContextHolder.getContext().setAuthentication(null); // User is signed in but email does not match, meaning launch is for a different user
}
}

// 2. Case: Lookup user with the LTI email address and make sure it's not in use
if (artemisAuthenticationProvider.getUsernameForEmail(email).isPresent() || userRepository.findOneByEmailIgnoreCase(email).isPresent()) {
log.info("User with email {} already exists. Email is already in use.", email);
throw new LtiEmailAlreadyInUseException();
}

// 3. Case: Create new user if an existing user is not required
if (!requireExistingUser) {
log.info("Creating new user from launch request: {}, username: {}, firstName: {}, lastName: {}", email, username, firstName, lastName);
SecurityContextHolder.getContext().setAuthentication(createNewUserFromLaunchRequest(email, username, firstName, lastName));
return;
}

log.info("Could not find existing user or create new LTI user.");
throw new InternalAuthenticationServiceException("Could not find existing user or create new LTI user."); // If user couldn't be authenticated, throw an error
}

Expand Down Expand Up @@ -171,15 +179,20 @@ private void addUserToExerciseGroup(User user, Course course) {
* @param response the response to add the JWT cookie to
*/
public void buildLtiResponse(UriComponentsBuilder uriComponentsBuilder, HttpServletResponse response) {
// TODO SK: why do we logout the user here if it was already activated?

User user = userRepository.getUser();

if (!user.getActivated()) {
log.info("User is not activated. Adding JWT cookie for activation.");
log.info("Add JWT cookie so the user will be logged in");
ResponseCookie responseCookie = jwtCookieService.buildLoginCookie(true);
response.addHeader(HttpHeaders.SET_COOKIE, responseCookie.toString());

uriComponentsBuilder.queryParam("initialize", "");
}
else {
log.info("User is activated. Adding JWT cookie for logout.");
prepareLogoutCookie(response);
uriComponentsBuilder.queryParam("ltiSuccessLoginRequired", user.getLogin());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@
import uk.ac.ox.ctl.lti13.security.oauth2.client.lti.web.OAuth2LoginAuthenticationFilter;

/**
* Processes an LTI 1.3 Authorization Request response.
* Filter for processing an LTI 1.3 Authorization Request response.
* It listens for LTI login attempts {@see CustomLti13Configurer#LTI13_LOGIN_PATH} and processes them.
* Step 3. of OpenID Connect Third Party Initiated Login is handled solely by spring-security-lti13
* OAuth2LoginAuthenticationFilter.
*/
Expand All @@ -58,8 +59,10 @@ protected void doFilterInternal(HttpServletRequest request, HttpServletResponse
filterChain.doFilter(request, response);
return;
}
log.info("LTI 1.3 Launch request received for url {}", this.requestMatcher.getPattern());

try {
// Login using the distributed authorization request repository
OidcAuthenticationToken authToken = finishOidcFlow(request, response);
OidcIdToken ltiIdToken = ((OidcUser) authToken.getPrincipal()).getIdToken();
String targetLink = ltiIdToken.getClaim(Claims.TARGET_LINK_URI).toString();
Expand All @@ -80,6 +83,7 @@ protected void doFilterInternal(HttpServletRequest request, HttpServletResponse
catch (LtiEmailAlreadyInUseException ex) {
// LtiEmailAlreadyInUseException is thrown in case of user who has email address in use is not authenticated after targetLink is set
// We need targetLink to redirect user on the client-side after successful authentication
log.error("LTI 1.3 launch failed due to email already in use: {}", ex.getMessage(), ex);
handleLtiEmailAlreadyInUseException(response, ltiIdToken);
}

Expand All @@ -96,7 +100,7 @@ private void handleLtiEmailAlreadyInUseException(HttpServletResponse response, O
response.setStatus(HttpServletResponse.SC_UNAUTHORIZED);
}

private OidcAuthenticationToken finishOidcFlow(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException {
private OidcAuthenticationToken finishOidcFlow(HttpServletRequest request, HttpServletResponse response) {
OidcAuthenticationToken ltiAuthToken;
try {
// call spring-security-lti13 authentication filter to finish the OpenID Connect Third Party Initiated Login
Expand All @@ -117,6 +121,7 @@ private void writeResponse(String targetLinkUri, OidcIdToken ltiIdToken, String

UriComponentsBuilder uriBuilder = UriComponentsBuilder.fromUriString(targetLinkUri);
if (SecurityUtils.isAuthenticated()) {
log.info("User is authenticated, building LTI response");
lti13Service.buildLtiResponse(uriBuilder, response);
}
LtiAuthenticationResponse jsonResponse = new LtiAuthenticationResponse(uriBuilder.build().toUriString(), ltiIdToken.getTokenValue(), clientRegistrationId);
Expand Down
Loading

0 comments on commit 8cc18bb

Please sign in to comment.