diff --git a/include/net/inet6_hashtables.h b/include/net/inet6_hashtables.h index 9e34c877a77093..7ca75cbbf75e20 100644 --- a/include/net/inet6_hashtables.h +++ b/include/net/inet6_hashtables.h @@ -71,6 +71,8 @@ extern struct sock *__inet6_lookup_established(struct net *net, extern struct sock *inet6_lookup_listener(struct net *net, struct inet_hashinfo *hashinfo, + const struct in6_addr *saddr, + const __be16 sport, const struct in6_addr *daddr, const unsigned short hnum, const int dif); @@ -88,7 +90,8 @@ static inline struct sock *__inet6_lookup(struct net *net, if (sk) return sk; - return inet6_lookup_listener(net, hashinfo, daddr, hnum, dif); + return inet6_lookup_listener(net, hashinfo, saddr, sport, + daddr, hnum, dif); } static inline struct sock *__inet6_lookup_skb(struct inet_hashinfo *hashinfo, diff --git a/include/net/netfilter/nf_tproxy_core.h b/include/net/netfilter/nf_tproxy_core.h index 193796445642a0..36d9379d4c4b49 100644 --- a/include/net/netfilter/nf_tproxy_core.h +++ b/include/net/netfilter/nf_tproxy_core.h @@ -152,6 +152,7 @@ nf_tproxy_get_sock_v6(struct net *net, const u8 protocol, break; case NFT_LOOKUP_LISTENER: sk = inet6_lookup_listener(net, &tcp_hashinfo, + saddr, sport, daddr, ntohs(dport), in->ifindex); diff --git a/net/ipv6/inet6_connection_sock.c b/net/ipv6/inet6_connection_sock.c index 30647857a375bc..e4297a393678c5 100644 --- a/net/ipv6/inet6_connection_sock.c +++ b/net/ipv6/inet6_connection_sock.c @@ -32,6 +32,9 @@ int inet6_csk_bind_conflict(const struct sock *sk, { const struct sock *sk2; const struct hlist_node *node; + int reuse = sk->sk_reuse; + int reuseport = sk->sk_reuseport; + int uid = sock_i_uid((struct sock *)sk); /* We must walk the whole port owner list in this case. -DaveM */ /* @@ -42,11 +45,17 @@ int inet6_csk_bind_conflict(const struct sock *sk, if (sk != sk2 && (!sk->sk_bound_dev_if || !sk2->sk_bound_dev_if || - sk->sk_bound_dev_if == sk2->sk_bound_dev_if) && - (!sk->sk_reuse || !sk2->sk_reuse || - sk2->sk_state == TCP_LISTEN) && - ipv6_rcv_saddr_equal(sk, sk2)) - break; + sk->sk_bound_dev_if == sk2->sk_bound_dev_if)) { + if ((!reuse || !sk2->sk_reuse || + sk2->sk_state == TCP_LISTEN) && + (!reuseport || !sk2->sk_reuseport || + (sk2->sk_state != TCP_TIME_WAIT && + !uid_eq(uid, + sock_i_uid((struct sock *)sk2))))) { + if (ipv6_rcv_saddr_equal(sk, sk2)) + break; + } + } } return node != NULL; diff --git a/net/ipv6/inet6_hashtables.c b/net/ipv6/inet6_hashtables.c index dea17fd28e5037..32b4a1675d826d 100644 --- a/net/ipv6/inet6_hashtables.c +++ b/net/ipv6/inet6_hashtables.c @@ -158,25 +158,38 @@ static inline int compute_score(struct sock *sk, struct net *net, } struct sock *inet6_lookup_listener(struct net *net, - struct inet_hashinfo *hashinfo, const struct in6_addr *daddr, + struct inet_hashinfo *hashinfo, const struct in6_addr *saddr, + const __be16 sport, const struct in6_addr *daddr, const unsigned short hnum, const int dif) { struct sock *sk; const struct hlist_nulls_node *node; struct sock *result; - int score, hiscore; + int score, hiscore, matches = 0, reuseport = 0; + u32 phash = 0; unsigned int hash = inet_lhashfn(net, hnum); struct inet_listen_hashbucket *ilb = &hashinfo->listening_hash[hash]; rcu_read_lock(); begin: result = NULL; - hiscore = -1; + hiscore = 0; sk_nulls_for_each(sk, node, &ilb->head) { score = compute_score(sk, net, hnum, daddr, dif); if (score > hiscore) { hiscore = score; result = sk; + reuseport = sk->sk_reuseport; + if (reuseport) { + phash = inet6_ehashfn(net, daddr, hnum, + saddr, sport); + matches = 1; + } + } else if (score == hiscore && reuseport) { + matches++; + if (((u64)phash * matches) >> 32 == 0) + result = sk; + phash = next_pseudo_random32(phash); } } /* diff --git a/net/ipv6/tcp_ipv6.c b/net/ipv6/tcp_ipv6.c index 3701c3c6e2eb72..06087e58738a55 100644 --- a/net/ipv6/tcp_ipv6.c +++ b/net/ipv6/tcp_ipv6.c @@ -834,7 +834,8 @@ static void tcp_v6_send_reset(struct sock *sk, struct sk_buff *skb) * no RST generated if md5 hash doesn't match. */ sk1 = inet6_lookup_listener(dev_net(skb_dst(skb)->dev), - &tcp_hashinfo, &ipv6h->daddr, + &tcp_hashinfo, &ipv6h->saddr, + th->source, &ipv6h->daddr, ntohs(th->source), inet6_iif(skb)); if (!sk1) return; @@ -1598,6 +1599,7 @@ static int tcp_v6_rcv(struct sk_buff *skb) struct sock *sk2; sk2 = inet6_lookup_listener(dev_net(skb->dev), &tcp_hashinfo, + &ipv6_hdr(skb)->saddr, th->source, &ipv6_hdr(skb)->daddr, ntohs(th->dest), inet6_iif(skb)); if (sk2 != NULL) {