diff --git a/src/LogNormal/LogNormalMath.sol b/src/LogNormal/LogNormalMath.sol index ba58d4b9..50bc82ac 100644 --- a/src/LogNormal/LogNormalMath.sol +++ b/src/LogNormal/LogNormalMath.sol @@ -168,6 +168,30 @@ function computePriceGivenY( return params.mean.mulWadUp(uint256(exp)); } +function computeDeltaLXIn( + uint256 amountIn, + uint256 rx, + uint256 ry, + uint256 L, + LogNormalParams memory params +) pure returns (uint256 deltaL) { + uint256 fees = params.swapFee.mulWadUp(amountIn); + uint256 px = computePriceGivenX(rx, L, params); + deltaL = px.mulWadUp(L).mulWadUp(fees).divWadDown(px.mulWadDown(rx) + ry); +} + +function computeDeltaLYIn( + uint256 amountIn, + uint256 rx, + uint256 ry, + uint256 L, + LogNormalParams memory params +) pure returns (uint256 deltaL) { + uint256 fees = params.swapFee.mulWadUp(amountIn); + uint256 px = computePriceGivenX(rx, L, params); + deltaL = L.mulWadUp(fees).divWadDown(px.mulWadDown(rx) + ry); +} + /// @dev This is a pure anonymous function defined at the file level, which allows /// it to be passed as an argument to another function. BisectionLib.sol takes this /// function as an argument to find the root of the trading function given the reserveYWad. diff --git a/src/LogNormal/LogNormalSolver.sol b/src/LogNormal/LogNormalSolver.sol index d26f037f..18e9cfa0 100644 --- a/src/LogNormal/LogNormalSolver.sol +++ b/src/LogNormal/LogNormalSolver.sol @@ -26,7 +26,9 @@ import { computeYGivenL, computeNextRy, computePriceGivenX, - computePriceGivenY + computePriceGivenY, + computeDeltaLXIn, + computeDeltaLYIn } from "src/LogNormal/LogNormalMath.sol"; contract LogNormalSolver { @@ -254,7 +256,13 @@ contract LogNormalSolver { ); if (swapXIn) { - state.deltaLiquidity = amountIn.mulWadUp(poolParams.swapFee); + state.deltaLiquidity = computeDeltaLXIn( + amountIn, + preReserves[0], + preReserves[1], + preTotalLiquidity, + poolParams + ); endReserves.rx = preReserves[0] + amountIn; endReserves.L = startComputedL + state.deltaLiquidity; @@ -271,8 +279,13 @@ contract LogNormalSolver { ); state.amountOut = preReserves[1] - endReserves.ry; } else { - state.deltaLiquidity = amountIn.mulWadUp(poolParams.swapFee) - .divWadUp(poolParams.mean); + state.deltaLiquidity = computeDeltaLYIn( + amountIn, + preReserves[0], + preReserves[1], + preTotalLiquidity, + poolParams + ); endReserves.ry = preReserves[1] + amountIn; endReserves.L = startComputedL + state.deltaLiquidity; diff --git a/test/LogNormal/unit/SetUp.sol b/test/LogNormal/unit/SetUp.sol index 5528842d..8c4428b1 100644 --- a/test/LogNormal/unit/SetUp.sol +++ b/test/LogNormal/unit/SetUp.sol @@ -19,11 +19,24 @@ contract LogNormalSetUp is SetUp { controller: address(this) }); + LogNormalParams defaultParamsDeep = LogNormalParams({ + mean: ONE, + width: 0.25 ether, + swapFee: TEST_SWAP_FEE, + controller: address(this) + }); + uint256 defaultReserveX = ONE; + uint256 defaultReserveXDeep = ONE * 10_000_000; + uint256 defaultPrice = ONE; bytes defaultInitialPoolData = computeInitialPoolData(defaultReserveX, defaultPrice, defaultParams); + bytes defaultInitialPoolDataDeep = computeInitialPoolData( + defaultReserveXDeep, defaultPrice, defaultParamsDeep + ); + function setUp() public override { SetUp.setUp(); logNormal = new LogNormal(address(dfmm)); @@ -52,6 +65,28 @@ contract LogNormalSetUp is SetUp { _; } + modifier deep() { + vm.warp(0); + + address[] memory tokens = new address[](2); + tokens[0] = address(tokenX); + tokens[1] = address(tokenY); + + InitParams memory defaultInitParamsDeep = InitParams({ + name: "", + symbol: "", + strategy: address(logNormal), + tokens: tokens, + data: defaultInitialPoolDataDeep, + feeCollector: address(0), + controllerFee: 0 + }); + + (POOL_ID,,) = dfmm.init(defaultInitParamsDeep); + + _; + } + modifier initRealistic() { vm.warp(0); diff --git a/test/LogNormal/unit/Swap.t.sol b/test/LogNormal/unit/Swap.t.sol index a4b54fdf..2c43abb3 100644 --- a/test/LogNormal/unit/Swap.t.sol +++ b/test/LogNormal/unit/Swap.t.sol @@ -99,4 +99,33 @@ contract LogNormalSwapTest is LogNormalSetUp { vm.expectRevert(); dfmm.swap(POOL_ID, payload); } + function test_LogNormal_swap_ChargesCorrectFeesYIn() public deep { + uint256 amountIn = 1 ether; + bool swapXForY = false; + + (bool valid,,, bytes memory payload) = + solver.simulateSwap(POOL_ID, swapXForY, amountIn); + + (,, uint256 inputAmount, uint256 outputAmount) = + dfmm.swap(POOL_ID, payload); + + console2.log(inputAmount); + console2.log(outputAmount); + + } + + function test_LogNormal_swap_ChargesCorrectFeesXIn() public deep { + uint256 amountIn = 1 ether; + bool swapXForY = true; + + (bool valid,,, bytes memory payload) = + solver.simulateSwap(POOL_ID, swapXForY, amountIn); + + (,, uint256 inputAmount, uint256 outputAmount) = + dfmm.swap(POOL_ID, payload); + + console2.log(inputAmount); + console2.log(outputAmount); + + } } diff --git a/test/utils/SetUp.sol b/test/utils/SetUp.sol index c41ffc36..094ff6be 100644 --- a/test/utils/SetUp.sol +++ b/test/utils/SetUp.sol @@ -17,8 +17,8 @@ contract SetUp is Test { function setUp() public virtual { tokenX = new MockERC20("Test Token X", "TSTX", 18); tokenY = new MockERC20("Test Token Y", "TSTY", 18); - tokenX.mint(address(this), 100_000e18); - tokenY.mint(address(this), 100_000e18); + tokenX.mint(address(this), 10_000_000_000_000e18); + tokenY.mint(address(this), 10_000_000_000_000e18); weth = new WETH(); dfmm = new DFMM(address(weth));