diff --git a/app/src/main/java/com/techcourse/dao/UserDao.java b/app/src/main/java/com/techcourse/dao/UserDao.java index 703b34fabd..d3a230a97a 100644 --- a/app/src/main/java/com/techcourse/dao/UserDao.java +++ b/app/src/main/java/com/techcourse/dao/UserDao.java @@ -3,6 +3,7 @@ import com.interface21.jdbc.core.JdbcTemplate; import com.interface21.jdbc.core.RowMapper; import com.techcourse.domain.User; +import java.sql.Connection; import java.util.List; import javax.sql.DataSource; import org.slf4j.Logger; @@ -34,6 +35,12 @@ public void insert(User user) { jdbcTemplate.update(sql, user.getAccount(), user.getPassword(), user.getEmail()); } + public void updateUsingExplicitConnection(User user, Connection connection) { + String sql = "UPDATE users SET account = ?, password = ?, email = ? where id = ?"; + logSql(sql); + jdbcTemplate.update(connection, sql, user.getAccount(), user.getPassword(), user.getEmail(), user.getId()); + } + public void update(final User user) { String sql = "UPDATE users SET account = ?, password = ?, email = ? WHERE id = ?"; logSql(sql); diff --git a/app/src/main/java/com/techcourse/dao/UserHistoryDao.java b/app/src/main/java/com/techcourse/dao/UserHistoryDao.java index 61e0cb5e02..bc57c3d2a8 100644 --- a/app/src/main/java/com/techcourse/dao/UserHistoryDao.java +++ b/app/src/main/java/com/techcourse/dao/UserHistoryDao.java @@ -3,6 +3,7 @@ import com.interface21.jdbc.core.JdbcTemplate; import com.interface21.jdbc.core.RowMapper; import com.techcourse.domain.UserHistory; +import java.sql.Connection; import javax.sql.DataSource; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -31,7 +32,7 @@ public UserHistoryDao(final DataSource dataSource) { public void log(UserHistory userHistory) { String sql = "INSERT INTO user_history (user_id, account, password, email, created_at, created_by) VALUES (?, ?, ?, ?, ?, ?)"; - log.debug("query : {}", sql); + logSql(sql); jdbcTemplate.update( sql, @@ -40,9 +41,23 @@ public void log(UserHistory userHistory) { ); } + public void logUsingExplicitConnection(UserHistory userHistory, Connection connection) { + String sql = "INSERT INTO user_history (user_id, account, password, email, created_at, created_by) VALUES (?, ?, ?, ?, ?, ?)"; + logSql(sql); + jdbcTemplate.update( + connection, sql, + userHistory.getUserId(), userHistory.getAccount(), userHistory.getPassword(), + userHistory.getEmail(), userHistory.getCreatedAt(), userHistory.getCreateBy() + ); + } + public UserHistory findById(Long id) { String sql = "SELECT id, user_id, account, password, email, created_at, created_by FROM user_history WHERE id = ?"; - log.debug("query : {}", sql); + logSql(sql); return jdbcTemplate.queryForObject(sql, ROW_MAPPER, id); } + + private void logSql(String sql) { + log.debug("query : {}", sql); + } } diff --git a/app/src/main/java/com/techcourse/service/UserService.java b/app/src/main/java/com/techcourse/service/UserService.java index fcf2159dc8..07c5243c39 100644 --- a/app/src/main/java/com/techcourse/service/UserService.java +++ b/app/src/main/java/com/techcourse/service/UserService.java @@ -1,32 +1,40 @@ package com.techcourse.service; +import com.interface21.jdbc.transaction.TransactionManager; import com.techcourse.dao.UserDao; import com.techcourse.dao.UserHistoryDao; import com.techcourse.domain.User; import com.techcourse.domain.UserHistory; +import java.sql.Connection; public class UserService { + private final TransactionManager txManager; private final UserDao userDao; private final UserHistoryDao userHistoryDao; - public UserService(final UserDao userDao, final UserHistoryDao userHistoryDao) { + public UserService(TransactionManager txManager, UserDao userDao, UserHistoryDao userHistoryDao) { + this.txManager = txManager; this.userDao = userDao; this.userHistoryDao = userHistoryDao; } - public User findById(final long id) { + public User findById(long id) { return userDao.findById(id); } - public void insert(final User user) { + public void insert(User user) { userDao.insert(user); } - public void changePassword(final long id, final String newPassword, final String createBy) { + public void changePassword(long id, String newPassword, String createBy) { + txManager.executeTransactionOf(conn -> changePasswordTx(conn, id, newPassword, createBy)); + } + + private void changePasswordTx(Connection connection, long id, String newPassword, String createBy) { final var user = findById(id); user.changePassword(newPassword); - userDao.update(user); - userHistoryDao.log(new UserHistory(user, createBy)); + userDao.updateUsingExplicitConnection(user, connection); + userHistoryDao.logUsingExplicitConnection(new UserHistory(user, createBy), connection); } } diff --git a/app/src/test/java/com/techcourse/service/MockUserHistoryDao.java b/app/src/test/java/com/techcourse/service/MockUserHistoryDao.java index a3823c1619..9df9ae2c7b 100644 --- a/app/src/test/java/com/techcourse/service/MockUserHistoryDao.java +++ b/app/src/test/java/com/techcourse/service/MockUserHistoryDao.java @@ -4,6 +4,7 @@ import com.techcourse.domain.UserHistory; import com.interface21.dao.DataAccessException; import com.interface21.jdbc.core.JdbcTemplate; +import java.sql.Connection; public class MockUserHistoryDao extends UserHistoryDao { @@ -12,7 +13,7 @@ public MockUserHistoryDao(final JdbcTemplate jdbcTemplate) { } @Override - public void log(final UserHistory userHistory) { + public void logUsingExplicitConnection(UserHistory userHistory, Connection connection) { throw new DataAccessException(); } } diff --git a/app/src/test/java/com/techcourse/service/UserServiceTest.java b/app/src/test/java/com/techcourse/service/UserServiceTest.java index af18e2c4cb..ac7f38f316 100644 --- a/app/src/test/java/com/techcourse/service/UserServiceTest.java +++ b/app/src/test/java/com/techcourse/service/UserServiceTest.java @@ -1,39 +1,42 @@ package com.techcourse.service; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import com.interface21.dao.DataAccessException; +import com.interface21.jdbc.core.JdbcTemplate; +import com.interface21.jdbc.transaction.TransactionManager; import com.techcourse.config.DataSourceConfig; import com.techcourse.dao.UserDao; import com.techcourse.dao.UserHistoryDao; import com.techcourse.domain.User; import com.techcourse.support.jdbc.init.DatabasePopulatorUtils; -import com.interface21.dao.DataAccessException; -import com.interface21.jdbc.core.JdbcTemplate; +import javax.sql.DataSource; import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; -import static org.assertj.core.api.Assertions.assertThat; -import static org.junit.jupiter.api.Assertions.assertThrows; - -@Disabled class UserServiceTest { + private TransactionManager txManager; private JdbcTemplate jdbcTemplate; private UserDao userDao; @BeforeEach void setUp() { - this.jdbcTemplate = new JdbcTemplate(DataSourceConfig.getInstance()); - this.userDao = new UserDao(jdbcTemplate); + DataSource dataSource = DataSourceConfig.getInstance(); + txManager = new TransactionManager(dataSource); + jdbcTemplate = new JdbcTemplate(dataSource); + userDao = new UserDao(jdbcTemplate); - DatabasePopulatorUtils.execute(DataSourceConfig.getInstance()); - final var user = new User("gugu", "password", "hkkang@woowahan.com"); + DatabasePopulatorUtils.execute(dataSource); + User user = new User("gugu", "password", "hkkang@woowahan.com"); userDao.insert(user); } @Test void testChangePassword() { final var userHistoryDao = new UserHistoryDao(jdbcTemplate); - final var userService = new UserService(userDao, userHistoryDao); + final var userService = new UserService(txManager, userDao, userHistoryDao); final var newPassword = "qqqqq"; final var createBy = "gugu"; @@ -48,7 +51,7 @@ void testChangePassword() { void testTransactionRollback() { // 트랜잭션 롤백 테스트를 위해 mock으로 교체 final var userHistoryDao = new MockUserHistoryDao(jdbcTemplate); - final var userService = new UserService(userDao, userHistoryDao); + final var userService = new UserService(txManager, userDao, userHistoryDao); final var newPassword = "newPassword"; final var createBy = "gugu"; diff --git a/jdbc/src/main/java/com/interface21/jdbc/core/JdbcTemplate.java b/jdbc/src/main/java/com/interface21/jdbc/core/JdbcTemplate.java index 19be494fef..80105977ae 100644 --- a/jdbc/src/main/java/com/interface21/jdbc/core/JdbcTemplate.java +++ b/jdbc/src/main/java/com/interface21/jdbc/core/JdbcTemplate.java @@ -21,11 +21,23 @@ public JdbcTemplate(final DataSource dataSource) { this.dataSource = dataSource; } + public int update(QueryConnectionHolder queryConnectionHolder, PreparedStatementSetter preparedStatementSetter) { + try { + PreparedStatement preparedStatement = queryConnectionHolder.getAsPreparedStatement(); + preparedStatementSetter.setValues(preparedStatement); + return preparedStatement.executeUpdate(); + } catch (SQLException e) { + throw new DataAccessException("Update 실패", e); + } + } + + public int update(Connection connection, String sql, Object... args) { + return update(new QueryConnectionHolder(connection, sql), new PreparedStatementArgumentsSetter(args)); + } + public int update(String sql, PreparedStatementSetter psSetter) { - try (Connection connection = dataSource.getConnection(); - PreparedStatement ps = connection.prepareStatement(sql)) { - psSetter.setValues(ps); - return ps.executeUpdate(); + try (Connection connection = dataSource.getConnection()) { + return update(new QueryConnectionHolder(connection, sql), psSetter); } catch (SQLException e) { throw new DataAccessException("Update 실패", e); } @@ -35,11 +47,6 @@ public int update(String sql, Object... args) { return update(sql, new PreparedStatementArgumentsSetter(args)); } - public T queryForObject(String sql, RowMapper rowMapper, Object... args) { - List query = query(sql, rowMapper, args); - return query.isEmpty() ? null : query.getLast(); - } - public List query(String sql, PreparedStatementSetter psSetter, RowMapper rowMapper) { try (Connection connection = dataSource.getConnection(); PreparedStatement ps = connection.prepareStatement(sql)) { @@ -54,6 +61,11 @@ public List query(String sql, RowMapper rowMapper, Object... args) { return query(sql, new PreparedStatementArgumentsSetter(args), rowMapper); } + public T queryForObject(String sql, RowMapper rowMapper, Object... args) { + List query = query(sql, rowMapper, args); + return query.isEmpty() ? null : query.getLast(); + } + private List retrieveRow(RowMapper rowMapper, PreparedStatement ps) throws SQLException { List results = new ArrayList<>(); try (ResultSet rs = ps.executeQuery()) { diff --git a/jdbc/src/main/java/com/interface21/jdbc/core/QueryConnectionHolder.java b/jdbc/src/main/java/com/interface21/jdbc/core/QueryConnectionHolder.java new file mode 100644 index 0000000000..bb083b394f --- /dev/null +++ b/jdbc/src/main/java/com/interface21/jdbc/core/QueryConnectionHolder.java @@ -0,0 +1,26 @@ +package com.interface21.jdbc.core; + +import com.interface21.dao.DataAccessException; +import java.sql.Connection; +import java.sql.PreparedStatement; +import java.sql.SQLException; +import java.util.Objects; + +public class QueryConnectionHolder { + + private final String query; + private final Connection connection; + + public QueryConnectionHolder(Connection connection, String query) { + this.connection = Objects.requireNonNull(connection); + this.query = Objects.requireNonNull(query); + } + + public PreparedStatement getAsPreparedStatement() { + try { + return connection.prepareStatement(query); + } catch (SQLException e) { + throw new DataAccessException("PreparedStatement 생성 실패", e); + } + } +} diff --git a/jdbc/src/main/java/com/interface21/jdbc/transaction/TransactionManager.java b/jdbc/src/main/java/com/interface21/jdbc/transaction/TransactionManager.java new file mode 100644 index 0000000000..e506a92d87 --- /dev/null +++ b/jdbc/src/main/java/com/interface21/jdbc/transaction/TransactionManager.java @@ -0,0 +1,53 @@ +package com.interface21.jdbc.transaction; + +import com.interface21.dao.DataAccessException; +import java.sql.Connection; +import java.sql.SQLException; +import java.util.logging.Logger; +import javax.sql.DataSource; + +public class TransactionManager { + + private static final Logger log = Logger.getLogger(TransactionManager.class.getName()); + + private final DataSource dataSource; + + public TransactionManager(DataSource dataSource) { + this.dataSource = dataSource; + } + + public void executeTransactionOf(TransactionalFunction callback) { + Connection connection = null; + boolean shouldThrow = false; + try { + connection = dataSource.getConnection(); + connection.setAutoCommit(false); + callback.execute(connection); + connection.commit(); + } catch (Exception e) { + gracefulShutdown(connection, Connection::rollback); + shouldThrow = true; + } finally { + gracefulShutdown(connection, Connection::close); + } + if (shouldThrow) { + throw new DataAccessException("트랜잭션 실행 중 문제가 발생했습니다. 트랜잭션은 롤백됩니다."); + } + } + + private void gracefulShutdown(Connection connection, ThrowingConsumer connectionConsumer) { + try { + connectionConsumer.accept(connection); + } catch (NullPointerException e) { + log.warning("Connection을 찾을 수 없습니다."); + } catch (SQLException e) { + throw new DataAccessException(e.getMessage(), e); + } + } + + @FunctionalInterface + private interface ThrowingConsumer { + + void accept(T connection) throws SQLException; + } +} diff --git a/jdbc/src/main/java/com/interface21/jdbc/transaction/TransactionalFunction.java b/jdbc/src/main/java/com/interface21/jdbc/transaction/TransactionalFunction.java new file mode 100644 index 0000000000..0e12155824 --- /dev/null +++ b/jdbc/src/main/java/com/interface21/jdbc/transaction/TransactionalFunction.java @@ -0,0 +1,10 @@ +package com.interface21.jdbc.transaction; + +import java.sql.Connection; +import java.sql.SQLException; + +@FunctionalInterface +public interface TransactionalFunction { + + void execute(Connection connection) throws SQLException; +} diff --git a/jdbc/src/test/java/com/interface21/jdbc/transaction/TransactionManagerTest.java b/jdbc/src/test/java/com/interface21/jdbc/transaction/TransactionManagerTest.java new file mode 100644 index 0000000000..d257c055f3 --- /dev/null +++ b/jdbc/src/test/java/com/interface21/jdbc/transaction/TransactionManagerTest.java @@ -0,0 +1,57 @@ +package com.interface21.jdbc.transaction; + +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; + +import com.interface21.dao.DataAccessException; +import java.sql.Connection; +import java.sql.SQLException; +import javax.sql.DataSource; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; + +class TransactionManagerTest { + + private DataSource dataSource; + private Connection connection; + + @BeforeEach + void setUp() throws SQLException { + dataSource = mock(DataSource.class); + connection = mock(Connection.class); + given(dataSource.getConnection()).willReturn(connection); + } + + @Test + @DisplayName("트랜잭션 실행 중 예외가 발생하면 롤백된다.") + void rollbackOnException() throws SQLException { + // given + TransactionManager txManager = new TransactionManager(dataSource); + TransactionalFunction txFunction = conn -> { + throw new SQLException(); + }; + + // when & then + assertThatThrownBy(() -> txManager.executeTransactionOf(txFunction)) + .isInstanceOf(DataAccessException.class); + verify(connection).rollback(); + verify(connection, never()).commit(); + verify(connection).close(); + } + + @Test + @DisplayName("예외 없는 트랜잭션은 정상적으로 커밋된다.") + void commitOnNoException() throws SQLException { + TransactionManager txManager = new TransactionManager(dataSource); + TransactionalFunction txFunction = conn -> {}; + txManager.executeTransactionOf(txFunction); + + verify(connection).commit(); + verify(connection, never()).rollback(); + verify(connection).close(); + } +}