diff --git a/net/l2tp/l2tp_core.c b/net/l2tp/l2tp_core.c index 9a1415fe3fa78b..b6554e32bb12ae 100644 --- a/net/l2tp/l2tp_core.c +++ b/net/l2tp/l2tp_core.c @@ -104,9 +104,9 @@ static struct workqueue_struct *l2tp_wq; /* per-net private data for this module */ static unsigned int l2tp_net_id; struct l2tp_net { - struct list_head l2tp_tunnel_list; - /* Lock for write access to l2tp_tunnel_list */ - spinlock_t l2tp_tunnel_list_lock; + /* Lock for write access to l2tp_tunnel_idr */ + spinlock_t l2tp_tunnel_idr_lock; + struct idr l2tp_tunnel_idr; struct hlist_head l2tp_session_hlist[L2TP_HASH_SIZE_2]; /* Lock for write access to l2tp_session_hlist */ spinlock_t l2tp_session_hlist_lock; @@ -208,13 +208,10 @@ struct l2tp_tunnel *l2tp_tunnel_get(const struct net *net, u32 tunnel_id) struct l2tp_tunnel *tunnel; rcu_read_lock_bh(); - list_for_each_entry_rcu(tunnel, &pn->l2tp_tunnel_list, list) { - if (tunnel->tunnel_id == tunnel_id && - refcount_inc_not_zero(&tunnel->ref_count)) { - rcu_read_unlock_bh(); - - return tunnel; - } + tunnel = idr_find(&pn->l2tp_tunnel_idr, tunnel_id); + if (tunnel && refcount_inc_not_zero(&tunnel->ref_count)) { + rcu_read_unlock_bh(); + return tunnel; } rcu_read_unlock_bh(); @@ -224,13 +221,14 @@ EXPORT_SYMBOL_GPL(l2tp_tunnel_get); struct l2tp_tunnel *l2tp_tunnel_get_nth(const struct net *net, int nth) { - const struct l2tp_net *pn = l2tp_pernet(net); + struct l2tp_net *pn = l2tp_pernet(net); + unsigned long tunnel_id, tmp; struct l2tp_tunnel *tunnel; int count = 0; rcu_read_lock_bh(); - list_for_each_entry_rcu(tunnel, &pn->l2tp_tunnel_list, list) { - if (++count > nth && + idr_for_each_entry_ul(&pn->l2tp_tunnel_idr, tunnel, tmp, tunnel_id) { + if (tunnel && ++count > nth && refcount_inc_not_zero(&tunnel->ref_count)) { rcu_read_unlock_bh(); return tunnel; @@ -1043,7 +1041,7 @@ static int l2tp_xmit_core(struct l2tp_session *session, struct sk_buff *skb, uns IPCB(skb)->flags &= ~(IPSKB_XFRM_TUNNEL_SIZE | IPSKB_XFRM_TRANSFORMED | IPSKB_REROUTED); nf_reset_ct(skb); - bh_lock_sock(sk); + bh_lock_sock_nested(sk); if (sock_owned_by_user(sk)) { kfree_skb(skb); ret = NET_XMIT_DROP; @@ -1227,6 +1225,15 @@ static void l2tp_udp_encap_destroy(struct sock *sk) l2tp_tunnel_delete(tunnel); } +static void l2tp_tunnel_remove(struct net *net, struct l2tp_tunnel *tunnel) +{ + struct l2tp_net *pn = l2tp_pernet(net); + + spin_lock_bh(&pn->l2tp_tunnel_idr_lock); + idr_remove(&pn->l2tp_tunnel_idr, tunnel->tunnel_id); + spin_unlock_bh(&pn->l2tp_tunnel_idr_lock); +} + /* Workqueue tunnel deletion function */ static void l2tp_tunnel_del_work(struct work_struct *work) { @@ -1234,7 +1241,6 @@ static void l2tp_tunnel_del_work(struct work_struct *work) del_work); struct sock *sk = tunnel->sock; struct socket *sock = sk->sk_socket; - struct l2tp_net *pn; l2tp_tunnel_closeall(tunnel); @@ -1248,12 +1254,7 @@ static void l2tp_tunnel_del_work(struct work_struct *work) } } - /* Remove the tunnel struct from the tunnel list */ - pn = l2tp_pernet(tunnel->l2tp_net); - spin_lock_bh(&pn->l2tp_tunnel_list_lock); - list_del_rcu(&tunnel->list); - spin_unlock_bh(&pn->l2tp_tunnel_list_lock); - + l2tp_tunnel_remove(tunnel->l2tp_net, tunnel); /* drop initial ref */ l2tp_tunnel_dec_refcount(tunnel); @@ -1384,8 +1385,6 @@ static int l2tp_tunnel_sock_create(struct net *net, return err; } -static struct lock_class_key l2tp_socket_class; - int l2tp_tunnel_create(int fd, int version, u32 tunnel_id, u32 peer_tunnel_id, struct l2tp_tunnel_cfg *cfg, struct l2tp_tunnel **tunnelp) { @@ -1455,12 +1454,19 @@ static int l2tp_validate_socket(const struct sock *sk, const struct net *net, int l2tp_tunnel_register(struct l2tp_tunnel *tunnel, struct net *net, struct l2tp_tunnel_cfg *cfg) { - struct l2tp_tunnel *tunnel_walk; - struct l2tp_net *pn; + struct l2tp_net *pn = l2tp_pernet(net); + u32 tunnel_id = tunnel->tunnel_id; struct socket *sock; struct sock *sk; int ret; + spin_lock_bh(&pn->l2tp_tunnel_idr_lock); + ret = idr_alloc_u32(&pn->l2tp_tunnel_idr, NULL, &tunnel_id, tunnel_id, + GFP_ATOMIC); + spin_unlock_bh(&pn->l2tp_tunnel_idr_lock); + if (ret) + return ret == -ENOSPC ? -EEXIST : ret; + if (tunnel->fd < 0) { ret = l2tp_tunnel_sock_create(net, tunnel->tunnel_id, tunnel->peer_tunnel_id, cfg, @@ -1474,31 +1480,16 @@ int l2tp_tunnel_register(struct l2tp_tunnel *tunnel, struct net *net, } sk = sock->sk; + lock_sock(sk); write_lock_bh(&sk->sk_callback_lock); ret = l2tp_validate_socket(sk, net, tunnel->encap); - if (ret < 0) + if (ret < 0) { + release_sock(sk); goto err_inval_sock; + } rcu_assign_sk_user_data(sk, tunnel); write_unlock_bh(&sk->sk_callback_lock); - tunnel->l2tp_net = net; - pn = l2tp_pernet(net); - - sock_hold(sk); - tunnel->sock = sk; - - spin_lock_bh(&pn->l2tp_tunnel_list_lock); - list_for_each_entry(tunnel_walk, &pn->l2tp_tunnel_list, list) { - if (tunnel_walk->tunnel_id == tunnel->tunnel_id) { - spin_unlock_bh(&pn->l2tp_tunnel_list_lock); - sock_put(sk); - ret = -EEXIST; - goto err_sock; - } - } - list_add_rcu(&tunnel->list, &pn->l2tp_tunnel_list); - spin_unlock_bh(&pn->l2tp_tunnel_list_lock); - if (tunnel->encap == L2TP_ENCAPTYPE_UDP) { struct udp_tunnel_sock_cfg udp_cfg = { .sk_user_data = tunnel, @@ -1512,9 +1503,16 @@ int l2tp_tunnel_register(struct l2tp_tunnel *tunnel, struct net *net, tunnel->old_sk_destruct = sk->sk_destruct; sk->sk_destruct = &l2tp_tunnel_destruct; - lockdep_set_class_and_name(&sk->sk_lock.slock, &l2tp_socket_class, - "l2tp_sock"); sk->sk_allocation = GFP_ATOMIC; + release_sock(sk); + + sock_hold(sk); + tunnel->sock = sk; + tunnel->l2tp_net = net; + + spin_lock_bh(&pn->l2tp_tunnel_idr_lock); + idr_replace(&pn->l2tp_tunnel_idr, tunnel, tunnel->tunnel_id); + spin_unlock_bh(&pn->l2tp_tunnel_idr_lock); trace_register_tunnel(tunnel); @@ -1523,9 +1521,6 @@ int l2tp_tunnel_register(struct l2tp_tunnel *tunnel, struct net *net, return 0; -err_sock: - write_lock_bh(&sk->sk_callback_lock); - rcu_assign_sk_user_data(sk, NULL); err_inval_sock: write_unlock_bh(&sk->sk_callback_lock); @@ -1534,6 +1529,7 @@ int l2tp_tunnel_register(struct l2tp_tunnel *tunnel, struct net *net, else sockfd_put(sock); err: + l2tp_tunnel_remove(net, tunnel); return ret; } EXPORT_SYMBOL_GPL(l2tp_tunnel_register); @@ -1647,8 +1643,8 @@ static __net_init int l2tp_init_net(struct net *net) struct l2tp_net *pn = net_generic(net, l2tp_net_id); int hash; - INIT_LIST_HEAD(&pn->l2tp_tunnel_list); - spin_lock_init(&pn->l2tp_tunnel_list_lock); + idr_init(&pn->l2tp_tunnel_idr); + spin_lock_init(&pn->l2tp_tunnel_idr_lock); for (hash = 0; hash < L2TP_HASH_SIZE_2; hash++) INIT_HLIST_HEAD(&pn->l2tp_session_hlist[hash]); @@ -1662,11 +1658,13 @@ static __net_exit void l2tp_exit_net(struct net *net) { struct l2tp_net *pn = l2tp_pernet(net); struct l2tp_tunnel *tunnel = NULL; + unsigned long tunnel_id, tmp; int hash; rcu_read_lock_bh(); - list_for_each_entry_rcu(tunnel, &pn->l2tp_tunnel_list, list) { - l2tp_tunnel_delete(tunnel); + idr_for_each_entry_ul(&pn->l2tp_tunnel_idr, tunnel, tmp, tunnel_id) { + if (tunnel) + l2tp_tunnel_delete(tunnel); } rcu_read_unlock_bh(); @@ -1676,6 +1674,7 @@ static __net_exit void l2tp_exit_net(struct net *net) for (hash = 0; hash < L2TP_HASH_SIZE_2; hash++) WARN_ON_ONCE(!hlist_empty(&pn->l2tp_session_hlist[hash])); + idr_destroy(&pn->l2tp_tunnel_idr); } static struct pernet_operations l2tp_net_ops = {