Skip to content

Commit

Permalink
feat: invalidate all sessions of a user after password changed (halo-…
Browse files Browse the repository at this point in the history
…dev#5757)

* feat: invalidate all sessions of a user after password changed

* fix: unit test case

* refactor: use spring session 3.3 to adapt

* refactor: compatible with session timeout configuration

* refactor: indexed session repository

* Reload page after changed the password

Signed-off-by: Ryan Wang <i@ryanc.cc>

* chore: update session repository

---------

Signed-off-by: Ryan Wang <i@ryanc.cc>
Co-authored-by: Ryan Wang <i@ryanc.cc>
  • Loading branch information
guqing and ruibaby authored Apr 23, 2024
1 parent cc7f2de commit 06e0b63
Show file tree
Hide file tree
Showing 10 changed files with 348 additions and 4 deletions.
1 change: 1 addition & 0 deletions api/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ dependencies {
api 'org.springframework.boot:spring-boot-starter-webflux'
api 'org.springframework.boot:spring-boot-starter-validation'
api 'org.springframework.boot:spring-boot-starter-data-r2dbc'
api 'org.springframework.session:spring-session-core'

// Spring Security
api 'org.springframework.boot:spring-boot-starter-security'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@
import static org.springframework.security.web.server.util.matcher.ServerWebExchangeMatchers.pathMatchers;

import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import lombok.RequiredArgsConstructor;
import org.springframework.beans.factory.ObjectProvider;
import org.springframework.boot.autoconfigure.session.SessionProperties;
import org.springframework.boot.autoconfigure.web.ServerProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.core.Ordered;
Expand All @@ -23,6 +26,8 @@
import org.springframework.security.web.server.context.WebSessionServerSecurityContextRepository;
import org.springframework.security.web.server.util.matcher.AndServerWebExchangeMatcher;
import org.springframework.security.web.server.util.matcher.MediaTypeServerWebExchangeMatcher;
import org.springframework.session.MapSession;
import org.springframework.session.config.annotation.web.server.EnableSpringWebSession;
import org.springframework.web.reactive.function.server.RouterFunction;
import org.springframework.web.reactive.function.server.ServerResponse;
import run.halo.app.core.extension.service.RoleService;
Expand All @@ -41,13 +46,16 @@
import run.halo.app.security.authentication.pat.PatServerWebExchangeMatcher;
import run.halo.app.security.authentication.twofactor.TwoFactorAuthorizationManager;
import run.halo.app.security.authorization.RequestInfoAuthorizationManager;
import run.halo.app.security.session.InMemoryReactiveIndexedSessionRepository;
import run.halo.app.security.session.ReactiveIndexedSessionRepository;

/**
* Security configuration for WebFlux.
*
* @author johnniang
*/
@Configuration
@EnableSpringWebSession
@EnableWebFluxSecurity
@RequiredArgsConstructor
public class WebServerSecurityConfig {
Expand Down Expand Up @@ -131,6 +139,17 @@ ServerSecurityContextRepository securityContextRepository() {
return new WebSessionServerSecurityContextRepository();
}

@Bean
public ReactiveIndexedSessionRepository<MapSession> reactiveSessionRepository(
SessionProperties sessionProperties,
ServerProperties serverProperties) {
var repository = new InMemoryReactiveIndexedSessionRepository(new ConcurrentHashMap<>());
var timeout = sessionProperties.determineTimeout(
() -> serverProperties.getReactive().getSession().getTimeout());
repository.setDefaultMaxInactiveInterval(timeout);
return repository;
}

@Bean
DefaultUserDetailService userDetailsService(UserService userService,
RoleService roleService) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import java.util.Set;
import lombok.RequiredArgsConstructor;
import org.apache.commons.lang3.BooleanUtils;
import org.springframework.context.ApplicationEventPublisher;
import org.springframework.security.crypto.password.PasswordEncoder;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
Expand All @@ -19,6 +20,7 @@
import run.halo.app.core.extension.Role;
import run.halo.app.core.extension.RoleBinding;
import run.halo.app.core.extension.User;
import run.halo.app.event.user.PasswordChangedEvent;
import run.halo.app.extension.ReactiveExtensionClient;
import run.halo.app.extension.exception.ExtensionNotFoundException;
import run.halo.app.infra.SystemConfigurableEnvironmentFetcher;
Expand All @@ -38,6 +40,7 @@ public class UserServiceImpl implements UserService {

private final SystemConfigurableEnvironmentFetcher environmentFetcher;

private final ApplicationEventPublisher eventPublisher;

@Override
public Mono<User> getUser(String username) {
Expand All @@ -59,7 +62,8 @@ public Mono<User> updatePassword(String username, String newPassword) {
.flatMap(user -> {
user.getSpec().setPassword(newPassword);
return client.update(user);
});
})
.doOnNext(user -> publishPasswordChangedEvent(username));
}

@Override
Expand All @@ -76,7 +80,8 @@ public Mono<User> updateWithRawPassword(String username, String rawPassword) {
.flatMap(user -> {
user.getSpec().setPassword(passwordEncoder.encode(rawPassword));
return client.update(user);
});
})
.doOnNext(user -> publishPasswordChangedEvent(username));
}

@Override
Expand Down Expand Up @@ -179,7 +184,6 @@ public Mono<User> createUser(User user, Set<String> roleNames) {

@Override
public Mono<Boolean> confirmPassword(String username, String rawPassword) {

return getUser(username)
.filter(user -> {
if (!StringUtils.hasText(user.getSpec().getPassword())) {
Expand All @@ -193,4 +197,8 @@ public Mono<Boolean> confirmPassword(String username, String rawPassword) {
})
.hasElement();
}

void publishPasswordChangedEvent(String username) {
eventPublisher.publishEvent(new PasswordChangedEvent(this, username));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package run.halo.app.event.user;

import lombok.Getter;
import org.springframework.context.ApplicationEvent;

@Getter
public class PasswordChangedEvent extends ApplicationEvent {
private final String username;

public PasswordChangedEvent(Object source, String username) {
super(source);
this.username = username;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
package run.halo.app.security.session;

import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import java.time.Duration;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import org.springframework.beans.factory.DisposableBean;
import org.springframework.session.DelegatingIndexResolver;
import org.springframework.session.IndexResolver;
import org.springframework.session.MapSession;
import org.springframework.session.PrincipalNameIndexResolver;
import org.springframework.session.ReactiveMapSessionRepository;
import org.springframework.session.Session;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

public class InMemoryReactiveIndexedSessionRepository extends ReactiveMapSessionRepository
implements ReactiveIndexedSessionRepository<MapSession>, DisposableBean {

final IndexResolver<MapSession> indexResolver =
new DelegatingIndexResolver<>(new PrincipalNameIndexResolver<>(PRINCIPAL_NAME_INDEX_NAME));

private final ConcurrentMap<String, Set<IndexKey>> sessionIdIndexMap =
new ConcurrentHashMap<>();
private final ConcurrentMap<IndexKey, Set<String>> indexSessionIdMap =
new ConcurrentHashMap<>();

/**
* Prevent other requests from being parsed and acquiring the session during its deletion,
* which could result in an unintended renewal. Currently, it acts as a buffer, and having a
* slightly prolonged expiration period is sufficient.
*/
private final Cache<String, Boolean> invalidateSessionIds = CacheBuilder.newBuilder()
.expireAfterWrite(Duration.ofMinutes(10))
.maximumSize(10_000)
.build();

public InMemoryReactiveIndexedSessionRepository(Map<String, Session> sessions) {
super(sessions);
}

@Override
public Mono<Void> save(MapSession session) {
if (invalidateSessionIds.getIfPresent(session.getId()) != null) {
return this.deleteById(session.getId());
}
return super.save(session)
.then(updateIndex(session));
}

@Override
public Mono<Void> deleteById(String id) {
return removeIndex(id)
.then(Mono.defer(() -> {
invalidateSessionIds.put(id, true);
return super.deleteById(id);
}));
}

@Override
public Mono<Map<String, MapSession>> findByIndexNameAndIndexValue(String indexName,
String indexValue) {
var indexKey = new IndexKey(indexName, indexValue);
return Flux.fromStream((() -> indexSessionIdMap.getOrDefault(indexKey, Set.of()).stream()))
.flatMap(this::findById)
.collectMap(Session::getId);
}

@Override
public Mono<Map<String, MapSession>> findByPrincipalName(String principalName) {
return this.findByIndexNameAndIndexValue(PRINCIPAL_NAME_INDEX_NAME, principalName);
}

@Override
public void destroy() {
sessionIdIndexMap.clear();
indexSessionIdMap.clear();
invalidateSessionIds.invalidateAll();
}

Mono<Void> removeIndex(String sessionId) {
return getIndexes(sessionId)
.doOnNext(indexKey -> indexSessionIdMap.computeIfPresent(indexKey,
(key, sessionIdSet) -> {
sessionIdSet.remove(sessionId);
return sessionIdSet.isEmpty() ? null : sessionIdSet;
})
)
.then(Mono.defer(() -> {
sessionIdIndexMap.remove(sessionId);
return Mono.empty();
}))
.then();
}

Mono<Void> updateIndex(MapSession session) {
return removeIndex(session.getId())
.then(Mono.defer(() -> {
indexResolver.resolveIndexesFor(session)
.forEach((name, value) -> {
IndexKey indexKey = new IndexKey(name, value);
indexSessionIdMap.computeIfAbsent(indexKey,
unusedSet -> ConcurrentHashMap.newKeySet())
.add(session.getId());
// Update sessionIdIndexMap
sessionIdIndexMap.computeIfAbsent(session.getId(),
unusedSet -> ConcurrentHashMap.newKeySet())
.add(indexKey);
});
return Mono.empty();
}))
.then();
}

Flux<IndexKey> getIndexes(String sessionId) {
return Flux.fromIterable(sessionIdIndexMap.getOrDefault(sessionId, Set.of()));
}

/**
* For testing purpose.
*/
ConcurrentMap<String, Set<IndexKey>> getSessionIdIndexMap() {
return sessionIdIndexMap;
}

/**
* For testing purpose.
*/
ConcurrentMap<IndexKey, Set<String>> getIndexSessionIdMap() {
return indexSessionIdMap;
}

record IndexKey(String attributeName, String attributeValue) {
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package run.halo.app.security.session;

import org.springframework.session.ReactiveFindByIndexNameSessionRepository;
import org.springframework.session.ReactiveSessionRepository;
import org.springframework.session.Session;

public interface ReactiveIndexedSessionRepository<S extends Session>
extends ReactiveSessionRepository<S>, ReactiveFindByIndexNameSessionRepository<S> {
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package run.halo.app.security.session;

import java.util.Map;
import lombok.RequiredArgsConstructor;
import org.springframework.context.event.EventListener;
import org.springframework.scheduling.annotation.Async;
import org.springframework.session.ReactiveFindByIndexNameSessionRepository;
import org.springframework.session.ReactiveSessionRepository;
import org.springframework.session.Session;
import org.springframework.stereotype.Component;
import reactor.core.publisher.Flux;
import run.halo.app.event.user.PasswordChangedEvent;

@Component
@RequiredArgsConstructor
public class SessionInvalidationListener {

private final ReactiveFindByIndexNameSessionRepository<? extends Session>
indexedSessionRepository;
private final ReactiveSessionRepository<? extends Session> sessionRepository;

@Async
@EventListener
public void onPasswordChanged(PasswordChangedEvent event) {
String username = event.getUsername();
// Invalidate session
invalidateUserSessions(username);
}

private void invalidateUserSessions(String username) {
indexedSessionRepository.findByPrincipalName(username)
.map(Map::keySet)
.flatMapMany(Flux::fromIterable)
.flatMap(sessionRepository::deleteById)
.then()
.block();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,15 @@
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.context.ApplicationEventPublisher;
import org.springframework.security.crypto.password.PasswordEncoder;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;
import run.halo.app.core.extension.Role;
import run.halo.app.core.extension.RoleBinding;
import run.halo.app.core.extension.User;
import run.halo.app.event.user.PasswordChangedEvent;
import run.halo.app.extension.Metadata;
import run.halo.app.extension.ReactiveExtensionClient;
import run.halo.app.extension.exception.ExtensionNotFoundException;
Expand All @@ -57,6 +59,9 @@ class UserServiceImplTest {
@Mock
PasswordEncoder passwordEncoder;

@Mock
ApplicationEventPublisher eventPublisher;

@InjectMocks
UserServiceImpl userService;

Expand Down Expand Up @@ -99,6 +104,8 @@ void shouldUpdatePasswordIfUserFoundInExtension() {
var user = (User) extension;
return "new-fake-password".equals(user.getSpec().getPassword());
}));

verify(eventPublisher).publishEvent(any(PasswordChangedEvent.class));
}

@Test
Expand Down Expand Up @@ -240,6 +247,7 @@ void shouldUpdatePasswordWithDifferentPassword() {
var user = (User) extension;
return "encoded-new-password".equals(user.getSpec().getPassword());
}));
verify(eventPublisher).publishEvent(any(PasswordChangedEvent.class));
}

@Test
Expand All @@ -262,6 +270,7 @@ void shouldUpdatePasswordIfNoPasswordBefore() {
return "encoded-new-password".equals(user.getSpec().getPassword());
}));
verify(client).get(User.class, "fake-user");
verify(eventPublisher).publishEvent(any(PasswordChangedEvent.class));
}

@Test
Expand All @@ -281,6 +290,7 @@ void shouldDoNothingIfPasswordNotChanged() {
verify(passwordEncoder, never()).encode(any());
verify(client, never()).update(any());
verify(client).get(User.class, "fake-user");
verify(eventPublisher, times(0)).publishEvent(any(PasswordChangedEvent.class));
}

@Test
Expand Down
Loading

0 comments on commit 06e0b63

Please sign in to comment.