Skip to content

Commit

Permalink
fix: CLN plugins for v24.02
Browse files Browse the repository at this point in the history
  • Loading branch information
michael1011 committed Mar 5, 2024
1 parent e7ca60a commit eef0544
Show file tree
Hide file tree
Showing 12 changed files with 127 additions and 83 deletions.
2 changes: 1 addition & 1 deletion tools/plugins/hold/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class Network(StrEnum):


PLUGIN_NAME = "hold"
VERSION = "0.0.4"
VERSION = "0.0.5"

TIMEOUT_CANCEL = 60
TIMEOUT_CANCEL_REGTEST = 5
Expand Down
1 change: 0 additions & 1 deletion tools/plugins/hold/hold.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def __init__(self, plugin: Plugin) -> None:
def init(self) -> None:
self.handler.init()
self._encoder.init()
self._route_hints.init()

def invoice(
self,
Expand Down
25 changes: 12 additions & 13 deletions tools/plugins/hold/route_hints.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from bolt11.models.routehint import Route, RouteHint
from pyln.client import Plugin
from pyln.client import Plugin, RpcError


class RouteHints:
Expand All @@ -8,25 +8,24 @@ class RouteHints:
def __init__(self, plugin: Plugin) -> None:
self._plugin = plugin

def init(self) -> None:
self._id = self._plugin.rpc.getinfo()["id"]

def get_private_channels(self, node: str) -> list[RouteHint]:
chans = self._plugin.rpc.listchannels(destination=self._id)["channels"]
try:
chans = self._plugin.rpc.listpeerchannels(peer_id=node)["channels"]
except RpcError:
return []

return [
RouteHint(
[
Route(
public_key=chan["source"],
public_key=node,
short_channel_id=chan["short_channel_id"],
base_fee=chan["base_fee_millisatoshi"],
ppm_fee=chan["fee_per_millionth"],
cltv_expiry_delta=chan["delay"],
base_fee=int(chan["updates"]["remote"]["fee_base_msat"]),
ppm_fee=chan["updates"]["remote"]["fee_proportional_millionths"],
cltv_expiry_delta=chan["updates"]["remote"]["cltv_expiry_delta"],
)
]
)
for chan in filter(
lambda chan: not chan["public"] and chan["source"] == node,
chans,
)
for chan in chans
if chan["private"]
]
15 changes: 10 additions & 5 deletions tools/plugins/hold/tests/test_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
LndPay,
cln_con,
connect_peers,
get_channel_info,
lnd,
start_plugin,
stop_plugin,
Expand Down Expand Up @@ -223,11 +222,17 @@ def test_routing_hints(self, cl: HoldStub) -> None:

hop = hops[0]

channel_info = get_channel_info(lnd_pubkey, hop.short_channel_id)
channel_info = next(
chan
for chan in cln_con(f"listpeerchannels {lnd_pubkey}")["channels"]
if chan["short_channel_id"] == hop.short_channel_id
)

updates = channel_info["updates"]["remote"]

assert hop.cltv_expiry_delta == channel_info["delay"]
assert hop.ppm_fee == channel_info["fee_per_millionth"]
assert hop.base_fee == channel_info["base_fee_millisatoshi"]
assert hop.base_fee == updates["fee_base_msat"]
assert hop.cltv_expiry_delta == updates["cltv_expiry_delta"]
assert hop.ppm_fee == updates["fee_proportional_millionths"]
assert hop.short_channel_id == channel_info["short_channel_id"]

def test_routing_hints_none_found(self, cl: HoldStub) -> None:
Expand Down
15 changes: 10 additions & 5 deletions tools/plugins/hold/tests/test_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
cln_con,
connect_peers,
format_json,
get_channel_info,
lnd,
start_plugin,
stop_plugin,
Expand Down Expand Up @@ -656,11 +655,17 @@ def test_routinghints(self, cln: CliCaller) -> None:

route = routes["routes"][0]

channel_info = get_channel_info(lnd_pubkey, route["short_channel_id"])
channel_info = next(
chan
for chan in cln_con(f"listpeerchannels {lnd_pubkey}")["channels"]
if chan["short_channel_id"] == route["short_channel_id"]
)

updates = channel_info["updates"]["remote"]

assert route["cltv_expiry_delta"] == channel_info["delay"]
assert route["ppm_fee"] == channel_info["fee_per_millionth"]
assert route["base_fee"] == channel_info["base_fee_millisatoshi"]
assert route["base_fee"] == updates["fee_base_msat"]
assert route["cltv_expiry_delta"] == updates["cltv_expiry_delta"]
assert route["ppm_fee"] == updates["fee_proportional_millionths"]
assert route["short_channel_id"] == channel_info["short_channel_id"]

def test_routinghints_none_found(self, cln: CliCaller) -> None:
Expand Down
28 changes: 16 additions & 12 deletions tools/plugins/hold/tests/test_route_hints.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,11 @@
from plugins.hold.route_hints import RouteHints
from plugins.hold.tests.utils import LndNode, RpcPlugin, cln_con, get_channel_info, lnd
from plugins.hold.tests.utils import LndNode, RpcPlugin, cln_con, lnd


class TestRouteHints:
# noinspection PyTypeChecker
rh = RouteHints(RpcPlugin())

def test_init(self) -> None:
self.rh.init()

assert self.rh._plugin is not None # noqa: SLF001
assert self.rh._id == cln_con("getinfo")["id"] # noqa: SLF001

def test_get_private_channels(self) -> None:
other_pubkey = lnd(LndNode.One, "getinfo")["identity_pubkey"]

Expand All @@ -24,12 +18,22 @@ def test_get_private_channels(self) -> None:
route = hint.routes[0]
assert route.public_key == other_pubkey

channel_info = get_channel_info(other_pubkey, route.short_channel_id)
channel_info = next(
chan
for chan in cln_con(f"listpeerchannels {other_pubkey}")["channels"]
if chan["private"]
)
updates = channel_info["updates"]["remote"]

assert route.cltv_expiry_delta == channel_info["delay"]
assert route.ppm_fee == channel_info["fee_per_millionth"]
assert route.base_fee == channel_info["base_fee_millisatoshi"]
assert route.base_fee == updates["fee_base_msat"]
assert route.cltv_expiry_delta == updates["cltv_expiry_delta"]
assert route.ppm_fee == updates["fee_proportional_millionths"]
assert route.short_channel_id == channel_info["short_channel_id"]

def test_get_private_channels_none_found(self) -> None:
assert self.rh.get_private_channels("not found") == []
assert (
self.rh.get_private_channels(
"0394c0450766d4029e980dd2934fbc4ca665222e3149c2a4a7b8a6251544a12033"
)
== []
)
9 changes: 2 additions & 7 deletions tools/plugins/hold/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ def getinfo() -> dict:
return cln_con("getinfo")

@staticmethod
def listpeerchannels() -> dict:
return cln_con("listpeerchannels")
def listpeerchannels(peer_id: str | None = None) -> dict:
return cln_con(f"listpeerchannels {peer_id if peer_id is not None else ''}")

@staticmethod
def listchannels(**kwargs: dict[str, str]) -> dict:
Expand Down Expand Up @@ -150,11 +150,6 @@ def cln_con(*args: str) -> dict[str, Any]:
)


def get_channel_info(node: str, short_chan_id: str | int) -> dict[str, Any]:
channel_infos = cln_con("listchannels", "-k", f"short_channel_id={short_chan_id}")["channels"]
return channel_infos[0] if channel_infos[0]["source"] == node else channel_infos[1]


class TestUtils:
@pytest.mark.parametrize(
("data", "result"),
Expand Down
2 changes: 1 addition & 1 deletion tools/plugins/mpay/consts.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
PLUGIN_NAME = "mpay"
VERSION = "0.1.0"
VERSION = "0.1.1"
56 changes: 49 additions & 7 deletions tools/plugins/mpay/data/network_info.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,25 @@
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import Any

from cachetools import TTLCache
from pyln.client import Plugin


@dataclass
class ChannelInfo:
fee_per_millionth: int
base_fee_millisatoshi: int

@staticmethod
def from_listchannels(channel: dict[str, Any]) -> "ChannelInfo":
return ChannelInfo(channel["fee_per_millionth"], channel["base_fee_millisatoshi"])

@staticmethod
def from_peerchannels(channel: dict[str, Any]) -> "ChannelInfo":
return ChannelInfo(channel["fee_proportional_millionths"], channel["fee_base_msat"])


class NetworkInfo:
_pl: Plugin
_alias_cache: TTLCache
Expand All @@ -23,14 +38,41 @@ def get_node_alias(self, pubkey: str) -> str:
self._alias_cache[pubkey] = alias
return alias

def get_channel_info_side(self, short_channel_id: str, side: int) -> dict[str, Any]:
channel = self.get_channel_info(short_channel_id)
return channel[0] if channel[0]["direction"] == side else channel[1]
def get_channel_info_side(self, short_channel_id: str, side: int) -> ChannelInfo:
channel = self._get_channel_info(short_channel_id, side)
if channel is not None:
return channel

# 24.02 removed private channels from listchannels
channel = self._get_peer_channel_info(short_channel_id, side)
if channel is not None:
return channel

msg = f"no channel with id {short_channel_id}"
raise ValueError(msg)

def get_channel_info(self, short_channel_id: str) -> list[dict[str, Any]]:
def _get_channel_info(self, short_channel_id: str, side: int) -> ChannelInfo | None:
channel = self._pl.rpc.listchannels(short_channel_id=short_channel_id)["channels"]
if len(channel) == 0:
msg = f"no channel with id {short_channel_id}"
raise ValueError(msg)
return None

return ChannelInfo.from_listchannels(
channel[0] if channel[0]["direction"] == side else channel[1]
)

def _get_peer_channel_info(self, short_channel_id: str, side: int) -> ChannelInfo | None:
channels = self._pl.rpc.listpeerchannels()["channels"]

channel = None

for chan in channels:
if chan["short_channel_id"] == short_channel_id:
channel = chan
break

if channel is None:
return None

return channel
return ChannelInfo.from_peerchannels(
channel["updates"]["local" if side == channel["direction"] else "remote"]
)
35 changes: 22 additions & 13 deletions tools/plugins/mpay/data/tests/test_network_info.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest

from plugins.hold.tests.utils import LndNode, RpcPlugin, cln_con, lnd
from plugins.mpay.data.network_info import NetworkInfo
from plugins.mpay.data.network_info import ChannelInfo, NetworkInfo


class TestNetworkInfo:
Expand All @@ -25,21 +25,30 @@ def test_get_node_alias_not_found(self) -> None:

def test_get_channel_info(self) -> None:
channel_id = cln_con("listchannels")["channels"][0]["short_channel_id"]
assert (
self.ni.get_channel_info(channel_id)
== cln_con(f"listchannels {channel_id}")["channels"]
assert self.ni._get_channel_info(channel_id, 0) == ChannelInfo.from_listchannels( # noqa: SLF001
cln_con(f"listchannels {channel_id}")["channels"][0]
)

def test_get_channel_info_not_found(self) -> None:
channel_id = "811759x3111x1"

with pytest.raises(ValueError, match=f"no channel with id {channel_id}"):
self.ni.get_channel_info(channel_id)

@pytest.mark.parametrize("side", [0, 1])
def test_get_channel_info_side(self, side: int) -> None:
channel_id = cln_con("listchannels")["channels"][0]["short_channel_id"]
assert (
self.ni.get_channel_info_side(channel_id, side)
== cln_con(f"listchannels {channel_id}")["channels"][side]
assert self.ni.get_channel_info_side(channel_id, side) == ChannelInfo.from_listchannels(
cln_con(f"listchannels {channel_id}")["channels"][side]
)

@pytest.mark.parametrize("side", [0, 1])
def test_get_channel_info_side_peer_channel(self, side: int) -> None:
peer_channels = cln_con("listpeerchannels")["channels"]
channel = next(chan for chan in peer_channels if chan["private"])

assert self.ni.get_channel_info_side(
channel["short_channel_id"], side
) == ChannelInfo.from_peerchannels(
channel["updates"]["local" if side == channel["direction"] else "remote"]
)

def test_get_channel_info_not_found(self) -> None:
channel_id = "811759x3111x1"

with pytest.raises(ValueError, match=f"no channel with id {channel_id}"):
self.ni.get_channel_info_side(channel_id, 1)
8 changes: 4 additions & 4 deletions tools/plugins/mpay/pay/route.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from pyln.client import Millisatoshi

from plugins.mpay.data.network_info import NetworkInfo
from plugins.mpay.data.network_info import ChannelInfo, NetworkInfo


@dataclass
Expand All @@ -13,10 +13,10 @@ class Fees:
proportional_millionths: int

@staticmethod
def from_channel_info(info: dict[str, Any]) -> "Fees":
def from_channel_info(info: ChannelInfo) -> "Fees":
return Fees(
base_msat=info["base_fee_millisatoshi"],
proportional_millionths=info["fee_per_millionth"],
base_msat=info.base_fee_millisatoshi,
proportional_millionths=info.fee_per_millionth,
)


Expand Down
14 changes: 0 additions & 14 deletions tools/plugins/mpay/pay/tests/test_route.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,6 @@


class TestRoute:
@pytest.mark.parametrize(
"info",
[
{"base_fee_millisatoshi": 0, "fee_per_millionth": 1},
{"base_fee_millisatoshi": 123, "fee_per_millionth": 0},
{"base_fee_millisatoshi": 421, "fee_per_millionth": 60},
],
)
def test_fees_from_channel_info(self, info: dict[str, Any]) -> None:
fees = Fees.from_channel_info(info)

assert fees.base_msat == info["base_fee_millisatoshi"]
assert fees.proportional_millionths == info["fee_per_millionth"]

def test_route_fees_mismatch(self) -> None:
with pytest.raises(
ValueError,
Expand Down

0 comments on commit eef0544

Please sign in to comment.