aboutsummaryrefslogtreecommitdiff
path: root/net/ipv4/inet_hashtables.c
diff options
context:
space:
mode:
Diffstat (limited to 'net/ipv4/inet_hashtables.c')
-rw-r--r--net/ipv4/inet_hashtables.c329
1 files changed, 219 insertions, 110 deletions
diff --git a/net/ipv4/inet_hashtables.c b/net/ipv4/inet_hashtables.c
index a5d57fa679ca..e8de5e699b3f 100644
--- a/net/ipv4/inet_hashtables.c
+++ b/net/ipv4/inet_hashtables.c
@@ -81,6 +81,41 @@ struct inet_bind_bucket *inet_bind_bucket_create(struct kmem_cache *cachep,
return tb;
}
+struct inet_bind2_bucket *inet_bind2_bucket_create(struct kmem_cache *cachep,
+ struct net *net,
+ struct inet_bind2_hashbucket *head,
+ const unsigned short port,
+ int l3mdev,
+ const struct sock *sk)
+{
+ struct inet_bind2_bucket *tb = kmem_cache_alloc(cachep, GFP_ATOMIC);
+
+ if (tb) {
+ write_pnet(&tb->ib_net, net);
+ tb->l3mdev = l3mdev;
+ tb->port = port;
+#if IS_ENABLED(CONFIG_IPV6)
+ if (sk->sk_family == AF_INET6)
+ tb->v6_rcv_saddr = sk->sk_v6_rcv_saddr;
+ else
+#endif
+ tb->rcv_saddr = sk->sk_rcv_saddr;
+ INIT_HLIST_HEAD(&tb->owners);
+ hlist_add_head(&tb->node, &head->chain);
+ }
+ return tb;
+}
+
+static bool bind2_bucket_addr_match(struct inet_bind2_bucket *tb2, struct sock *sk)
+{
+#if IS_ENABLED(CONFIG_IPV6)
+ if (sk->sk_family == AF_INET6)
+ return ipv6_addr_equal(&tb2->v6_rcv_saddr,
+ &sk->sk_v6_rcv_saddr);
+#endif
+ return tb2->rcv_saddr == sk->sk_rcv_saddr;
+}
+
/*
* Caller must hold hashbucket lock for this tb with local BH disabled
*/
@@ -92,12 +127,25 @@ void inet_bind_bucket_destroy(struct kmem_cache *cachep, struct inet_bind_bucket
}
}
+/* Caller must hold the lock for the corresponding hashbucket in the bhash table
+ * with local BH disabled
+ */
+void inet_bind2_bucket_destroy(struct kmem_cache *cachep, struct inet_bind2_bucket *tb)
+{
+ if (hlist_empty(&tb->owners)) {
+ __hlist_del(&tb->node);
+ kmem_cache_free(cachep, tb);
+ }
+}
+
void inet_bind_hash(struct sock *sk, struct inet_bind_bucket *tb,
- const unsigned short snum)
+ struct inet_bind2_bucket *tb2, const unsigned short snum)
{
inet_sk(sk)->inet_num = snum;
sk_add_bind_node(sk, &tb->owners);
inet_csk(sk)->icsk_bind_hash = tb;
+ sk_add_bind2_node(sk, &tb2->owners);
+ inet_csk(sk)->icsk_bind2_hash = tb2;
}
/*
@@ -109,6 +157,7 @@ static void __inet_put_port(struct sock *sk)
const int bhash = inet_bhashfn(sock_net(sk), inet_sk(sk)->inet_num,
hashinfo->bhash_size);
struct inet_bind_hashbucket *head = &hashinfo->bhash[bhash];
+ struct inet_bind2_bucket *tb2;
struct inet_bind_bucket *tb;
spin_lock(&head->lock);
@@ -117,6 +166,13 @@ static void __inet_put_port(struct sock *sk)
inet_csk(sk)->icsk_bind_hash = NULL;
inet_sk(sk)->inet_num = 0;
inet_bind_bucket_destroy(hashinfo->bind_bucket_cachep, tb);
+
+ if (inet_csk(sk)->icsk_bind2_hash) {
+ tb2 = inet_csk(sk)->icsk_bind2_hash;
+ __sk_del_bind2_node(sk);
+ inet_csk(sk)->icsk_bind2_hash = NULL;
+ inet_bind2_bucket_destroy(hashinfo->bind2_bucket_cachep, tb2);
+ }
spin_unlock(&head->lock);
}
@@ -133,14 +189,19 @@ int __inet_inherit_port(const struct sock *sk, struct sock *child)
struct inet_hashinfo *table = sk->sk_prot->h.hashinfo;
unsigned short port = inet_sk(child)->inet_num;
const int bhash = inet_bhashfn(sock_net(sk), port,
- table->bhash_size);
+ table->bhash_size);
struct inet_bind_hashbucket *head = &table->bhash[bhash];
+ struct inet_bind2_hashbucket *head_bhash2;
+ bool created_inet_bind_bucket = false;
+ struct net *net = sock_net(sk);
+ struct inet_bind2_bucket *tb2;
struct inet_bind_bucket *tb;
int l3mdev;
spin_lock(&head->lock);
tb = inet_csk(sk)->icsk_bind_hash;
- if (unlikely(!tb)) {
+ tb2 = inet_csk(sk)->icsk_bind2_hash;
+ if (unlikely(!tb || !tb2)) {
spin_unlock(&head->lock);
return -ENOENT;
}
@@ -153,25 +214,45 @@ int __inet_inherit_port(const struct sock *sk, struct sock *child)
* as that of the child socket. We have to look up or
* create a new bind bucket for the child here. */
inet_bind_bucket_for_each(tb, &head->chain) {
- if (net_eq(ib_net(tb), sock_net(sk)) &&
- tb->l3mdev == l3mdev && tb->port == port)
+ if (check_bind_bucket_match(tb, net, port, l3mdev))
break;
}
if (!tb) {
tb = inet_bind_bucket_create(table->bind_bucket_cachep,
- sock_net(sk), head, port,
- l3mdev);
+ net, head, port, l3mdev);
if (!tb) {
spin_unlock(&head->lock);
return -ENOMEM;
}
+ created_inet_bind_bucket = true;
}
inet_csk_update_fastreuse(tb, child);
+
+ goto bhash2_find;
+ } else if (!bind2_bucket_addr_match(tb2, child)) {
+ l3mdev = inet_sk_bound_l3mdev(sk);
+
+bhash2_find:
+ tb2 = inet_bind2_bucket_find(table, net, port, l3mdev, child,
+ &head_bhash2);
+ if (!tb2) {
+ tb2 = inet_bind2_bucket_create(table->bind2_bucket_cachep,
+ net, head_bhash2, port,
+ l3mdev, child);
+ if (!tb2)
+ goto error;
+ }
}
- inet_bind_hash(child, tb, port);
+ inet_bind_hash(child, tb, tb2, port);
spin_unlock(&head->lock);
return 0;
+
+error:
+ if (created_inet_bind_bucket)
+ inet_bind_bucket_destroy(table->bind_bucket_cachep, tb);
+ spin_unlock(&head->lock);
+ return -ENOMEM;
}
EXPORT_SYMBOL_GPL(__inet_inherit_port);
@@ -193,42 +274,6 @@ inet_lhash2_bucket_sk(struct inet_hashinfo *h, struct sock *sk)
return inet_lhash2_bucket(h, hash);
}
-static void inet_hash2(struct inet_hashinfo *h, struct sock *sk)
-{
- struct inet_listen_hashbucket *ilb2;
-
- if (!h->lhash2)
- return;
-
- ilb2 = inet_lhash2_bucket_sk(h, sk);
-
- spin_lock(&ilb2->lock);
- if (sk->sk_reuseport && sk->sk_family == AF_INET6)
- hlist_add_tail_rcu(&inet_csk(sk)->icsk_listen_portaddr_node,
- &ilb2->head);
- else
- hlist_add_head_rcu(&inet_csk(sk)->icsk_listen_portaddr_node,
- &ilb2->head);
- ilb2->count++;
- spin_unlock(&ilb2->lock);
-}
-
-static void inet_unhash2(struct inet_hashinfo *h, struct sock *sk)
-{
- struct inet_listen_hashbucket *ilb2;
-
- if (!h->lhash2 ||
- WARN_ON_ONCE(hlist_unhashed(&inet_csk(sk)->icsk_listen_portaddr_node)))
- return;
-
- ilb2 = inet_lhash2_bucket_sk(h, sk);
-
- spin_lock(&ilb2->lock);
- hlist_del_init_rcu(&inet_csk(sk)->icsk_listen_portaddr_node);
- ilb2->count--;
- spin_unlock(&ilb2->lock);
-}
-
static inline int compute_score(struct sock *sk, struct net *net,
const unsigned short hnum, const __be32 daddr,
const int dif, const int sdif)
@@ -282,12 +327,11 @@ static struct sock *inet_lhash2_lookup(struct net *net,
const __be32 daddr, const unsigned short hnum,
const int dif, const int sdif)
{
- struct inet_connection_sock *icsk;
struct sock *sk, *result = NULL;
+ struct hlist_nulls_node *node;
int score, hiscore = 0;
- inet_lhash2_for_each_icsk_rcu(icsk, &ilb2->head) {
- sk = (struct sock *)icsk;
+ sk_nulls_for_each_rcu(sk, node, &ilb2->nulls_head) {
score = compute_score(sk, net, hnum, daddr, dif, sdif);
if (score > hiscore) {
result = lookup_reuseport(net, sk, skb, doff,
@@ -410,13 +454,11 @@ begin:
sk_nulls_for_each_rcu(sk, node, &head->chain) {
if (sk->sk_hash != hash)
continue;
- if (likely(INET_MATCH(sk, net, acookie,
- saddr, daddr, ports, dif, sdif))) {
+ if (likely(inet_match(net, sk, acookie, ports, dif, sdif))) {
if (unlikely(!refcount_inc_not_zero(&sk->sk_refcnt)))
goto out;
- if (unlikely(!INET_MATCH(sk, net, acookie,
- saddr, daddr, ports,
- dif, sdif))) {
+ if (unlikely(!inet_match(net, sk, acookie,
+ ports, dif, sdif))) {
sock_gen_put(sk);
goto begin;
}
@@ -465,8 +507,7 @@ static int __inet_check_established(struct inet_timewait_death_row *death_row,
if (sk2->sk_hash != hash)
continue;
- if (likely(INET_MATCH(sk2, net, acookie,
- saddr, daddr, ports, dif, sdif))) {
+ if (likely(inet_match(net, sk2, acookie, ports, dif, sdif))) {
if (sk2->sk_state == TCP_TIME_WAIT) {
tw = inet_twsk(sk2);
if (twsk_unique(sk, sk2, twp))
@@ -532,16 +573,14 @@ static bool inet_ehash_lookup_by_sk(struct sock *sk,
if (esk->sk_hash != sk->sk_hash)
continue;
if (sk->sk_family == AF_INET) {
- if (unlikely(INET_MATCH(esk, net, acookie,
- sk->sk_daddr,
- sk->sk_rcv_saddr,
+ if (unlikely(inet_match(net, esk, acookie,
ports, dif, sdif))) {
return true;
}
}
#if IS_ENABLED(CONFIG_IPV6)
else if (sk->sk_family == AF_INET6) {
- if (unlikely(INET6_MATCH(esk, net,
+ if (unlikely(inet6_match(net, esk,
&sk->sk_v6_daddr,
&sk->sk_v6_rcv_saddr,
ports, dif, sdif))) {
@@ -633,7 +672,7 @@ static int inet_reuseport_add_sock(struct sock *sk,
int __inet_hash(struct sock *sk, struct sock *osk)
{
struct inet_hashinfo *hashinfo = sk->sk_prot->h.hashinfo;
- struct inet_listen_hashbucket *ilb;
+ struct inet_listen_hashbucket *ilb2;
int err = 0;
if (sk->sk_state != TCP_LISTEN) {
@@ -643,25 +682,23 @@ int __inet_hash(struct sock *sk, struct sock *osk)
return 0;
}
WARN_ON(!sk_unhashed(sk));
- ilb = &hashinfo->listening_hash[inet_sk_listen_hashfn(sk)];
+ ilb2 = inet_lhash2_bucket_sk(hashinfo, sk);
- spin_lock(&ilb->lock);
+ spin_lock(&ilb2->lock);
if (sk->sk_reuseport) {
- err = inet_reuseport_add_sock(sk, ilb);
+ err = inet_reuseport_add_sock(sk, ilb2);
if (err)
goto unlock;
}
if (IS_ENABLED(CONFIG_IPV6) && sk->sk_reuseport &&
sk->sk_family == AF_INET6)
- __sk_nulls_add_node_tail_rcu(sk, &ilb->nulls_head);
+ __sk_nulls_add_node_tail_rcu(sk, &ilb2->nulls_head);
else
- __sk_nulls_add_node_rcu(sk, &ilb->nulls_head);
- inet_hash2(hashinfo, sk);
- ilb->count++;
+ __sk_nulls_add_node_rcu(sk, &ilb2->nulls_head);
sock_set_flag(sk, SOCK_RCU_FREE);
sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1);
unlock:
- spin_unlock(&ilb->lock);
+ spin_unlock(&ilb2->lock);
return err;
}
@@ -678,23 +715,6 @@ int inet_hash(struct sock *sk)
}
EXPORT_SYMBOL_GPL(inet_hash);
-static void __inet_unhash(struct sock *sk, struct inet_listen_hashbucket *ilb)
-{
- if (sk_unhashed(sk))
- return;
-
- if (rcu_access_pointer(sk->sk_reuseport_cb))
- reuseport_stop_listen_sock(sk);
- if (ilb) {
- struct inet_hashinfo *hashinfo = sk->sk_prot->h.hashinfo;
-
- inet_unhash2(hashinfo, sk);
- ilb->count--;
- }
- __sk_nulls_del_node_init_rcu(sk);
- sock_prot_inuse_add(sock_net(sk), sk->sk_prot, -1);
-}
-
void inet_unhash(struct sock *sk)
{
struct inet_hashinfo *hashinfo = sk->sk_prot->h.hashinfo;
@@ -703,25 +723,109 @@ void inet_unhash(struct sock *sk)
return;
if (sk->sk_state == TCP_LISTEN) {
- struct inet_listen_hashbucket *ilb;
+ struct inet_listen_hashbucket *ilb2;
- ilb = &hashinfo->listening_hash[inet_sk_listen_hashfn(sk)];
+ ilb2 = inet_lhash2_bucket_sk(hashinfo, sk);
/* Don't disable bottom halves while acquiring the lock to
* avoid circular locking dependency on PREEMPT_RT.
*/
- spin_lock(&ilb->lock);
- __inet_unhash(sk, ilb);
- spin_unlock(&ilb->lock);
+ spin_lock(&ilb2->lock);
+ if (sk_unhashed(sk)) {
+ spin_unlock(&ilb2->lock);
+ return;
+ }
+
+ if (rcu_access_pointer(sk->sk_reuseport_cb))
+ reuseport_stop_listen_sock(sk);
+
+ __sk_nulls_del_node_init_rcu(sk);
+ sock_prot_inuse_add(sock_net(sk), sk->sk_prot, -1);
+ spin_unlock(&ilb2->lock);
} else {
spinlock_t *lock = inet_ehash_lockp(hashinfo, sk->sk_hash);
spin_lock_bh(lock);
- __inet_unhash(sk, NULL);
+ if (sk_unhashed(sk)) {
+ spin_unlock_bh(lock);
+ return;
+ }
+ __sk_nulls_del_node_init_rcu(sk);
+ sock_prot_inuse_add(sock_net(sk), sk->sk_prot, -1);
spin_unlock_bh(lock);
}
}
EXPORT_SYMBOL_GPL(inet_unhash);
+static bool check_bind2_bucket_match(struct inet_bind2_bucket *tb,
+ struct net *net, unsigned short port,
+ int l3mdev, struct sock *sk)
+{
+#if IS_ENABLED(CONFIG_IPV6)
+ if (sk->sk_family == AF_INET6)
+ return net_eq(ib2_net(tb), net) && tb->port == port &&
+ tb->l3mdev == l3mdev &&
+ ipv6_addr_equal(&tb->v6_rcv_saddr, &sk->sk_v6_rcv_saddr);
+ else
+#endif
+ return net_eq(ib2_net(tb), net) && tb->port == port &&
+ tb->l3mdev == l3mdev && tb->rcv_saddr == sk->sk_rcv_saddr;
+}
+
+bool check_bind2_bucket_match_nulladdr(struct inet_bind2_bucket *tb,
+ struct net *net, const unsigned short port,
+ int l3mdev, const struct sock *sk)
+{
+#if IS_ENABLED(CONFIG_IPV6)
+ struct in6_addr nulladdr = {};
+
+ if (sk->sk_family == AF_INET6)
+ return net_eq(ib2_net(tb), net) && tb->port == port &&
+ tb->l3mdev == l3mdev &&
+ ipv6_addr_equal(&tb->v6_rcv_saddr, &nulladdr);
+ else
+#endif
+ return net_eq(ib2_net(tb), net) && tb->port == port &&
+ tb->l3mdev == l3mdev && tb->rcv_saddr == 0;
+}
+
+static struct inet_bind2_hashbucket *
+inet_bhashfn_portaddr(struct inet_hashinfo *hinfo, const struct sock *sk,
+ const struct net *net, unsigned short port)
+{
+ u32 hash;
+
+#if IS_ENABLED(CONFIG_IPV6)
+ if (sk->sk_family == AF_INET6)
+ hash = ipv6_portaddr_hash(net, &sk->sk_v6_rcv_saddr, port);
+ else
+#endif
+ hash = ipv4_portaddr_hash(net, sk->sk_rcv_saddr, port);
+ return &hinfo->bhash2[hash & (hinfo->bhash_size - 1)];
+}
+
+/* This should only be called when the spinlock for the socket's corresponding
+ * bind_hashbucket is held
+ */
+struct inet_bind2_bucket *
+inet_bind2_bucket_find(struct inet_hashinfo *hinfo, struct net *net,
+ const unsigned short port, int l3mdev, struct sock *sk,
+ struct inet_bind2_hashbucket **head)
+{
+ struct inet_bind2_bucket *bhash2 = NULL;
+ struct inet_bind2_hashbucket *h;
+
+ h = inet_bhashfn_portaddr(hinfo, sk, net, port);
+ inet_bind_bucket_for_each(bhash2, &h->chain) {
+ if (check_bind2_bucket_match(bhash2, net, port, l3mdev, sk))
+ break;
+ }
+
+ if (head)
+ *head = h;
+
+ return bhash2;
+}
+
/* RFC 6056 3.3.4. Algorithm 4: Double-Hash Port Selection Algorithm
* Note that we use 32bit integers (vs RFC 'short integers')
* because 2^16 is not a multiple of num_ephemeral and this
@@ -742,10 +846,13 @@ int __inet_hash_connect(struct inet_timewait_death_row *death_row,
{
struct inet_hashinfo *hinfo = death_row->hashinfo;
struct inet_timewait_sock *tw = NULL;
+ struct inet_bind2_hashbucket *head2;
struct inet_bind_hashbucket *head;
int port = inet_sk(sk)->inet_num;
struct net *net = sock_net(sk);
+ struct inet_bind2_bucket *tb2;
struct inet_bind_bucket *tb;
+ bool tb_created = false;
u32 remaining, offset;
int ret, i, low, high;
int l3mdev;
@@ -802,8 +909,7 @@ other_parity_scan:
* the established check is already unique enough.
*/
inet_bind_bucket_for_each(tb, &head->chain) {
- if (net_eq(ib_net(tb), net) && tb->l3mdev == l3mdev &&
- tb->port == port) {
+ if (check_bind_bucket_match(tb, net, port, l3mdev)) {
if (tb->fastreuse >= 0 ||
tb->fastreuseport >= 0)
goto next_port;
@@ -821,6 +927,7 @@ other_parity_scan:
spin_unlock_bh(&head->lock);
return -ENOMEM;
}
+ tb_created = true;
tb->fastreuse = -1;
tb->fastreuseport = -1;
goto ok;
@@ -836,6 +943,17 @@ next_port:
return -EADDRNOTAVAIL;
ok:
+ /* Find the corresponding tb2 bucket since we need to
+ * add the socket to the bhash2 table as well
+ */
+ tb2 = inet_bind2_bucket_find(hinfo, net, port, l3mdev, sk, &head2);
+ if (!tb2) {
+ tb2 = inet_bind2_bucket_create(hinfo->bind2_bucket_cachep, net,
+ head2, port, l3mdev, sk);
+ if (!tb2)
+ goto error;
+ }
+
/* Here we want to add a little bit of randomness to the next source
* port that will be chosen. We use a max() with a random here so that
* on low contention the randomness is maximal and on high contention
@@ -845,7 +963,7 @@ ok:
WRITE_ONCE(table_perturb[index], READ_ONCE(table_perturb[index]) + i + 2);
/* Head lock still held and bh's disabled */
- inet_bind_hash(sk, tb, port);
+ inet_bind_hash(sk, tb, tb2, port);
if (sk_unhashed(sk)) {
inet_sk(sk)->inet_sport = htons(port);
inet_ehash_nolisten(sk, (struct sock *)tw, NULL);
@@ -857,6 +975,12 @@ ok:
inet_twsk_deschedule_put(tw);
local_bh_enable();
return 0;
+
+error:
+ if (tb_created)
+ inet_bind_bucket_destroy(hinfo->bind_bucket_cachep, tb);
+ spin_unlock_bh(&head->lock);
+ return -ENOMEM;
}
/*
@@ -874,29 +998,14 @@ int inet_hash_connect(struct inet_timewait_death_row *death_row,
}
EXPORT_SYMBOL_GPL(inet_hash_connect);
-void inet_hashinfo_init(struct inet_hashinfo *h)
-{
- int i;
-
- for (i = 0; i < INET_LHTABLE_SIZE; i++) {
- spin_lock_init(&h->listening_hash[i].lock);
- INIT_HLIST_NULLS_HEAD(&h->listening_hash[i].nulls_head,
- i + LISTENING_NULLS_BASE);
- h->listening_hash[i].count = 0;
- }
-
- h->lhash2 = NULL;
-}
-EXPORT_SYMBOL_GPL(inet_hashinfo_init);
-
static void init_hashinfo_lhash2(struct inet_hashinfo *h)
{
int i;
for (i = 0; i <= h->lhash2_mask; i++) {
spin_lock_init(&h->lhash2[i].lock);
- INIT_HLIST_HEAD(&h->lhash2[i].head);
- h->lhash2[i].count = 0;
+ INIT_HLIST_NULLS_HEAD(&h->lhash2[i].nulls_head,
+ i + LISTENING_NULLS_BASE);
}
}