diff --git a/src/ConstantSum/ConstantSum.sol b/src/ConstantSum/ConstantSum.sol index 08a1c27d..aaa2d170 100644 --- a/src/ConstantSum/ConstantSum.sol +++ b/src/ConstantSum/ConstantSum.sol @@ -63,6 +63,7 @@ contract ConstantSum is PairStrategy { internalParams[poolId].price = params.price; internalParams[poolId].swapFee = params.swapFee; + internalParams[poolId].controller = params.controller; // Get the trading function and check this is valid invariant = @@ -105,6 +106,7 @@ contract ConstantSum is PairStrategy { params.price = internalParams[poolId].price; params.swapFee = internalParams[poolId].swapFee; + params.controller = internalParams[poolId].controller; return abi.encode(params); } diff --git a/test/ConstantSum/unit/GetPoolParams.t.sol b/test/ConstantSum/unit/GetPoolParams.t.sol index 8381be67..67f74d09 100644 --- a/test/ConstantSum/unit/GetPoolParams.t.sol +++ b/test/ConstantSum/unit/GetPoolParams.t.sol @@ -1,6 +1,43 @@ // SPDX-License-Identifier: UNLICENSED pragma solidity ^0.8.13; -import { ConstantSumSetUp } from "./SetUp.sol"; +import { ConstantSumSetUp, ConstantSumParams, InitParams } from "./SetUp.sol"; -contract ConstantSumGetPoolParamsTest is ConstantSumSetUp { } +contract ConstantSumGetPoolParamsTest is ConstantSumSetUp { + function test_ConstantSum_getPoolParams() public { + ConstantSumParams memory initialParams = ConstantSumParams({ + price: 1 ether, + swapFee: TEST_SWAP_FEE, + controller: address(this) + }); + + uint256 reserveX = 1 ether; + uint256 reserveY = 1 ether; + + bytes memory initData = + solver.getInitialPoolData(reserveX, reserveY, initialParams); + + address[] memory tokens = new address[](2); + tokens[0] = address(tokenX); + tokens[1] = address(tokenY); + + InitParams memory initParams = InitParams({ + name: "", + symbol: "", + strategy: address(constantSum), + tokens: tokens, + data: initData, + feeCollector: address(0), + controllerFee: 0 + }); + + (uint256 poolId,,) = dfmm.init(initParams); + + ConstantSumParams memory poolParams = + abi.decode(constantSum.getPoolParams(poolId), (ConstantSumParams)); + + assertEq(poolParams.price, initialParams.price); + assertEq(poolParams.swapFee, initialParams.swapFee); + assertEq(poolParams.controller, initialParams.controller); + } +}