Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Access Control Part 2 - Implement JwtService #10

Merged
merged 11 commits into from
Nov 25, 2024
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.leungcheng.spring_simple_backend.domain;
package com.leungcheng.spring_simple_backend.auth;

import com.leungcheng.spring_simple_backend.domain.UserRepository;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.security.core.userdetails.UserDetails;
import org.springframework.security.core.userdetails.UserDetailsService;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package com.leungcheng.spring_simple_backend.auth;

import com.leungcheng.spring_simple_backend.domain.JwtService;
import com.leungcheng.spring_simple_backend.domain.UserRepository;
import jakarta.servlet.FilterChain;
import jakarta.servlet.ServletException;
Expand Down Expand Up @@ -28,11 +27,15 @@ protected void doFilterInternal(
.ifPresent(
accessToken -> {
if (SecurityContextHolder.getContext().getAuthentication() == null) {
UserInfoAuthenticationToken authToken =
new UserInfoAuthenticationToken(jwtService.parseAccessToken(accessToken));
SecurityContext context = SecurityContextHolder.createEmptyContext();
context.setAuthentication(authToken);
SecurityContextHolder.setContext(context);
try {
UserInfoAuthenticationToken authToken =
new UserInfoAuthenticationToken(jwtService.parseAccessToken(accessToken));
SecurityContext context = SecurityContextHolder.createEmptyContext();
context.setAuthentication(authToken);
SecurityContextHolder.setContext(context);
} catch (JwtService.InvalidTokenException e) {
response.setStatus(HttpServletResponse.SC_UNAUTHORIZED);
}
}
});

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package com.leungcheng.spring_simple_backend.auth;

import com.leungcheng.spring_simple_backend.domain.User;
import io.jsonwebtoken.ExpiredJwtException;
import io.jsonwebtoken.Jwts;
import io.jsonwebtoken.security.SignatureException;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.time.Instant;
import java.util.Date;
import javax.crypto.SecretKey;
import javax.crypto.spec.SecretKeySpec;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;

@Component
public class JwtService {
public record UserInfo(String userId) {}

public static class InvalidTokenException extends RuntimeException {
public InvalidTokenException(String message) {
super(message);
}
}

public JwtService(
@Value("${jwt.hs256Key}") String hs256Key,
@Value("${jwt.expiredDuration}") Duration expiredDuration) {
byte[] keyBytes = hs256Key.getBytes(StandardCharsets.UTF_8);
secretKey = new SecretKeySpec(keyBytes, "HmacSHA256");
this.expiredDuration = expiredDuration;
}

private final SecretKey secretKey;

private final Duration expiredDuration;

private Date getExpirationDate() {
return Date.from(Instant.now().plus(this.expiredDuration));
}

public String generateAccessToken(User user) {
return Jwts.builder()
.subject(user.getId())
.expiration(getExpirationDate())
.signWith(this.secretKey, Jwts.SIG.HS256)
.compact();
}

/**
* @throws InvalidTokenException if token is invalid due to expiration, invalid signature, or
* other reasons
*/
public UserInfo parseAccessToken(String token) {
try {
String userId =
Jwts.parser()
.verifyWith(this.secretKey)
.build()
.parseSignedClaims(token)
.getPayload()
.getSubject();
return new UserInfo(userId);
} catch (Exception e) {
if (e instanceof ExpiredJwtException) {
throw new InvalidTokenException("Expired token");
}
if (e instanceof SignatureException) {
throw new InvalidTokenException("Invalid signature");
}
throw new InvalidTokenException("Invalid token");
}
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package com.leungcheng.spring_simple_backend.auth;

import com.leungcheng.spring_simple_backend.domain.CustomUserDetailService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package com.leungcheng.spring_simple_backend.auth;

import com.leungcheng.spring_simple_backend.domain.JwtService.UserInfo;
import com.leungcheng.spring_simple_backend.auth.JwtService.UserInfo;
import org.springframework.security.authentication.AbstractAuthenticationToken;

public class UserInfoAuthenticationToken extends AbstractAuthenticationToken {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package com.leungcheng.spring_simple_backend.controller;

import com.leungcheng.spring_simple_backend.domain.JwtService;
import com.leungcheng.spring_simple_backend.auth.JwtService;
import com.leungcheng.spring_simple_backend.domain.User;
import com.leungcheng.spring_simple_backend.domain.UserRepository;
import jakarta.validation.Valid;
Expand Down

This file was deleted.

2 changes: 2 additions & 0 deletions src/main/resources/application.properties
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
spring.application.name=spring-simple-backend
jwt.hs256Key=${JWT_HS256_KEY:your-default-key-1234567890abcdef}
jwt.expiredDuration=1h
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
package com.leungcheng.spring_simple_backend;

import static org.hamcrest.Matchers.not;
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.Mockito.when;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.*;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.*;

import com.jayway.jsonpath.JsonPath;
import com.leungcheng.spring_simple_backend.domain.JwtService;
import com.leungcheng.spring_simple_backend.domain.ProductRepository;
import com.leungcheng.spring_simple_backend.domain.User;
import com.leungcheng.spring_simple_backend.domain.UserRepository;
Expand All @@ -17,7 +14,6 @@
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.autoconfigure.web.servlet.AutoConfigureMockMvc;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.boot.test.mock.mockito.MockBean;
import org.springframework.test.web.servlet.MockMvc;
import org.springframework.test.web.servlet.MvcResult;
import org.springframework.test.web.servlet.ResultActions;
Expand All @@ -31,8 +27,6 @@ class SpringSimpleBackendApplicationTests {
@Autowired private ProductRepository productRepository;
@Autowired private UserRepository userRepository;

@MockBean private JwtService jwtService;

private Optional<String> accessToken = Optional.empty();

@BeforeEach
Expand All @@ -55,11 +49,6 @@ private String useNewUserAccessToken() throws Exception {
signup(userCredentials).andExpect(status().isCreated());
User user = userRepository.findByUsername(userCredentials.username).orElseThrow();

when(jwtService.generateAccessToken(argThat(argument -> argument.getId().equals(user.getId()))))
.thenReturn("dummy-token");
when(jwtService.parseAccessToken("dummy-token"))
.thenReturn(new JwtService.UserInfo(user.getId()));

MvcResult result = login(userCredentials).andExpect(status().isOk()).andReturn();
String token = JsonPath.read(result.getResponse().getContentAsString(), "$.accessToken");
setAccessToken(token);
Expand Down Expand Up @@ -153,6 +142,12 @@ public void shouldRejectNonAuthApiCallWithoutToken() throws Exception {
createProduct(CreateProductParams.sample()).andExpect(status().isForbidden());
}

@Test
public void shouldRejectIfApiCallWithInvalidToken() throws Exception {
setAccessToken("invalid-token");
createProduct(CreateProductParams.sample()).andExpect(status().isForbidden());
}

@Test
public void shouldRejectIfAuthHeaderIsNotSetCorrectly() throws Exception {
useNewUserAccessToken();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
package com.leungcheng.spring_simple_backend.auth;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;

import com.leungcheng.spring_simple_backend.domain.User;
import io.jsonwebtoken.Jwts;
import java.time.Duration;
import org.junit.jupiter.api.Test;

public class JwtServiceTest {
private static User.Builder userBuilder() {
return new User.Builder().username("default-user").password("password");
}

private static class JwtServiceBuilder {
private String hs256Key = Jwts.SIG.HS256.key().build().toString();
private Duration expiredDuration = Duration.ofHours(1);

private JwtServiceBuilder expiredDuration(Duration expiredDuration) {
this.expiredDuration = expiredDuration;
return this;
}

private JwtServiceBuilder newHs256Key() {
this.hs256Key = Jwts.SIG.HS256.key().build().toString();
return this;
}

private JwtService build() {
return new JwtService(hs256Key, expiredDuration);
}
}

@Test
void shouldGenerateAndParseAccessToken() {
JwtService jwtService = new JwtServiceBuilder().build();
User user = userBuilder().build();

String token = jwtService.generateAccessToken(user);
JwtService.UserInfo userInfo = jwtService.parseAccessToken(token);

assertEquals(user.getId(), userInfo.userId());
}

@Test
void shouldThrowExceptionIfTokenExpired() {
JwtService jwtService = new JwtServiceBuilder().expiredDuration(Duration.ofSeconds(-1)).build();
User user = userBuilder().build();

String token = jwtService.generateAccessToken(user);

JwtService.InvalidTokenException exception =
assertThrows(
JwtService.InvalidTokenException.class,
() -> {
jwtService.parseAccessToken(token);
});
assertEquals("Expired token", exception.getMessage());
}

@Test
void shouldThrowExceptionIfTokenIsSignedByDifferentKey() {
JwtServiceBuilder builder = new JwtServiceBuilder();
JwtService jwtService = builder.build();
User user = userBuilder().build();

String token = jwtService.generateAccessToken(user);

JwtService.InvalidTokenException exception =
assertThrows(
JwtService.InvalidTokenException.class,
() -> {
builder.newHs256Key().build().parseAccessToken(token);
});
assertEquals("Invalid signature", exception.getMessage());
}

@Test
void shouldThrowExceptionIfTokenIsInvalid() {
JwtService jwtService = new JwtServiceBuilder().build();
User user = userBuilder().build();

JwtService.InvalidTokenException exception =
assertThrows(
JwtService.InvalidTokenException.class,
() -> {
jwtService.parseAccessToken("invalid");
});
assertEquals("Invalid token", exception.getMessage());
}
}