aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGravatar Chion Tang <sdspeedonion@gmail.com> 2018-03-14 02:08:40 +0000
committerGravatar Chion Tang <sdspeedonion@gmail.com> 2018-03-14 02:08:40 +0000
commitaa8062e2ec9e04a3e9eacad7e385c10e834e7c69 (patch)
treecb1394e3dc7c92f4dca5b1d8f8bc5079d644fc52
parentfix: nf_ct_event register lock bug (diff)
downloadnetfilter-full-cone-nat-aa8062e2ec9e04a3e9eacad7e385c10e834e7c69.tar.gz
netfilter-full-cone-nat-aa8062e2ec9e04a3e9eacad7e385c10e834e7c69.tar.bz2
netfilter-full-cone-nat-aa8062e2ec9e04a3e9eacad7e385c10e834e7c69.zip
feature: hashtable for original src & original tuple
-rw-r--r--xt_FULLCONENAT.c156
1 files changed, 101 insertions, 55 deletions
diff --git a/xt_FULLCONENAT.c b/xt_FULLCONENAT.c
index 54e8c97..8e56e21 100644
--- a/xt_FULLCONENAT.c
+++ b/xt_FULLCONENAT.c
@@ -3,6 +3,8 @@
#include <linux/version.h>
#include <linux/types.h>
#include <linux/random.h>
+#include <linux/once.h>
+#include <linux/jhash.h>
#include <linux/list.h>
#include <linux/hashtable.h>
#include <linux/netdevice.h>
@@ -10,12 +12,17 @@
#include <linux/netfilter.h>
#include <linux/netfilter_ipv4.h>
#include <linux/netfilter/x_tables.h>
+#include <net/netns/hash.h>
#include <net/netfilter/nf_nat.h>
#include <net/netfilter/nf_conntrack.h>
#include <net/netfilter/nf_conntrack_tuple.h>
#include <net/netfilter/nf_conntrack_core.h>
#include <net/netfilter/nf_conntrack_ecache.h>
+#define HASH_2(x, y) ((x + y) / 2 * (x + y + 1) + y)
+
+#define HASHTABLE_BUCKET_BITS 10
+
#if LINUX_VERSION_CODE < KERNEL_VERSION(4, 10, 0)
static inline int nf_ct_netns_get(struct net *net, u8 nfproto) { return 0; }
@@ -43,7 +50,10 @@ struct nat_mapping {
int ifindex; /* external interface index*/
struct nf_conntrack_tuple original_tuple;
- struct hlist_node node;
+ struct hlist_node node_by_ext_port;
+ struct hlist_node node_by_original_src;
+ struct hlist_node node_by_original_tuple;
+
};
struct nf_ct_net_event {
@@ -59,22 +69,34 @@ static LIST_HEAD(nf_ct_net_event_list);
static DEFINE_MUTEX(nf_ct_net_event_lock);
-static DEFINE_HASHTABLE(mapping_table, 10);
+static DEFINE_HASHTABLE(mapping_table_by_ext_port, HASHTABLE_BUCKET_BITS);
+static DEFINE_HASHTABLE(mapping_table_by_original_src, HASHTABLE_BUCKET_BITS);
+static DEFINE_HASHTABLE(mapping_table_by_original_tuple, HASHTABLE_BUCKET_BITS);
static DEFINE_SPINLOCK(fullconenat_lock);
-static struct nat_mapping* get_mapping(const uint16_t port, const int ifindex, const int create_new) {
- struct nat_mapping *p_current, *p_new;
-
- hash_for_each_possible(mapping_table, p_current, node, port) {
- if (p_current->port == port && p_current->ifindex == ifindex) {
- return p_current;
- }
- }
+static unsigned int nf_conntrack_hash_rnd __read_mostly;
+static u32 hash_conntrack_raw(const struct nf_conntrack_tuple *tuple,
+ const struct net *net) {
+ unsigned int n;
+ u32 seed;
+
+ get_random_once(&nf_conntrack_hash_rnd, sizeof(nf_conntrack_hash_rnd));
+
+ /* The direction must be ignored, so we hash everything up to the
+ * destination ports (which is a multiple of 4) and treat the last
+ * three bytes manually.
+ */
+ seed = nf_conntrack_hash_rnd ^ net_hash_mix(net);
+ n = (sizeof(tuple->src) + sizeof(tuple->dst.u3)) / sizeof(u32);
+ return jhash2((u32 *)tuple, n, seed ^
+ (((__force __u16)tuple->dst.u.all << 16) |
+ tuple->dst.protonum));
+}
- if (!create_new) {
- return NULL;
- }
+static struct nat_mapping* allocate_mapping(const struct net *net, const uint16_t port, const __be32 int_addr, const uint16_t int_port, const int ifindex, const struct nf_conntrack_tuple* original_tuple) {
+ struct nat_mapping *p_new;
+ u32 hash_tuple, hash_src;
p_new = kmalloc(sizeof(struct nat_mapping), GFP_ATOMIC);
if (p_new == NULL) {
@@ -82,34 +104,54 @@ static struct nat_mapping* get_mapping(const uint16_t port, const int ifindex, c
return NULL;
}
p_new->port = port;
- p_new->int_addr = 0;
- p_new->int_port = 0;
+ p_new->int_addr = int_addr;
+ p_new->int_port = int_port;
p_new->ifindex = ifindex;
- memset(&p_new->original_tuple, 0, sizeof(struct nf_conntrack_tuple));
+ memcpy(&p_new->original_tuple, original_tuple, sizeof(struct nf_conntrack_tuple));
- hash_add(mapping_table, &p_new->node, port);
+ hash_tuple = hash_conntrack_raw(original_tuple, net);
+ hash_src = HASH_2(int_addr, (u32)int_port);
+
+ hash_add(mapping_table_by_ext_port, &p_new->node_by_ext_port, port);
+ hash_add(mapping_table_by_original_tuple, &p_new->node_by_original_tuple, hash_tuple);
+ hash_add(mapping_table_by_original_src, &p_new->node_by_original_src, hash_src);
return p_new;
}
+static struct nat_mapping* get_mapping_by_ext_port(const uint16_t port, const int ifindex) {
+ struct nat_mapping *p_current;
+
+ hash_for_each_possible(mapping_table_by_ext_port, p_current, node_by_ext_port, port) {
+ if (p_current->port == port && p_current->ifindex == ifindex) {
+ return p_current;
+ }
+ }
+
+ return NULL;
+}
+
static struct nat_mapping* get_mapping_by_original_src(const __be32 src_ip, const uint16_t src_port, const int ifindex) {
struct nat_mapping *p_current;
- int i;
- hash_for_each(mapping_table, i, p_current, node) {
+ u32 hash_src = HASH_2(src_ip, (u32)src_port);
+
+ hash_for_each_possible(mapping_table_by_original_src, p_current, node_by_original_src, hash_src) {
if (p_current->int_addr == src_ip && p_current->int_port == src_port && p_current->ifindex == ifindex) {
return p_current;
}
}
+
return NULL;
}
-static struct nat_mapping* get_mapping_by_original_tuple(const struct nf_conntrack_tuple* tuple) {
+static struct nat_mapping* get_mapping_by_original_tuple(const struct net *net, const struct nf_conntrack_tuple* tuple) {
struct nat_mapping *p_current;
- int i;
- if (tuple == NULL) {
+ u32 hash_tuple = hash_conntrack_raw(tuple, net);
+
+ if (net == NULL || tuple == NULL) {
return NULL;
}
- hash_for_each(mapping_table, i, p_current, node) {
+ hash_for_each_possible(mapping_table_by_original_tuple, p_current, node_by_original_tuple, hash_tuple) {
if (nf_ct_tuple_equal(&p_current->original_tuple, tuple)) {
return p_current;
}
@@ -117,6 +159,16 @@ static struct nat_mapping* get_mapping_by_original_tuple(const struct nf_conntra
return NULL;
}
+static void kill_mapping(struct nat_mapping *mapping) {
+ if (mapping == NULL) {
+ return;
+ }
+ hash_del(&mapping->node_by_ext_port);
+ hash_del(&mapping->node_by_original_src);
+ hash_del(&mapping->node_by_original_tuple);
+ kfree(mapping);
+}
+
static void destroy_mappings(void) {
struct nat_mapping *p_current;
struct hlist_node *tmp;
@@ -124,9 +176,8 @@ static void destroy_mappings(void) {
spin_lock(&fullconenat_lock);
- hash_for_each_safe(mapping_table, i, tmp, p_current, node) {
- hash_del(&p_current->node);
- kfree(p_current);
+ hash_for_each_safe(mapping_table_by_ext_port, i, tmp, p_current, node_by_ext_port) {
+ kill_mapping(p_current);
}
spin_unlock(&fullconenat_lock);
@@ -135,7 +186,7 @@ static void destroy_mappings(void) {
/* check if a mapping is valid.
* possibly delete and free an invalid mapping.
* the mapping should not be used anymore after check_mapping() returns 0. */
-static int check_mapping(struct nat_mapping* mapping, struct net *net, struct nf_conntrack_zone *zone)
+static int check_mapping(struct net *net, struct nf_conntrack_zone *zone, struct nat_mapping* mapping)
{
struct nf_conntrack_tuple_hash *original_tuple_hash;
@@ -161,14 +212,14 @@ del_mapping:
/* for dying/unconfirmed conntracks, an IPCT_DESTROY event may NOT be fired.
* so we manually kill one of those conntracks once we acquire one. */
pr_debug("xt_FULLCONENAT: check_mapping(): kill dying/unconfirmed mapping at ext port %d\n", mapping->port);
- hash_del(&mapping->node);
- kfree(mapping);
+ kill_mapping(mapping);
return 0;
}
/* conntrack destroy event callback function */
static int ct_event_cb(unsigned int events, struct nf_ct_event *item) {
struct nf_conn *ct;
+ struct net *net;
struct nf_conntrack_tuple *ct_tuple_origin;
struct nat_mapping *mapping;
uint8_t protonum;
@@ -179,6 +230,8 @@ static int ct_event_cb(unsigned int events, struct nf_ct_event *item) {
return 0;
}
+ net = nf_ct_net(ct);
+
/* take the original tuple and find the corresponding mapping */
ct_tuple_origin = &(ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple);
@@ -189,7 +242,7 @@ static int ct_event_cb(unsigned int events, struct nf_ct_event *item) {
spin_lock(&fullconenat_lock);
- mapping = get_mapping_by_original_tuple(ct_tuple_origin);
+ mapping = get_mapping_by_original_tuple(net, ct_tuple_origin);
if (mapping == NULL) {
spin_unlock(&fullconenat_lock);
return 0;
@@ -197,8 +250,7 @@ static int ct_event_cb(unsigned int events, struct nf_ct_event *item) {
/* then kill it */
pr_debug("xt_FULLCONENAT: ct_event_cb(): kill expired mapping at ext port %d\n", mapping->port);
- hash_del(&mapping->node);
- kfree(mapping);
+ kill_mapping(mapping);
spin_unlock(&fullconenat_lock);
@@ -221,7 +273,7 @@ static __be32 get_device_ip(const struct net_device* dev) {
}
}
-static uint16_t find_appropriate_port(const uint16_t original_port, const struct nf_nat_ipv4_range *range, const int ifindex, struct net *net, struct nf_conntrack_zone *zone) {
+static uint16_t find_appropriate_port(struct net *net, struct nf_conntrack_zone *zone, const uint16_t original_port, const int ifindex, const struct nf_nat_ipv4_range *range) {
uint16_t min, start, selected, range_size, i;
struct nat_mapping* mapping = NULL;
@@ -245,8 +297,8 @@ static uint16_t find_appropriate_port(const uint16_t original_port, const struct
if ((original_port >= min && original_port <= min + range_size - 1)
|| !(range->flags & NF_NAT_RANGE_PROTO_SPECIFIED)) {
/* 1. try to preserve the port if it's available */
- mapping = get_mapping(original_port, ifindex, 0);
- if (mapping == NULL || !(check_mapping(mapping, net, zone))) {
+ mapping = get_mapping_by_ext_port(original_port, ifindex);
+ if (mapping == NULL || !(check_mapping(net, zone, mapping))) {
return original_port;
}
}
@@ -258,14 +310,18 @@ static uint16_t find_appropriate_port(const uint16_t original_port, const struct
for (i = 0; i < range_size; i++) {
/* 2. try to find an available port */
selected = min + ((start + i) % range_size);
- mapping = get_mapping(selected, ifindex, 0);
- if (mapping == NULL || !(check_mapping(mapping, net, zone))) {
+ mapping = get_mapping_by_ext_port(selected, ifindex);
+ if (mapping == NULL || !(check_mapping(net, zone, mapping))) {
return selected;
}
}
- /* 3. at least we tried. rewrite a previous mapping. */
- return min + start;
+ /* 3. at least we tried. override a previous mapping. */
+ selected = min + start;
+ mapping = get_mapping_by_ext_port(selected, ifindex);
+ kill_mapping(mapping);
+
+ return selected;
}
static unsigned int fullconenat_tg(struct sk_buff *skb, const struct xt_action_param *par)
@@ -322,12 +378,12 @@ static unsigned int fullconenat_tg(struct sk_buff *skb, const struct xt_action_p
spin_lock(&fullconenat_lock);
/* find an active mapping based on the inbound port */
- mapping = get_mapping(port, ifindex, 0);
+ mapping = get_mapping_by_ext_port(port, ifindex);
if (mapping == NULL) {
spin_unlock(&fullconenat_lock);
return ret;
}
- if (check_mapping(mapping, net, zone)) {
+ if (check_mapping(net, zone, mapping)) {
newrange.flags = NF_NAT_RANGE_MAP_IPS | NF_NAT_RANGE_PROTO_SPECIFIED;
newrange.min_addr.ip = mapping->int_addr;
newrange.max_addr.ip = mapping->int_addr;
@@ -354,7 +410,7 @@ static unsigned int fullconenat_tg(struct sk_buff *skb, const struct xt_action_p
original_port = be16_to_cpu((ct_tuple_origin->src).u.udp.port);
src_mapping = get_mapping_by_original_src(ip, original_port, ifindex);
- if (src_mapping != NULL && check_mapping(src_mapping, net, zone)) {
+ if (src_mapping != NULL && check_mapping(net, zone, src_mapping)) {
/* outbound nat: if a previously established mapping is active,
* we will reuse that mapping. */
@@ -364,7 +420,7 @@ static unsigned int fullconenat_tg(struct sk_buff *skb, const struct xt_action_p
newrange.max_proto = newrange.min_proto;
} else {
- want_port = find_appropriate_port(original_port, range, ifindex, net, zone);
+ want_port = find_appropriate_port(net, zone, original_port, ifindex, range);
newrange.flags = NF_NAT_RANGE_MAP_IPS | NF_NAT_RANGE_PROTO_SPECIFIED;
newrange.min_proto.udp.port = cpu_to_be16(want_port);
@@ -390,19 +446,9 @@ static unsigned int fullconenat_tg(struct sk_buff *skb, const struct xt_action_p
port = be16_to_cpu((ct_tuple->dst).u.udp.port);
/* save the mapping information into our mapping table */
- mapping = get_mapping(port, ifindex, 1);
- if (mapping == NULL) {
- spin_unlock(&fullconenat_lock);
- return ret;
- }
- mapping->int_addr = ip;
- mapping->int_port = original_port;
- mapping->ifindex = ifindex;
- /* save the original source tuple */
- memcpy(&mapping->original_tuple, ct_tuple_origin, sizeof(struct nf_conntrack_tuple));
-
- spin_unlock(&fullconenat_lock);
+ mapping = allocate_mapping(net, port, ip, original_port, ifindex, ct_tuple_origin);
+ spin_unlock(&fullconenat_lock);
return ret;
}